88import numpy .typing as npt
99
1010from bqskit .ir .gates .parameterized .unitary import VariableUnitaryGate
11+ from bqskit .ir .opt .cost .functions import HilbertSchmidtCostGenerator
1112from bqskit .ir .opt .cost .functions import HilbertSchmidtResidualsGenerator
1213from bqskit .ir .opt .cost .generator import CostFunctionGenerator
14+ from bqskit .ir .opt .cost .residual import ResidualsFunction
1315from bqskit .ir .opt .instantiater import Instantiater
1416from bqskit .ir .opt .minimizer import Minimizer
1517from bqskit .ir .opt .minimizers .ceres import CeresMinimizer
@@ -107,6 +109,8 @@ def multi_start_instantiate_inplace(
107109 start_gen = RandomStartGenerator ()
108110 starts = start_gen .gen_starting_points (num_starts , circuit , target )
109111 cost_fn = self .cost_fn_gen .gen_cost (circuit , target )
112+ if isinstance (cost_fn , ResidualsFunction ):
113+ cost_fn = HilbertSchmidtCostGenerator ().gen_cost (circuit , target )
110114 params_list = [self .instantiate (circuit , target , x0 ) for x0 in starts ]
111115 params = sorted (params_list , key = lambda x : cost_fn (x ))[0 ]
112116 circuit .set_params (params )
@@ -127,6 +131,8 @@ async def multi_start_instantiate_async(
127131 start_gen = RandomStartGenerator ()
128132 starts = start_gen .gen_starting_points (num_starts , circuit , target )
129133 cost_fn = self .cost_fn_gen .gen_cost (circuit , target )
134+ if isinstance (cost_fn , ResidualsFunction ):
135+ cost_fn = HilbertSchmidtCostGenerator ().gen_cost (circuit , target )
130136 params_list = await get_runtime ().map (
131137 self .instantiate ,
132138 [circuit ] * num_starts ,
0 commit comments