Skip to content

Commit 9ed7a90

Browse files
committed
fixing recursion depth issues, save best intermediate solutions, show scipy optimize progress, fix mistake in target terms that occurs for some heralding cases
1 parent df604d5 commit 9ed7a90

File tree

6 files changed

+203
-115
lines changed

6 files changed

+203
-115
lines changed

pytheus/fancy_classes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,7 @@ def getStateCatalog(self, order=None, full=False):
301301
else:
302302
self._state_catalog = th.stateCatalog(th.findEdgeCovers(self.edges,
303303
order=order, loops=self.loops))
304-
305-
if self._state_catalog_tensor is not None:
304+
if self._state_catalog_tensor is None:
306305
self._state_catalog_tensor = {}
307306
for ket, pm_list in self._state_catalog.items():
308307
self._state_catalog_tensor[ket] = [[self.complete_graph_edges.index(edge) for edge in pm] for pm in pm_list]

pytheus/leiwand.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def leiwand(data, name='graph'):
128128
poly = reversed(list(dict(sorted(poly.items(), key=lambda x: x[0])).values()))
129129
for i, coord in enumerate(poly):
130130
print(r"\node[vertex] ({name}) at ({x},{y}) {xname};".format(name=vertices[i],
131-
xname=r"{\color{fontcolor}" + vertices[
132-
i] + "}", x=coord[0], y=coord[1]),
131+
xname=r"{\color{fontcolor}" + str(vertices[
132+
i]) + "}", x=coord[0], y=coord[1]),
133133
file=outf)
134134

135135
# edge_string = r"\path ({v1}) edge[{options}, opacity={opacity}] ({v2});"

pytheus/lossfunctions.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ def count_rate(graph, target_state, cnfg):
180180
new_loss = cnfg.get("new_loss", False)
181181
if not new_loss:
182182
# set up target equation
183-
target = target_state.targetEquation(state_catalog=graph.state_catalog, imaginary=cnfg["imaginary"])
184183
# get variable names
185184
variables = th.stringEdges(graph.edges, imaginary=cnfg["imaginary"])
186185

187186
# non-heralded, post-selection case
188187
if not cnfg["heralding_out"]:
189188
# only looking at perfect matchings
189+
target = target_state.targetEquation(state_catalog=graph.state_catalog, imaginary=cnfg["imaginary"])
190190
graph.getNorm()
191191
norm = graph.norm
192192
# heralded case, more complicated selection rules
@@ -199,6 +199,7 @@ def count_rate(graph, target_state, cnfg):
199199
edgecovers = brutal_covers(cnfg, graph)
200200
cat = th.stateCatalog(edgecovers)
201201
norm = th.writeNorm(cat, imaginary=cnfg["imaginary"])
202+
target = target_state.targetEquation(state_catalog=cat, imaginary=cnfg["imaginary"])
202203
lambdaloss = "".join(["1-", target, "/(1+", norm, ")"])
203204
func, lossstring = th.buildLossString(lambdaloss, variables)
204205
else:
@@ -209,17 +210,32 @@ def count_rate(graph, target_state, cnfg):
209210
state_catalog_tensor = np.array(graph._state_catalog_tensor)
210211
target_normed = np.array(target_normed)
211212

212-
graph_state = lambda edges: edges[state_catalog_tensor].prod(axis=-1).sum(axis=-1)
213-
normed_state = lambda state: state / (1 + np.linalg.norm(state, axis=-1))
214-
count_rate = lambda state: abs(state @ target_normed)**2
215-
func0 = lambda x: 1 - count_rate(normed_state(graph_state(x)))
216-
213+
# print(state_catalog_tensor)
214+
def graph_state(edges):
215+
result = edges[state_catalog_tensor].prod(axis=-1).sum(axis=-1)
216+
return result
217+
218+
def normed_state(state):
219+
print(f"[normed_state] Input shape: {state.shape}")
220+
norm_val = np.linalg.norm(state, axis=-1)
221+
result = state / (1 + norm_val)
222+
return result
223+
224+
def count_rate(state):
225+
result = abs(state @ target_normed) ** 2
226+
return result
227+
228+
def func0(x):
229+
state = graph_state(x)
230+
normalized_state = normed_state(state)
231+
return 1 - count_rate(normalized_state)
217232
#matrix that transforms weights of length len(graph.edges) to weights of length len(graph.complete_graph_edges)
218233
mat = np.zeros((len(graph.complete_graph_edges), len(graph.edges)))
219234
for i, edge in enumerate(graph.edges):
220235
mat[graph.complete_graph_edges.index(edge),i] = 1
221236

222-
func = lambda x: func0(mat @ x)
237+
def func(x):
238+
return func0(mat @ x)
223239

224240
print('count rate done', flush=True)
225241
return func
@@ -230,8 +246,6 @@ def fidelity(graph, target_state, cnfg):
230246
new_loss = cnfg.get("new_loss", False)
231247

232248
if not new_loss:
233-
# set up target equation
234-
target = target_state.targetEquation(state_catalog=graph.state_catalog, imaginary=cnfg["imaginary"])
235249
# get variable names
236250
variables = th.stringEdges(graph.edges, imaginary=cnfg["imaginary"])
237251

@@ -240,6 +254,7 @@ def fidelity(graph, target_state, cnfg):
240254
# only looking at perfect matchings
241255
graph.getNorm()
242256
norm = graph.norm
257+
target = target_state.targetEquation(state_catalog=graph.state_catalog, imaginary=cnfg["imaginary"])
243258
# heralded case, more complicated selection rules
244259
else:
245260
if not cnfg["brutal_covers"]:
@@ -250,6 +265,7 @@ def fidelity(graph, target_state, cnfg):
250265
edgecovers = brutal_covers(cnfg, graph)
251266
cat = th.stateCatalog(edgecovers)
252267
norm = th.writeNorm(cat, imaginary=cnfg["imaginary"])
268+
target = target_state.targetEquation(state_catalog=cat, imaginary=cnfg["imaginary"])
253269
lambdaloss = "".join(["1-", target, "/(0+", norm, ")"])
254270
func, lossstring = th.buildLossString(lambdaloss, variables)
255271

pytheus/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def setup_for_target(cnfg, state_cat=True):
421421
if edge[0] == connection[1] and edge[1] == connection[0]:
422422
graph.remove(edge)
423423

424-
print(f'start graph has {len(edge_list)} edges.')
424+
print(f'start graph has {len(graph.edges)} edges.')
425425
return target_state, graph, cnfg
426426

427427

pytheus/optimizer.py

Lines changed: 120 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import json
1414

1515
import logging
16-
16+
import time
1717
log = logging.getLogger(__name__)
1818

1919

@@ -29,11 +29,16 @@ def __init__(self, start_graph: Graph, saver: saver, ent_dic=None, target_state=
2929
else:
3030
self.target = target_state # object of State class
3131

32+
self.best_loss = np.inf
3233
# do preoptimization on complete starting graph, this might already take some time
3334
self.graph = self.pre_optimize_start_graph(start_graph)
34-
self.saver = saver
35+
3536
self.save_hist = safe_history
3637
self.history = []
38+
39+
self.saver = saver
40+
#save preoptimized graph to file
41+
self.saver.save_graph(self)
3742

3843
def check(self, result: object, lossfunctions: object):
3944
"""
@@ -57,13 +62,9 @@ def check(self, result: object, lossfunctions: object):
5762
if abs(result.fun) - abs(self.loss_val[0]) > self.config['thresholds'][0]:
5863
return False
5964
else:
60-
# uncomment to see where checks fail
61-
# print(result.fun, self.config['thresholds'][0])
62-
if result.fun > self.config['thresholds'][0]:
63-
# if check fails return false
64-
return False
6565
# check if all loss functions are under the corresponding threshold
66-
for ii in range(1, len(lossfunctions)):
66+
for ii in range(len(lossfunctions)):
67+
print(f'loss {ii}: {lossfunctions[ii](result.x):.4f}, threshold: {self.config["thresholds"][ii]:.4f}')
6768
if lossfunctions[ii](result.x) > self.config['thresholds'][ii]:
6869
# if check fails return false
6970
return False
@@ -166,7 +167,8 @@ def get_loss_functions(self, current_graph: Graph):
166167
for loss in callable_loss:
167168
try:
168169
loss(testinit)
169-
except Exception:
170+
except Exception as e:
171+
print(f'Error in loss function {loss}: {e}')
170172
raise RuntimeError('Loss function gives error for a test input, so it is not properly defined. This could be due to configuration parameters given for the optimization. Trying to compute perfect matchings for an odd number of total particles (main+ancilla) will lead to a meaningless loss function (0/0 --> division by zero).')
171173
return callable_loss
172174

@@ -204,94 +206,115 @@ def pre_optimize_start_graph(self, graph) -> Graph:
204206
preopt_graph
205207
206208
"""
207-
# losses is a list of callable lossfunctions, e.g. [countrate(x), fidelity(x)], where x is a vector of edge weights
208-
# that can be given to scipy.optimize
209-
log.info('loading losses')
210-
losses = self.get_loss_functions(graph)
211-
log.info('losses done')
212-
valid = False
213-
counter = 0
214-
# repeat optimization of complete graph until a good solution is found (which satifies self.check())
215-
while not valid:
216-
# prepare optimizer
217-
initial_values, bounds = self.prepOptimizer(len(graph))
218-
# optimization with scipy
219-
log.info('begin preopt')
220-
best_result = optimize.minimize(losses[0], x0=initial_values,
221-
bounds=bounds,
222-
method=self.config['optimizer'],
223-
options={'ftol': self.config['ftol']})
224-
log.info('end preopt')
225-
self.loss_val = self.update_losses(best_result, losses)
226-
# check if solution is valid
227-
valid = self.check(best_result, losses)
228-
counter += 1
229-
# print a warning if preoptimization is stuck in a loop
230-
if counter % 10 == 0:
231-
print('10 invalid preoptimization, consider changing parameters.')
232-
log.info('10 invalid preoptimization, consider changing parameters.')
233-
if counter % 100 == 0:
234-
print('100 invalid preoptimization, state cannot be found.')
235-
log.info('100 invalid preoptimization, state cannot be found.')
236-
raise ValueError('100 invalid preoptimization steps. Conclusion: State cannot be created with provides parameters. Consider adding more ancillas or using less restrictions if possible (e.g. removed_connections).')
237-
238-
# if num_pre is set to larger than 1 in config, do num_pre preoptimization and choose the best one.
239-
# for optimizations with concrete target state, num_pre = 1 is enough
240-
for __ in range(self.config['num_pre'] - 1):
241-
initial_values, bounds = self.prepOptimizer(len(graph))
242-
result = optimize.minimize(losses[0], x0=initial_values,
243-
bounds=bounds,
244-
method=self.config['optimizer'],
245-
options={'ftol': self.config['ftol']})
246-
247-
if result.fun < best_result.fun:
248-
best_result = result
249-
self.loss_val = self.update_losses(best_result, losses)
250-
print(f'best result from pre-opt: {abs(best_result.fun)}')
251-
log.info(f'best result from pre-opt: {abs(best_result.fun)}')
252-
253-
for ii, edge in enumerate(graph.edges):
254-
graph[edge] = best_result.x[ii]
255-
preopt_graph = graph.copy()
256-
257-
try:
258-
bulk_thr = self.config['bulk_thr']
259-
except:
260-
bulk_thr = 0
261-
if bulk_thr > 0:
262-
# cut all edges smaller than bulk_thr and optimize again
263-
# this can save a lot of time
264-
cont = True
265-
num_deleted = 0
266-
while cont:
267-
# delete smallest edges one by one
268-
min_edge = preopt_graph.minimum()
269-
amplitude = preopt_graph[min_edge]
270-
if self.imaginary == 'polar':
271-
amplitude = amplitude[0]
272-
if abs(amplitude) < bulk_thr:
273-
preopt_graph.remove(min_edge, update=True)
274-
num_deleted += 1
209+
if 'init_graph' in self.config:
210+
print('SKIPPING PREOPTIMIZATION, USING INIT_GRAPH FROM CONFIG', flush=True)
211+
init_edges = self.config['init_graph']
212+
init_edges = {eval(k): v for k, v in init_edges.items()}
213+
print(init_edges)
214+
print(graph.edges)
215+
for edge in list(graph.edges):
216+
if edge in init_edges:
217+
graph[edge] = init_edges[edge]
275218
else:
276-
cont = False
277-
print(f'{num_deleted} edges deleted')
278-
log.info(f'{num_deleted} edges deleted')
219+
graph.remove(edge)
220+
print(graph)
221+
preopt_graph = graph.copy()
222+
else:
223+
# losses is a list of callable lossfunctions, e.g. [countrate(x), fidelity(x)], where x is a vector of edge weights
224+
# that can be given to scipy.optimize
225+
log.info('loading losses')
226+
losses = self.get_loss_functions(graph)
227+
log.info('losses done')
279228
valid = False
229+
counter = 0
230+
# repeat optimization of complete graph until a good solution is found (which satifies self.check())
280231
while not valid:
281-
# it is necessary that the truncated graph passes the checks
282-
initial_values, bounds = self.prepOptimizer(len(preopt_graph))
283-
losses = self.get_loss_functions(preopt_graph)
284-
trunc_result = optimize.minimize(losses[0], x0=initial_values,
285-
bounds=bounds,
286-
method=self.config['optimizer'],
287-
options={'ftol': self.config['ftol']})
288-
self.loss_val = self.update_losses(trunc_result, losses)
289-
print(f'result after truncation: {abs(trunc_result.fun)}')
290-
log.info(f'result after truncation: {abs(trunc_result.fun)}')
291-
valid = self.check(trunc_result, losses)
292-
293-
for ii, edge in enumerate(preopt_graph.edges):
294-
preopt_graph[edge] = trunc_result.x[ii]
232+
# prepare optimizer
233+
initial_values, bounds = self.prepOptimizer(len(graph))
234+
# optimization with scipy
235+
log.info('begin preopt')
236+
self.tt = time.time()
237+
def callback(xk):
238+
print(f'preopt step with loss {losses[0](xk):.3f}', flush=True)
239+
print(f'time for last step: {time.time()-self.tt:.1f}s', flush=True)
240+
self.tt = time.time()
241+
best_result = optimize.minimize(losses[0], x0=initial_values,
242+
bounds=bounds,
243+
method=self.config['optimizer'],
244+
options={'ftol': self.config['ftol']},
245+
callback=callback)
246+
log.info('end preopt')
247+
self.loss_val = self.update_losses(best_result, losses)
248+
# check if solution is valid
249+
valid = self.check(best_result, losses)
250+
counter += 1
251+
# print a warning if preoptimization is stuck in a loop
252+
if counter % 10 == 0:
253+
print('10 invalid preoptimization, consider changing parameters.')
254+
log.info('10 invalid preoptimization, consider changing parameters.')
255+
if counter % 100 == 0:
256+
print('100 invalid preoptimization, state cannot be found.')
257+
log.info('100 invalid preoptimization, state cannot be found.')
258+
raise ValueError('100 invalid preoptimization steps. Conclusion: State cannot be created with provides parameters. Consider adding more ancillas or using less restrictions if possible (e.g. removed_connections).')
259+
260+
# if num_pre is set to larger than 1 in config, do num_pre preoptimization and choose the best one.
261+
# for optimizations with concrete target state, num_pre = 1 is enough
262+
for __ in range(self.config['num_pre'] - 1):
263+
initial_values, bounds = self.prepOptimizer(len(graph))
264+
result = optimize.minimize(losses[0], x0=initial_values,
265+
bounds=bounds,
266+
method=self.config['optimizer'],
267+
options={'ftol': self.config['ftol']})
268+
269+
if result.fun < best_result.fun:
270+
best_result = result
271+
self.loss_val = self.update_losses(best_result, losses)
272+
print(f'best result from pre-opt: {abs(best_result.fun)}')
273+
log.info(f'best result from pre-opt: {abs(best_result.fun)}')
274+
self.best_loss = best_result.fun
275+
276+
for ii, edge in enumerate(graph.edges):
277+
graph[edge] = best_result.x[ii]
278+
preopt_graph = graph.copy()
279+
280+
try:
281+
bulk_thr = self.config['bulk_thr']
282+
except:
283+
bulk_thr = 0
284+
if bulk_thr > 0:
285+
# cut all edges smaller than bulk_thr and optimize again
286+
# this can save a lot of time
287+
cont = True
288+
num_deleted = 0
289+
while cont:
290+
# delete smallest edges one by one
291+
min_edge = preopt_graph.minimum()
292+
amplitude = preopt_graph[min_edge]
293+
if self.imaginary == 'polar':
294+
amplitude = amplitude[0]
295+
if abs(amplitude) < bulk_thr:
296+
preopt_graph.remove(min_edge, update=True)
297+
num_deleted += 1
298+
else:
299+
cont = False
300+
print(f'{num_deleted} edges deleted')
301+
log.info(f'{num_deleted} edges deleted')
302+
valid = False
303+
while not valid:
304+
# it is necessary that the truncated graph passes the checks
305+
initial_values, bounds = self.prepOptimizer(len(preopt_graph))
306+
losses = self.get_loss_functions(preopt_graph)
307+
trunc_result = optimize.minimize(losses[0], x0=initial_values,
308+
bounds=bounds,
309+
method=self.config['optimizer'],
310+
options={'ftol': self.config['ftol']})
311+
self.loss_val = self.update_losses(trunc_result, losses)
312+
print(f'result after truncation: {abs(trunc_result.fun)}')
313+
log.info(f'result after truncation: {abs(trunc_result.fun)}')
314+
valid = self.check(trunc_result, losses)
315+
316+
for ii, edge in enumerate(preopt_graph.edges):
317+
preopt_graph[edge] = trunc_result.x[ii]
295318

296319
return preopt_graph
297320

@@ -419,6 +442,9 @@ def optimize_one_edge(self, num_edge: int,
419442
# to a file even if topological optimization can be continued
420443
if all(np.array(abs(self.graph)) > 0.95):
421444
self.saver.save_graph(self)
445+
elif result.fun < self.best_loss:
446+
self.best_loss = result.fun
447+
self.saver.save_graph(self)
422448
if self.save_hist:
423449
self.history.append([str(self.graph),self.loss_val])
424450
# return updated result graph

0 commit comments

Comments
 (0)