|
| 1 | +// This code is part of Qiskit. |
| 2 | +// |
| 3 | +// (C) Copyright IBM 2025 |
| 4 | +// |
| 5 | +// This code is licensed under the Apache License, Version 2.0. You may |
| 6 | +// obtain a copy of this license in the LICENSE.txt file in the root directory |
| 7 | +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. |
| 8 | +// |
| 9 | +// Any modifications or derivative works of this code must retain this |
| 10 | +// copyright notice, and modified files need to carry a notice indicating |
| 11 | +// that they have been altered from the originals. |
| 12 | + |
| 13 | +use crate::passes::alap_schedule_analysis::TimeOps; |
| 14 | +use crate::TranspilerError; |
| 15 | +use hashbrown::HashMap; |
| 16 | +use pyo3::prelude::*; |
| 17 | +use pyo3::types::PyDict; |
| 18 | +use qiskit_circuit::dag_circuit::{DAGCircuit, Wire}; |
| 19 | +use qiskit_circuit::dag_node::{DAGNode, DAGOpNode}; |
| 20 | +use qiskit_circuit::operations::{OperationRef, StandardInstruction}; |
| 21 | +use qiskit_circuit::{Clbit, Qubit}; |
| 22 | +use rustworkx_core::petgraph::prelude::NodeIndex; |
| 23 | + |
| 24 | +pub fn run_asap_schedule_analysis<T: TimeOps>( |
| 25 | + dag: &DAGCircuit, |
| 26 | + clbit_write_latency: T, |
| 27 | + node_durations: HashMap<NodeIndex, T>, |
| 28 | +) -> PyResult<HashMap<NodeIndex, T>> { |
| 29 | + if dag.qregs().len() != 1 || !dag.qregs_data().contains_key("q") { |
| 30 | + return Err(TranspilerError::new_err( |
| 31 | + "ASAP schedule runs on physical circuits only", |
| 32 | + )); |
| 33 | + } |
| 34 | + |
| 35 | + let mut node_start_time: HashMap<NodeIndex, T> = HashMap::new(); |
| 36 | + let mut idle_after: HashMap<Wire, T> = HashMap::new(); |
| 37 | + |
| 38 | + let zero = T::zero(); |
| 39 | + |
| 40 | + for index in 0..dag.qubits().len() { |
| 41 | + idle_after.insert(Wire::Qubit(Qubit::new(index)), zero); |
| 42 | + } |
| 43 | + |
| 44 | + for index in 0..dag.clbits().len() { |
| 45 | + idle_after.insert(Wire::Clbit(Clbit::new(index)), zero); |
| 46 | + } |
| 47 | + |
| 48 | + for node_index in dag.topological_op_nodes()? { |
| 49 | + let op = dag[node_index].unwrap_operation(); |
| 50 | + |
| 51 | + let qargs: Vec<Wire> = dag |
| 52 | + .qargs_interner() |
| 53 | + .get(op.qubits) |
| 54 | + .iter() |
| 55 | + .map(|&q| Wire::Qubit(q)) |
| 56 | + .collect(); |
| 57 | + let cargs: Vec<Wire> = dag |
| 58 | + .cargs_interner() |
| 59 | + .get(op.clbits) |
| 60 | + .iter() |
| 61 | + .map(|&c| Wire::Clbit(c)) |
| 62 | + .collect(); |
| 63 | + |
| 64 | + let &op_duration = node_durations.get(&node_index).ok_or_else(|| { |
| 65 | + TranspilerError::new_err(format!( |
| 66 | + "No duration found for node at index {}", |
| 67 | + node_index.index() |
| 68 | + )) |
| 69 | + })?; |
| 70 | + let op_view = op.op.view(); |
| 71 | + let is_gate_or_delay = matches!( |
| 72 | + op_view, |
| 73 | + OperationRef::Gate(_) |
| 74 | + | OperationRef::StandardGate(_) |
| 75 | + | OperationRef::StandardInstruction(StandardInstruction::Delay(_)) |
| 76 | + ); |
| 77 | + |
| 78 | + // compute t0, t1: instruction interval, note that |
| 79 | + // t0: start time of instruction |
| 80 | + // t1: end time of instruction |
| 81 | + |
| 82 | + let (t0, t1) = if is_gate_or_delay { |
| 83 | + let t0 = qargs |
| 84 | + .iter() |
| 85 | + .map(|q| idle_after[q]) |
| 86 | + .fold(zero, |acc, x| *T::max(&acc, &x)); |
| 87 | + (t0, t0 + op_duration) |
| 88 | + } else if matches!( |
| 89 | + op_view, |
| 90 | + OperationRef::StandardInstruction(StandardInstruction::Measure) |
| 91 | + ) { |
| 92 | + // Measure instruction handling is bit tricky due to clbit_write_latency |
| 93 | + let t0q = qargs |
| 94 | + .iter() |
| 95 | + .map(|q| idle_after[q]) |
| 96 | + .fold(zero, |acc, x| *T::max(&acc, &x)); |
| 97 | + let t0c = cargs |
| 98 | + .iter() |
| 99 | + .map(|c| idle_after[c]) |
| 100 | + .fold(zero, |acc, x| *T::max(&acc, &x)); |
| 101 | + // Assume following case (t0c > t0q) |
| 102 | + // |
| 103 | + // |t0q |
| 104 | + // Q ▒▒▒▒░░░░░░░░░░░░ |
| 105 | + // C ▒▒▒▒▒▒▒▒░░░░░░░░ |
| 106 | + // |t0c |
| 107 | + // |
| 108 | + // In this case, there is no actual clbit access until clbit_write_latency. |
| 109 | + // The node t0 can be push backward by this amount. |
| 110 | + // |
| 111 | + // |t0q' = t0c - clbit_write_latency |
| 112 | + // Q ▒▒▒▒░░▒▒▒▒▒▒▒▒▒▒ |
| 113 | + // C ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ |
| 114 | + // |t0c' = t0c |
| 115 | + // |
| 116 | + // rather than naively doing |
| 117 | + // |
| 118 | + // |t0q' = t0c |
| 119 | + // Q ▒▒▒▒░░░░▒▒▒▒▒▒▒▒ |
| 120 | + // C ▒▒▒▒▒▒▒▒░░░▒▒▒▒▒ |
| 121 | + // |t0c' = t0c + clbut_write_latency |
| 122 | + let t0 = *T::max(&t0q, &(t0c - clbit_write_latency)); |
| 123 | + let t1 = t0 + op_duration; |
| 124 | + for clbit in cargs.iter() { |
| 125 | + idle_after.insert(*clbit, t1); |
| 126 | + } |
| 127 | + (t0, t1) |
| 128 | + } else { |
| 129 | + // Directives (like Barrier) |
| 130 | + let t0 = qargs |
| 131 | + .iter() |
| 132 | + .chain(cargs.iter()) |
| 133 | + .map(|bit| idle_after[bit]) |
| 134 | + .fold(zero, |acc, x| *T::max(&acc, &x)); |
| 135 | + (t0, t0 + op_duration) |
| 136 | + }; |
| 137 | + |
| 138 | + for qubit in qargs { |
| 139 | + idle_after.insert(qubit, t1); |
| 140 | + } |
| 141 | + |
| 142 | + node_start_time.insert(node_index, t0); |
| 143 | + } |
| 144 | + |
| 145 | + Ok(node_start_time) |
| 146 | +} |
| 147 | + |
| 148 | +#[pyfunction] |
| 149 | +/// Runs the ASAPSchedule analysis pass on dag. |
| 150 | +/// |
| 151 | +/// Args: |
| 152 | +/// dag (DAGCircuit): DAG to schedule. |
| 153 | +/// clbit_write_latency (u64): The latency to write classical bits. |
| 154 | +/// node_durations (PyDict): Mapping from node indices to operation durations. |
| 155 | +/// |
| 156 | +/// Returns: |
| 157 | +/// PyDict: A dictionary mapping each DAGOpNode to its scheduled start time. |
| 158 | +/// |
| 159 | +#[pyo3(name = "asap_schedule_analysis", signature= (dag, clbit_write_latency, node_durations))] |
| 160 | +pub fn py_run_asap_schedule_analysis( |
| 161 | + py: Python, |
| 162 | + dag: &DAGCircuit, |
| 163 | + clbit_write_latency: u64, |
| 164 | + node_durations: &Bound<PyDict>, |
| 165 | +) -> PyResult<Py<PyDict>> { |
| 166 | + // Extract indices and durations from PyDict |
| 167 | + // Get the first duration type |
| 168 | + let mut iter = node_durations.iter(); |
| 169 | + let py_dict = PyDict::new(py); |
| 170 | + if let Some((_, first_duration)) = iter.next() { |
| 171 | + if first_duration.extract::<u64>().is_ok() { |
| 172 | + // All durations are of type u64 |
| 173 | + let mut op_durations = HashMap::new(); |
| 174 | + for (py_node, py_duration) in node_durations.iter() { |
| 175 | + let node_idx = py_node |
| 176 | + .downcast_into::<DAGOpNode>()? |
| 177 | + .extract::<DAGNode>()? |
| 178 | + .node |
| 179 | + .expect("Node index not found."); |
| 180 | + let val = py_duration.extract::<u64>()?; |
| 181 | + op_durations.insert(node_idx, val); |
| 182 | + } |
| 183 | + let node_start_time = |
| 184 | + run_asap_schedule_analysis::<u64>(dag, clbit_write_latency, op_durations)?; |
| 185 | + for (node_idx, t1) in node_start_time { |
| 186 | + let node = dag.get_node(py, node_idx)?; |
| 187 | + py_dict.set_item(node, t1)?; |
| 188 | + } |
| 189 | + } else if first_duration.extract::<f64>().is_ok() { |
| 190 | + // All durations are of type f64 |
| 191 | + let mut op_durations = HashMap::new(); |
| 192 | + for (py_node, py_duration) in node_durations.iter() { |
| 193 | + let node_idx = py_node |
| 194 | + .downcast_into::<DAGOpNode>()? |
| 195 | + .extract::<DAGNode>()? |
| 196 | + .node |
| 197 | + .expect("Node index not found."); |
| 198 | + let val = py_duration.extract::<f64>()?; |
| 199 | + op_durations.insert(node_idx, val); |
| 200 | + } |
| 201 | + let node_start_time = |
| 202 | + run_asap_schedule_analysis::<f64>(dag, clbit_write_latency as f64, op_durations)?; |
| 203 | + for (node_idx, t1) in node_start_time { |
| 204 | + let node = dag.get_node(py, node_idx)?; |
| 205 | + py_dict.set_item(node, t1)?; |
| 206 | + } |
| 207 | + } else { |
| 208 | + return Err(TranspilerError::new_err("Duration must be int or float")); |
| 209 | + } |
| 210 | + } else { |
| 211 | + return Err(TranspilerError::new_err("No durations provided")); |
| 212 | + } |
| 213 | + |
| 214 | + Ok(py_dict.into()) |
| 215 | +} |
| 216 | + |
| 217 | +pub fn asap_schedule_analysis_mod(m: &Bound<PyModule>) -> PyResult<()> { |
| 218 | + m.add_wrapped(wrap_pyfunction!(py_run_asap_schedule_analysis))?; |
| 219 | + Ok(()) |
| 220 | +} |
0 commit comments