Skip to content

Commit d4706b9

Browse files
added controlled Trotter
New API functions: - applyControlledTrotterizedPauliStrSumGadget() - applyMultiControlledTrotterizedPauliStrSumGadget() - applyMultiStateControlledTrotterizedPauliStrSumGadget() - C++-only std::vector overloads of the latter two. Additionally: - renamed the internal constituent functions, like applyFirstOrderTrotter(), to explicit internal_applyFirstOrderTrotterRepetition() - renamed paulis_getInds() to paulis_getTargetInds() Note that new validation was required to check that no PauliStrSum non-identity Paulis overlapped the control qubits. This is relatively expensive; we build a PauliStrSum target-mask in O(#terms * #qubits) time whereas the previous most expensive validation (checking PauliStrSum targets do not exceed Qureg) costs O(#terms * log(#qubits)). Such costs are still completely occluded by those of simulating/processing a PauliStrSum in the backend, but might still attract lazy evaluation of the target-mask which is bound to the PauliStrSum instance. We have deferred any such optimisation and the associated struct changes since it necessitates an update to the PauliStrSum design (like new sync functions)
1 parent f3baf34 commit d4706b9

File tree

7 files changed

+221
-53
lines changed

7 files changed

+221
-53
lines changed

quest/include/operations.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,6 +2387,30 @@ void multiplyPauliStrSum(Qureg qureg, PauliStrSum sum, Qureg workspace);
23872387
void applyTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, qreal angle, int order, int reps);
23882388

23892389

2390+
/// @notyetdoced
2391+
/// @notyettested
2392+
/// @see
2393+
/// - applyTrotterizedPauliStrSumGadget()
2394+
/// - applyControlledCompMatr1()
2395+
void applyControlledTrotterizedPauliStrSumGadget(Qureg qureg, int control, PauliStrSum sum, qreal angle, int order, int reps);
2396+
2397+
2398+
/// @notyetdoced
2399+
/// @notyettested
2400+
/// @see
2401+
/// - applyTrotterizedPauliStrSumGadget()
2402+
/// - applyMultiControlledCompMatr1()
2403+
void applyMultiControlledTrotterizedPauliStrSumGadget(Qureg qureg, int* controls, int numControls, PauliStrSum sum, qreal angle, int order, int reps);
2404+
2405+
2406+
/// @notyetdoced
2407+
/// @notyettested
2408+
/// @see
2409+
/// - applyTrotterizedPauliStrSumGadget()
2410+
/// - applyMultiStateControlledCompMatr1()
2411+
void applyMultiStateControlledTrotterizedPauliStrSumGadget(Qureg qureg, int* controls, int* states, int numControls, PauliStrSum sum, qreal angle, int order, int reps);
2412+
2413+
23902414
/** @notyettested
23912415
*
23922416
* A generalisation of applyTrotterizedPauliStrSumGadget() which accepts a complex angle and permits
@@ -2492,6 +2516,26 @@ void applyNonUnitaryTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, q
24922516
}
24932517
#endif
24942518

2519+
#ifdef __cplusplus
2520+
2521+
2522+
/// @notyettested
2523+
/// @notyetvalidated
2524+
/// @notyetdoced
2525+
/// @cppvectoroverload
2526+
/// @see applyMultiControlledTrotterizedPauliStrSumGadget()
2527+
void applyMultiControlledTrotterizedPauliStrSumGadget(Qureg qureg, std::vector<int> controls, PauliStrSum sum, qreal angle, int order, int reps);
2528+
2529+
2530+
/// @notyettested
2531+
/// @notyetvalidated
2532+
/// @notyetdoced
2533+
/// @cppvectoroverload
2534+
/// @see applyMultiStateControlledTrotterizedPauliStrSumGadget()
2535+
void applyMultiStateControlledTrotterizedPauliStrSumGadget(Qureg qureg, std::vector<int> controls, std::vector<int> states, PauliStrSum sum, qreal angle, int order, int reps);
2536+
2537+
2538+
#endif // __cplusplus
24952539

24962540
/** @} */
24972541

quest/src/api/operations.cpp

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,90 +1130,145 @@ void multiplyPauliStrSum(Qureg qureg, PauliStrSum sum, Qureg workspace) {
11301130
// workspace -> qureg, and qureg -> sum * qureg
11311131
}
11321132

1133-
void applyFirstOrderTrotter(Qureg qureg, PauliStrSum sum, qcomp angle, bool reverse) {
1134-
1135-
// (internal, invoked by applyTrotterizedPauliStrSumGadget)
1136-
1133+
void internal_applyFirstOrderTrotterRepetition(
1134+
Qureg qureg, vector<int>& ketCtrls, vector<int>& braCtrls,
1135+
vector<int>& states, PauliStrSum sum, qcomp angle, bool reverse
1136+
) {
1137+
// apply each sum term as a gadget, in forward or reverse order
11371138
for (qindex i=0; i<sum.numTerms; i++) {
11381139
int j = reverse? sum.numTerms - i - 1 : i;
1140+
qcomp coeff = sum.coeffs[j];
1141+
PauliStr str = sum.strings[j];
11391142

1140-
// effect exp(i angle * sum) by undoing gadget pre-factor
1141-
qcomp arg = angle * sum.coeffs[j] / util_getPhaseFromGateAngle(1);
1142-
applyNonUnitaryPauliGadget(qureg, sum.strings[j], arg); // caller disabled valiation therein
1143-
}
1144-
}
1143+
// effect |psi> -> exp(i angle * sum)|psi>
1144+
qcomp arg = angle * coeff;
1145+
localiser_statevec_anyCtrlPauliGadget(qureg, ketCtrls, states, str, arg);
11451146

1146-
void applyHigherOrderTrotter(Qureg qureg, PauliStrSum sum, qcomp angle, int order) {
1147+
if (!qureg.isDensityMatrix)
1148+
continue;
11471149

1148-
// (internal, invoked by applyTrotterizedPauliStrSumGadget)
1150+
// effect rho -> rho dagger(i angle * sum)
1151+
arg *= paulis_hasOddNumY(str) ? 1 : -1;
1152+
str = paulis_getShiftedPauliStr(str, qureg.numQubits);
1153+
localiser_statevec_anyCtrlPauliGadget(qureg, braCtrls, states, str, arg);
1154+
}
1155+
}
11491156

1157+
void internal_applyHigherOrderTrotterRepetition(
1158+
Qureg qureg, vector<int>& ketCtrls, vector<int>& braCtrls,
1159+
vector<int>& states, PauliStrSum sum, qcomp angle, int order
1160+
) {
11501161
if (order == 1) {
1151-
applyFirstOrderTrotter(qureg, sum, angle, false);
1162+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, angle, false);
11521163

11531164
} else if (order == 2) {
1154-
applyFirstOrderTrotter(qureg, sum, angle/2, false);
1155-
applyFirstOrderTrotter(qureg, sum, angle/2, true);
1165+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, angle/2, false);
1166+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, angle/2, true);
11561167

11571168
} else {
11581169
qreal p = 1. / (4 - std::pow(4, 1./(order-1)));
11591170
qcomp a = p * angle;
11601171
qcomp b = (1-4*p) * angle;
11611172

11621173
int lower = order - 2;
1163-
applyFirstOrderTrotter(qureg, sum, a, lower);
1164-
applyFirstOrderTrotter(qureg, sum, a, lower);
1165-
applyFirstOrderTrotter(qureg, sum, b, lower);
1166-
applyFirstOrderTrotter(qureg, sum, a, lower);
1167-
applyFirstOrderTrotter(qureg, sum, a, lower);
1174+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, a, lower);
1175+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, a, lower);
1176+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, b, lower);
1177+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, a, lower);
1178+
internal_applyFirstOrderTrotterRepetition(qureg, ketCtrls, braCtrls, states, sum, a, lower);
11681179
}
11691180
}
11701181

1171-
void applyNonUnitaryTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, qcomp angle, int order, int reps) {
1172-
validate_quregFields(qureg, __func__);
1173-
validate_pauliStrSumFields(sum, __func__);
1174-
validate_pauliStrSumTargets(sum, qureg, __func__);
1175-
validate_trotterParams(qureg, order, reps, __func__);
1176-
// sum is permitted to be non-Hermitian
1177-
1182+
void internal_applyAllTrotterRepetitions(
1183+
Qureg qureg, int* controls, int* states, int numControls,
1184+
PauliStrSum sum, qcomp angle, int order, int reps
1185+
) {
11781186
// exp(i angle sum) = identity when angle=0
11791187
if (angle == qcomp(0,0))
11801188
return;
11811189

1182-
// record validation state then disable to avoid repeated
1183-
// re-validations in each invoked applyPauliGadget() below
1184-
bool wasValidationEnabled = validateconfig_isEnabled();
1185-
validateconfig_disable();
1190+
// prepare control-qubit lists once for all invoked gadgets below
1191+
auto ketCtrlsVec = util_getVector(controls, numControls);
1192+
auto braCtrlsVec = (qureg.isDensityMatrix)? util_getBraQubits(ketCtrlsVec, qureg) : vector<int>{};
1193+
auto statesVec = util_getVector(states, numControls);
11861194

1187-
// perform sequence of applyPauliGadget()
1188-
for (int r=0; r<reps; r++)
1189-
applyHigherOrderTrotter(qureg, sum, angle/reps, order);
1195+
qcomp arg = angle / reps;
11901196

1191-
// potentially restore validation
1192-
if (wasValidationEnabled)
1193-
validateconfig_enable();
1197+
// perform carefully-ordered sequence of gadgets
1198+
for (int r=0; r<reps; r++)
1199+
internal_applyHigherOrderTrotterRepetition(
1200+
qureg, ketCtrlsVec, braCtrlsVec, statesVec, sum, arg, order);
11941201

11951202
/// @todo
11961203
/// the accuracy of Trotterisation is greatly improved by randomisation
11971204
/// or (even sub-optimal) grouping into commuting terms. Should we
11981205
/// implement these above or into another function?
11991206
}
12001207

1201-
void applyTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, qreal angle, int order, int reps) {
1208+
void applyNonUnitaryTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, qcomp angle, int order, int reps) {
1209+
validate_quregFields(qureg, __func__);
1210+
validate_pauliStrSumFields(sum, __func__);
1211+
validate_pauliStrSumTargets(sum, qureg, __func__);
1212+
validate_trotterParams(qureg, order, reps, __func__);
1213+
// sum is permitted to be non-Hermitian
1214+
1215+
internal_applyAllTrotterRepetitions(qureg, nullptr, nullptr, 0, sum, angle, order, reps);
1216+
}
12021217

1203-
// validate inputs here despite re-validation below so that func name is correct in error message
1218+
void applyTrotterizedPauliStrSumGadget(Qureg qureg, PauliStrSum sum, qreal angle, int order, int reps) {
12041219
validate_quregFields(qureg, __func__);
12051220
validate_pauliStrSumFields(sum, __func__);
12061221
validate_pauliStrSumTargets(sum, qureg, __func__);
12071222
validate_trotterParams(qureg, order, reps, __func__);
1223+
validate_pauliStrSumIsHermitian(sum, __func__);
1224+
1225+
internal_applyAllTrotterRepetitions(qureg, nullptr, nullptr, 0, sum, angle, order, reps);
1226+
}
1227+
1228+
void applyControlledTrotterizedPauliStrSumGadget(Qureg qureg, int control, PauliStrSum sum, qreal angle, int order, int reps) {
1229+
validate_quregFields(qureg, __func__);
1230+
validate_pauliStrSumFields(sum, __func__);
1231+
validate_controlAndPauliStrSumTargets(qureg, control, sum, __func__);
1232+
validate_trotterParams(qureg, order, reps, __func__);
1233+
validate_pauliStrSumIsHermitian(sum, __func__);
12081234

1209-
// crucially, require that sum coefficients are real
1235+
internal_applyAllTrotterRepetitions(qureg, &control, nullptr, 1, sum, angle, order, reps);
1236+
}
1237+
1238+
void applyMultiControlledTrotterizedPauliStrSumGadget(Qureg qureg, int* controls, int numControls, PauliStrSum sum, qreal angle, int order, int reps) {
1239+
validate_quregFields(qureg, __func__);
1240+
validate_pauliStrSumFields(sum, __func__);
1241+
validate_controlsAndPauliStrSumTargets(qureg, controls, numControls, sum, __func__);
1242+
validate_trotterParams(qureg, order, reps, __func__);
12101243
validate_pauliStrSumIsHermitian(sum, __func__);
12111244

1212-
applyNonUnitaryTrotterizedPauliStrSumGadget(qureg, sum, angle, order, reps);
1245+
internal_applyAllTrotterRepetitions(qureg, controls, nullptr, numControls, sum, angle, order, reps);
1246+
}
1247+
1248+
void applyMultiStateControlledTrotterizedPauliStrSumGadget(Qureg qureg, int* controls, int* states, int numControls, PauliStrSum sum, qreal angle, int order, int reps) {
1249+
validate_quregFields(qureg, __func__);
1250+
validate_pauliStrSumFields(sum, __func__);
1251+
validate_controlsAndPauliStrSumTargets(qureg, controls, numControls, sum, __func__);
1252+
validate_controlStates(states, numControls, __func__); // permits states==nullptr
1253+
validate_trotterParams(qureg, order, reps, __func__);
1254+
validate_pauliStrSumIsHermitian(sum, __func__);
1255+
1256+
internal_applyAllTrotterRepetitions(qureg, controls, states, numControls, sum, angle, order, reps);
12131257
}
12141258

12151259
} // end de-mangler
12161260

1261+
void applyMultiControlledTrotterizedPauliStrSumGadget(Qureg qureg, vector<int> controls, PauliStrSum sum, qreal angle, int order, int reps) {
1262+
1263+
applyMultiControlledTrotterizedPauliStrSumGadget(qureg, controls.data(), controls.size(), sum, angle, order, reps);
1264+
}
1265+
1266+
void applyMultiStateControlledTrotterizedPauliStrSumGadget(Qureg qureg, vector<int> controls, vector<int> states, PauliStrSum sum, qreal angle, int order, int reps) {
1267+
validate_controlsMatchStates(controls.size(), states.size(), __func__);
1268+
1269+
applyMultiStateControlledTrotterizedPauliStrSumGadget(qureg, controls.data(), states.data(), controls.size(), sum, angle, order, reps);
1270+
}
1271+
12171272

12181273

12191274
/*

quest/src/api/paulis.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,24 +207,41 @@ qcomp paulis_getPrefixPaulisElem(Qureg qureg, vector<int> prefixY, vector<int> p
207207
}
208208

209209

210-
vector<int> paulis_getInds(PauliStr str) {
210+
vector<int> paulis_getTargetInds(PauliStr str) {
211211

212212
int maxInd = paulis_getIndOfLefmostNonIdentityPauli(str);
213213

214214
vector<int> inds(0);
215215
inds.reserve(maxInd+1);
216216

217217
for (int i=0; i<=maxInd; i++)
218-
if (paulis_getPauliAt(str, i) != 0)
218+
if (paulis_getPauliAt(str, i) != 0) // Id
219219
inds.push_back(i);
220220

221221
return inds;
222222
}
223223

224224

225+
qindex paulis_getTargetBitMask(PauliStr str) {
226+
227+
/// @todo
228+
/// would compile-time MAX_NUM_PAULIS_PER_STR bound be faster here,
229+
/// since this function is invoked upon every PauliStrSum element?
230+
int maxInd = paulis_getIndOfLefmostNonIdentityPauli(str);
231+
232+
qindex mask = 0;
233+
234+
for (int i=0; i<=maxInd; i++)
235+
if (paulis_getPauliAt(str, i) != 0) // Id
236+
mask = flipBit(mask, i);
237+
238+
return mask;
239+
}
240+
241+
225242
array<vector<int>,3> paulis_getSeparateInds(PauliStr str, Qureg qureg) {
226243

227-
vector<int> iXYZ = paulis_getInds(str);
244+
vector<int> iXYZ = paulis_getTargetInds(str);
228245
vector<int> iX, iY, iZ;
229246

230247
vector<int>* ptrs[] = {&iX, &iY, &iZ};
@@ -295,6 +312,18 @@ PAULI_MASK_TYPE paulis_getKeyOfSameMixedAmpsGroup(PauliStr str) {
295312
}
296313

297314

315+
qindex paulis_getTargetBitMask(PauliStrSum sum) {
316+
317+
qindex mask = 0;
318+
319+
// mask has 1 where any str has a != Id
320+
for (int t=0; t<sum.numTerms; t++)
321+
mask |= paulis_getTargetBitMask(sum.strings[t]);
322+
323+
return mask;
324+
}
325+
326+
298327

299328
/*
300329
* PAULI STRING INITIALISATION

quest/src/core/localiser.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ template void localiser_statevec_anyCtrlAnyTargAnyMatr(Qureg, vector<int>, vecto
12381238

12391239

12401240
extern bool paulis_containsXOrY(PauliStr str);
1241-
extern vector<int> paulis_getInds(PauliStr str);
1241+
extern vector<int> paulis_getTargetInds(PauliStr str);
12421242
extern std::array<vector<int>,3> paulis_getSeparateInds(PauliStr str, Qureg qureg);
12431243
extern int paulis_getPrefixZSign(Qureg qureg, vector<int> prefixZ) ;
12441244
extern qcomp paulis_getPrefixPaulisElem(Qureg qureg, vector<int> prefixY, vector<int> prefixZ);
@@ -1334,7 +1334,7 @@ void localiser_statevec_anyCtrlPauliTensor(Qureg qureg, vector<int> ctrls, vecto
13341334

13351335
bool isGadget = false;
13361336
qreal phase = 0; // ignored
1337-
anyCtrlZTensorOrGadget(qureg, ctrls, ctrlStates, paulis_getInds(str), isGadget, phase);
1337+
anyCtrlZTensorOrGadget(qureg, ctrls, ctrlStates, paulis_getTargetInds(str), isGadget, phase);
13381338
}
13391339
}
13401340

@@ -1350,7 +1350,7 @@ void localiser_statevec_anyCtrlPauliGadget(Qureg qureg, vector<int> ctrls, vecto
13501350

13511351
// when str=IZ, we must use the above bespoke algorithm
13521352
if (!paulis_containsXOrY(str)) {
1353-
localiser_statevec_anyCtrlPhaseGadget(qureg, ctrls, ctrlStates, paulis_getInds(str), phase);
1353+
localiser_statevec_anyCtrlPhaseGadget(qureg, ctrls, ctrlStates, paulis_getTargetInds(str), phase);
13541354
return;
13551355
}
13561356

0 commit comments

Comments
 (0)