Skip to content

Commit 76d0b2b

Browse files
jzaia18maliasadi
andauthored
Resource tracking can now handle statepreps (#2230)
**Context:** The `quantum.set_state` and `quantum.set_basis_state` instructions are not currently tracked via runtime resource tracking. **Description of the Change:** * Adds 2 new functions, `SetState` and `SetBasisState` to the ResourceTracker class, allowing these instructions to be tracked. * Reports uses of these functions as `StatePrep` and `BasisState` operators respectively **Benefits:** Device-level `qml.specs` for qjit'd circuits now reports the number of initial `StatePrep` operators. **Possible Drawbacks:** **Related GitHub Issues:** PennyLaneAI/pennylane#8650 [sc-104048] --------- Co-authored-by: Ali Asadi <[email protected]>
1 parent 25d3e96 commit 76d0b2b

File tree

5 files changed

+88
-14
lines changed

5 files changed

+88
-14
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@
6969

7070
<h3>Improvements 🛠</h3>
7171

72+
* Resource tracking now tracks calls to `SetState` and `SetBasisState`, and can report results
73+
that include `qml.StatePrep` operations.
74+
[(#2230)](https://github.com/PennyLaneAI/catalyst/pull/2230)
75+
7276
* Remove the hardcoded list of runtime operations in the frontend.
7377
This will allow arbitrary PL gates to be represented without hyperparameters in MLIR.
7478
For gates that do not have a QIR representation, a runtime error will be raised at execution.

runtime/lib/backend/null_qubit/NullQubit.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,12 @@ struct NullQubit final : public Catalyst::Runtime::QuantumDevice {
224224
* @param state The state vector data (ignored)
225225
* @param wires The qubits to prepare (ignored)
226226
*/
227-
void SetState(DataView<std::complex<double>, 1> &, std::vector<QubitIdType> &) {}
227+
void SetState(DataView<std::complex<double>, 1> &, std::vector<QubitIdType> &wires)
228+
{
229+
if (this->track_resources_) {
230+
this->resource_tracker_.SetState(wires);
231+
}
232+
}
228233

229234
/**
230235
* @brief No-op implementation for computational basis state preparation
@@ -235,7 +240,12 @@ struct NullQubit final : public Catalyst::Runtime::QuantumDevice {
235240
* @param basis_state The computational basis state (ignored)
236241
* @param wires The qubits to prepare (ignored)
237242
*/
238-
void SetBasisState(DataView<int8_t, 1> &, std::vector<QubitIdType> &) {}
243+
void SetBasisState(DataView<int8_t, 1> &, std::vector<QubitIdType> &wires)
244+
{
245+
if (this->track_resources_) {
246+
this->resource_tracker_.SetBasisState(wires);
247+
}
248+
}
239249

240250
/**
241251
* @brief No-op implementation for a named quantum operation

runtime/lib/backend/null_qubit/ResourceTracker.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,26 @@ struct ResourceTracker final {
318318
RecordOperation(op_name, wires, controlled_wires);
319319
}
320320

321+
/**
322+
* @brief Records a state preparation operation for resource tracking
323+
*
324+
* @param wires The target wires the operation acts upon
325+
*/
326+
void SetState(const std::vector<QubitIdType> &wires)
327+
{
328+
RecordOperation("StatePrep", wires, {});
329+
}
330+
331+
/**
332+
* @brief Records a basis state preparation operation for resource tracking
333+
*
334+
* @param wires The target wires the operation acts upon
335+
*/
336+
void SetBasisState(const std::vector<QubitIdType> &wires)
337+
{
338+
RecordOperation("BasisState", wires, {});
339+
}
340+
321341
/**
322342
* @brief Prints resource usage statistics in JSON format to the specified file
323343
*

runtime/tests/Test_NullQubit.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,20 @@ TEST_CASE("Test NullQubit device resource tracking integration", "[NullQubit]")
924924

925925
std::vector<QubitIdType> Qs = sim->AllocateQubits(4);
926926

927+
// Apply set state operations
928+
{
929+
std::vector<std::complex<double>> data = {{0.5, 0.5}, {0.0, 0.0}};
930+
DataView<std::complex<double>, 1> data_view(data);
931+
std::vector<QubitIdType> wires = {1};
932+
sim->SetState(data_view, wires);
933+
}
934+
{
935+
std::vector<int8_t> data = {0};
936+
DataView<int8_t, 1> data_view(data);
937+
std::vector<QubitIdType> wires = {0};
938+
sim->SetBasisState(data_view, wires);
939+
}
940+
927941
// Apply named gates to test all possible name modifiers
928942
sim->NamedOperation("PauliX", {}, {Qs[0]}, false);
929943
sim->NamedOperation("T", {}, {Qs[0]}, true);
@@ -953,16 +967,20 @@ TEST_CASE("Test NullQubit device resource tracking integration", "[NullQubit]")
953967
resource_file_r.open(RESOURCES_FILENAME);
954968
CHECK(resource_file_r.is_open()); // fail-fast if file failed to create
955969

956-
std::vector<std::string> resource_names = {"PauliX",
957-
"C(Adjoint(T))",
958-
"Adjoint(T)",
959-
"C(S)",
960-
"2C(S)",
961-
"CNOT",
962-
"Adjoint(ControlledQubitUnitary)",
963-
"ControlledQubitUnitary",
964-
"Adjoint(QubitUnitary)",
965-
"QubitUnitary"};
970+
std::vector<std::string> resource_names = {
971+
"PauliX",
972+
"C(Adjoint(T))",
973+
"Adjoint(T)",
974+
"C(S)",
975+
"2C(S)",
976+
"CNOT",
977+
"Adjoint(ControlledQubitUnitary)",
978+
"ControlledQubitUnitary",
979+
"Adjoint(QubitUnitary)",
980+
"QubitUnitary",
981+
"StatePrep",
982+
"BasisState",
983+
};
966984

967985
// Read full Json, check if num_wires and num_gates are correct
968986
std::string full_json;
@@ -973,10 +991,10 @@ TEST_CASE("Test NullQubit device resource tracking integration", "[NullQubit]")
973991
CHECK(line.find("4") != std::string::npos);
974992
}
975993
if (line.find("num_gates") != std::string::npos) {
976-
CHECK(line.find("10") != std::string::npos);
994+
CHECK(line.find("12") != std::string::npos);
977995
}
978996
if (line.find("depth") != std::string::npos) {
979-
CHECK(line.find("10") != std::string::npos);
997+
CHECK(line.find("11") != std::string::npos);
980998
}
981999
full_json += line + "\n";
9821000
}

runtime/tests/Test_ResourceTracker.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,25 @@ TEST_CASE("Test Resource Tracker WriteOut", "[resourcetracking]")
427427
CHECK(full_json.find(name) != std::string::npos);
428428
}
429429
}
430+
431+
TEST_CASE("Test Resource Tracker SetState Operations", "[resourcetracking]")
432+
{
433+
ResourceTracker tracker;
434+
tracker.SetComputeDepth(true);
435+
for (size_t i = 0; i < 5; i++) {
436+
tracker.AllocateQubit(i);
437+
}
438+
CHECK(tracker.GetNumGates() == 0);
439+
CHECK(tracker.GetNumGates("StatePrep") == 0);
440+
CHECK(tracker.GetNumGates("BasisState") == 0);
441+
442+
tracker.SetState({0});
443+
tracker.SetBasisState({0, 1, 2});
444+
CHECK(tracker.GetNumGates() == 2);
445+
CHECK(tracker.GetNumGates("StatePrep") == 1);
446+
CHECK(tracker.GetNumGates("BasisState") == 1);
447+
448+
CHECK(tracker.GetNumGatesBySize(1) == 1);
449+
CHECK(tracker.GetNumGatesBySize(3) == 1);
450+
CHECK(tracker.GetDepth() == 2);
451+
}

0 commit comments

Comments
 (0)