Skip to content

Commit 7519a20

Browse files
authored
Merge pull request #347 from mihir-putcha/fix-multistart-cost-comparison
Fix multistart cost comparison for ResidualsFunction
2 parents 99c7914 + 92e2262 commit 7519a20

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

bqskit/ir/opt/instantiaters/minimization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import numpy.typing as npt
99

1010
from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
11+
from bqskit.ir.opt.cost.functions import HilbertSchmidtCostGenerator
1112
from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator
1213
from bqskit.ir.opt.cost.generator import CostFunctionGenerator
14+
from bqskit.ir.opt.cost.residual import ResidualsFunction
1315
from bqskit.ir.opt.instantiater import Instantiater
1416
from bqskit.ir.opt.minimizer import Minimizer
1517
from bqskit.ir.opt.minimizers.ceres import CeresMinimizer
@@ -107,6 +109,8 @@ def multi_start_instantiate_inplace(
107109
start_gen = RandomStartGenerator()
108110
starts = start_gen.gen_starting_points(num_starts, circuit, target)
109111
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
112+
if isinstance(cost_fn, ResidualsFunction):
113+
cost_fn = HilbertSchmidtCostGenerator().gen_cost(circuit, target)
110114
params_list = [self.instantiate(circuit, target, x0) for x0 in starts]
111115
params = sorted(params_list, key=lambda x: cost_fn(x))[0]
112116
circuit.set_params(params)
@@ -127,6 +131,8 @@ async def multi_start_instantiate_async(
127131
start_gen = RandomStartGenerator()
128132
starts = start_gen.gen_starting_points(num_starts, circuit, target)
129133
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
134+
if isinstance(cost_fn, ResidualsFunction):
135+
cost_fn = HilbertSchmidtCostGenerator().gen_cost(circuit, target)
130136
params_list = await get_runtime().map(
131137
self.instantiate,
132138
[circuit] * num_starts,

0 commit comments

Comments
 (0)