Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bqskit/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def compile(
request_data,
logging_level,
max_logging_depth,
data,
)
result = self.result(task_id)

Expand Down
6 changes: 6 additions & 0 deletions bqskit/ir/opt/instantiaters/minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import numpy.typing as npt

from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
from bqskit.ir.opt.cost.functions import HilbertSchmidtCostGenerator
from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator
from bqskit.ir.opt.cost.generator import CostFunctionGenerator
from bqskit.ir.opt.cost.residual import ResidualsFunction
from bqskit.ir.opt.instantiater import Instantiater
from bqskit.ir.opt.minimizer import Minimizer
from bqskit.ir.opt.minimizers.ceres import CeresMinimizer
Expand Down Expand Up @@ -107,6 +109,8 @@ def multi_start_instantiate_inplace(
start_gen = RandomStartGenerator()
starts = start_gen.gen_starting_points(num_starts, circuit, target)
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
if isinstance(cost_fn, ResidualsFunction):
cost_fn = HilbertSchmidtCostGenerator().gen_cost(circuit, target)
params_list = [self.instantiate(circuit, target, x0) for x0 in starts]
params = sorted(params_list, key=lambda x: cost_fn(x))[0]
circuit.set_params(params)
Expand All @@ -127,6 +131,8 @@ async def multi_start_instantiate_async(
start_gen = RandomStartGenerator()
starts = start_gen.gen_starting_points(num_starts, circuit, target)
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
if isinstance(cost_fn, ResidualsFunction):
cost_fn = HilbertSchmidtCostGenerator().gen_cost(circuit, target)
params_list = await get_runtime().map(
self.instantiate,
[circuit] * num_starts,
Expand Down
16 changes: 9 additions & 7 deletions bqskit/runtime/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,15 +371,12 @@ def _handle_cancel(self, addr: RuntimeAddress) -> None:
self._cancelled_task_ids.add(addr)

# Remove all tasks that are children of `addr` from initialized tasks
for key, task in self._tasks.items():
for key, task in list(self._tasks.items()):
if task.is_descendant_of(addr):
task.cancel()
for mailbox_id in self._tasks[key].owned_mailboxes:
for mailbox_id in task.owned_mailboxes:
self._mailboxes.pop(mailbox_id)
self._tasks = {
a: t for a, t in self._tasks.items()
if not t.is_descendant_of(addr)
}
self._tasks.pop(key, None)

# Remove all tasks that are children of `addr` from delayed tasks
self._delayed_tasks = [
Expand Down Expand Up @@ -466,7 +463,12 @@ def _try_step_next_ready_task(self) -> None:
except StopIteration as e:
self._process_task_completion(task, e.value)

except Exception:
except Exception as e:
if type(e) is RuntimeError:
for addr in self._cancelled_task_ids:
if task.is_descendant_of(addr):
return

assert self._active_task is not None # for type checker

# Bubble up errors
Expand Down
Loading