diff --git a/README.md b/README.md index c9620b6..83b9b6a 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,12 @@ random_permutation = np.random.permutation(9) optimized_circuit = rls.synth(random_permutation, num_searches=1000) ``` +## 🏅 Reward and Gate Penalties (at a glance) +- Each step returns `reward = (1.0 if solved else 0.0) - penalty`. +- `penalty` is the weighted increase in cost metrics after the chosen gate: CNOT count, CNOT layers, total layers, and total gates. +- Default weights (`MetricsWeights`) are `n_cnots=0.01`, `n_layers_cnots=0.0`, `n_layers=0.0`, `n_gates=0.0001`; configure per env via `metrics_weights`. +- Metrics accumulate over the episode; once the target is solved, the positive reward is offset by the penalties from any extra cost incurred. + ## 🤝 Contributing We welcome contributions! Whether you're adding new synthesis problems, improving RL algorithms, or enhancing documentation - every contribution helps advance quantum computing research. @@ -100,4 +106,4 @@ Licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE.txt) for d - Kremer, D., Villar, V., Paik, H., Duran, I., Faro, I., & Cruz-Benito, J. (2024). Practical and efficient quantum circuit synthesis and transpiling with reinforcement learning. arXiv preprint [arXiv:2405.13196](https://arxiv.org/abs/2405.13196). -- Dubal, A., Kremer, D., Martiel, S., Villar, V., Wang, D., & Cruz-Benito, J. (2025). Pauli Network Circuit Synthesis with Reinforcement Learning. arXiv preprint [arXiv:2503.14448](https://arxiv.org/abs/2503.14448). \ No newline at end of file +- Dubal, A., Kremer, D., Martiel, S., Villar, V., Wang, D., & Cruz-Benito, J. (2025). Pauli Network Circuit Synthesis with Reinforcement Learning. arXiv preprint [arXiv:2503.14448](https://arxiv.org/abs/2503.14448). diff --git a/pyproject.toml b/pyproject.toml index d5fe42d..da82e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dynamic = ["version"] dependencies = [ "qiskit>=2.1", "gymnasium", - "twisterl", + "twisterl~=0.4.1", ] diff --git a/rust/Cargo.lock b/rust/Cargo.lock index f725571..672d221 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -351,7 +351,7 @@ dependencies = [ "pyo3", "rand", "rayon", - "twisterl-rs", + "twisterl", ] [[package]] @@ -465,10 +465,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] -name = "twisterl-rs" -version = "0.1.0" +name = "twisterl" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25a14d819575b07f13110e7ebfe96f7be712a567bb0126e22a624aaf30fda9e3" +checksum = "2e20f49e0f02e09d1ddee49bde56e1ea4642427580012458cf98d6d87fd0c15c" dependencies = [ "anyhow", "dyn-clone", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index a283d00..b73406b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -14,7 +14,7 @@ nalgebra = "0.33.0" rand = "0.8.4" rayon = "1.1.0" petgraph = "0.6.5" -twisterl = {package = "twisterl-rs", version = "0.1.0", features = ["python_bindings"]} +twisterl = {version = "~0.4.1", features = ["python_bindings"]} [profile.release] opt-level = 3 diff --git a/rust/src/envs/clifford.rs b/rust/src/envs/clifford.rs index dbc74fb..a798745 100644 --- a/rust/src/envs/clifford.rs +++ b/rust/src/envs/clifford.rs @@ -14,11 +14,15 @@ that they have been altered from the originals. use pyo3::prelude::*; use rand::distributions::{Distribution, Uniform}; +use rand::Rng; use twisterl::rl::env::Env; -use twisterl::python_interface::env::{PyBaseEnv, get_env_ref, get_env_mut}; +use twisterl::python_interface::env::PyBaseEnv; use crate::envs::common::Gate; +use crate::envs::metrics::{MetricsCounts, MetricsTracker, MetricsWeights}; +use crate::envs::symmetry::compute_twists_clifford; +use std::collections::HashMap; #[derive(Clone)] @@ -139,6 +143,35 @@ impl CFState { } true } + + fn inverse(&self) -> Self { + let dim = self.dim(); + let mut mat = self.clone(); + let mut inv = CFState::new(self.n); + + for col in 0..dim { + if !mat.get(col, col) { + let pivot = ((col + 1)..dim).find(|&row| mat.get(row, col)); + let pivot = pivot.expect("CFState is singular; cannot invert"); + mat.swap_rows(col, pivot); + inv.swap_rows(col, pivot); + } + + for row in 0..dim { + if row != col && mat.get(row, col) { + mat.row_xor(row, col); + inv.row_xor(row, col); + } + } + } + + debug_assert!(mat.solved(), "CFState inverse computation failed"); + inv + } + + fn invert(&mut self) { + *self = self.inverse(); + } } // -------- Env: Clifford synthesis over the symplectic tableau (phase ignored) -------- @@ -153,6 +186,17 @@ pub struct Clifford { pub gateset: Vec, pub depth_slope: usize, pub max_depth: usize, + pub obs_perms: Vec>, + pub act_perms: Vec>, + metrics: MetricsTracker, + metrics_values: MetricsCounts, + metrics_weights: MetricsWeights, + reward_value: f32, + add_inverts: bool, + track_solution: bool, + solution: Vec, + solution_inv: Vec, + inverted: bool, } impl Clifford { @@ -162,12 +206,80 @@ impl Clifford { gateset: Vec, depth_slope: usize, max_depth: usize, + metrics_weights: MetricsWeights, + add_inverts: bool, + add_perms: bool, + track_solution: bool, ) -> Self { let cf = CFState::new(num_qubits); let success = cf.solved(); - Clifford { cf, depth: 1, success, difficulty, gateset, depth_slope, max_depth } + + // Only compute symmetries if enabled + let (obs_perms, act_perms) = if add_perms { + compute_twists_clifford(num_qubits, &gateset) + } else { + (Vec::new(), Vec::new()) + }; + + let metrics = MetricsTracker::new(num_qubits); + let metrics_values = metrics.snapshot(); + Clifford { + cf, + depth: 1, + success, + difficulty, + gateset, + depth_slope, + max_depth, + obs_perms, + act_perms, + metrics, + metrics_values, + metrics_weights, + reward_value: if success { 1.0 } else { 0.0 }, + add_inverts, + track_solution, + solution: Vec::new(), + solution_inv: Vec::new(), + inverted: false, + } } pub fn solved(&self) -> bool { self.cf.solved() } + + fn apply_gate_to_state(&mut self, gate: &Gate) { + match gate { + Gate::H(q) => self.cf.h(*q), + Gate::S(q) => self.cf.s(*q), + Gate::Sdg(q) => self.cf.sdg(*q), // identical to S modulo global phase (ignored) + Gate::SX(q) => self.cf.sx(*q), + Gate::SXdg(q) => self.cf.sxdg(*q), // identical to SX modulo global phase (ignored) + Gate::CX(c, t) => self.cf.cx(*c, *t), + Gate::CZ(a, b) => self.cf.cz(*a, *b), + Gate::SWAP(a, b) => self.cf.swap(*a, *b), + } + } + + fn maybe_random_invert(&mut self) { + if !self.add_inverts { + return; + } + if rand::thread_rng().gen_bool(0.5) { + self.cf.invert(); + self.inverted = !self.inverted; + } + } + + fn reset_internals(&mut self) { + self.success = self.solved(); + self.metrics.reset(); + self.metrics_values = self.metrics.snapshot(); + self.reward_value = if self.success { 1.0 } else { 0.0 }; + self.inverted = false; + if self.track_solution { + self.solution_inv = Vec::new(); + self.solution = Vec::new(); + } + } } impl Env for Clifford { @@ -188,38 +300,50 @@ impl Env for Clifford { // Expecting a flattened 2N x 2N boolean matrix encoded as i64s (>0 => true) self.cf.data = state.iter().map(|&x| x > 0).collect(); self.depth = self.max_depth; - self.success = self.solved(); + self.reset_internals(); } fn reset(&mut self) { self.cf = CFState::new(self.cf.n); - self.depth = self.max_depth; - self.success = self.solved(); - let mut rng = rand::thread_rng(); let action_range = Uniform::new(0, self.num_actions()); for _ in 0..self.difficulty { let action = action_range.sample(&mut rng); - self.step(action); + if let Some(gate) = self.gateset.get(action).cloned() { + self.apply_gate_to_state(&gate); + } } self.depth = (self.depth_slope * self.difficulty).min(self.max_depth); - self.success = self.solved(); + self.reset_internals(); } fn step(&mut self, action: usize) { - match self.gateset[action] { - Gate::H(q) => self.cf.h(q), - Gate::S(q) => self.cf.s(q), - Gate::Sdg(q) => self.cf.sdg(q), // identical to S modulo global phase (ignored) - Gate::SX(q) => self.cf.sx(q), - Gate::SXdg(q) => self.cf.sxdg(q), // identical to SX modulo global phase (ignored) - Gate::CX(c, t) => self.cf.cx(c, t), - Gate::CZ(a, b) => self.cf.cz(a, b), - Gate::SWAP(a,b) => self.cf.swap(a, b), + let mut penalty = 0.0f32; + + if let Some(gate) = self.gateset.get(action).cloned() { + let previous = self.metrics_values.clone(); + self.metrics.apply_gate(&gate); + let new_metrics = self.metrics.snapshot(); + penalty = new_metrics.weighted_delta(&previous, &self.metrics_weights); + self.metrics_values = new_metrics; + + self.apply_gate_to_state(&gate); + } + + if self.track_solution { + if self.inverted { + self.solution_inv.push(action); + } else { + self.solution.push(action); + } } + self.depth = self.depth.saturating_sub(1); + self.maybe_random_invert(); self.success = self.solved(); + let achieved = if self.success { 1.0 } else { 0.0 }; + self.reward_value = achieved - penalty; } fn masks(&self) -> Vec { @@ -228,14 +352,10 @@ impl Env for Clifford { fn is_final(&self) -> bool { self.depth == 0 || self.success } - fn reward(&self) -> f32 { - if self.success { - 1.0 - } else if self.depth == 0 { - -0.5 - } else { - -0.5 / (self.max_depth as f32) - } + fn reward(&self) -> f32 { self.reward_value } + + fn success(&self) -> bool { + self.success } fn observe(&self) -> Vec { @@ -246,6 +366,19 @@ impl Env for Clifford { .filter_map(|(i, &v)| if v { Some(i) } else { None }) .collect() } + + fn twists(&self) -> (Vec>, Vec>) { + (self.obs_perms.clone(), self.act_perms.clone()) + } + + fn track_solution(&self) -> bool { self.track_solution } + + fn solution(&self) -> Vec { + let mut out = Vec::with_capacity(self.solution.len() + self.solution_inv.len()); + out.extend_from_slice(&self.solution); + out.extend(self.solution_inv.iter().rev().copied()); + out + } } #[pyclass(name="CliffordEnv", extends=PyBaseEnv)] @@ -254,15 +387,41 @@ pub struct PyCliffordEnv; #[pymethods] impl PyCliffordEnv { #[new] + #[pyo3(signature = ( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + metrics_weights=None, + add_inverts=None, + add_perms=None, + track_solution=None, + ))] pub fn new( num_qubits: usize, difficulty: usize, gateset: Vec, depth_slope: usize, - max_depth: usize + max_depth: usize, + metrics_weights: Option>, + add_inverts: Option, + add_perms: Option, + track_solution: Option, ) -> (Self, PyBaseEnv) { - let env = Clifford::new(num_qubits, difficulty, gateset, depth_slope, max_depth); + let weights = MetricsWeights::from_hashmap(metrics_weights); + let env = Clifford::new( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + weights, + add_inverts.unwrap_or(true), + add_perms.unwrap_or(true), + track_solution.unwrap_or(true), + ); let env = Box::new(env); (PyCliffordEnv, PyBaseEnv { env }) } -} \ No newline at end of file +} diff --git a/rust/src/envs/linear_function.rs b/rust/src/envs/linear_function.rs index f04ff33..7cf7eb0 100644 --- a/rust/src/envs/linear_function.rs +++ b/rust/src/envs/linear_function.rs @@ -14,12 +14,17 @@ that they have been altered from the originals. use pyo3::prelude::*; use rand::distributions::{Distribution, Uniform}; +use rand::Rng; use twisterl::rl::env::Env; -use twisterl::python_interface::env::{PyBaseEnv, get_env_ref, get_env_mut}; +use twisterl::python_interface::env::PyBaseEnv; use crate::envs::common::Gate; +use crate::envs::metrics::{MetricsCounts, MetricsTracker, MetricsWeights}; +use crate::envs::symmetry::compute_twists_square; +use std::collections::HashMap; + // Define some internal representation #[derive(Clone)] pub struct LFState { @@ -93,6 +98,56 @@ impl LFState { } true } + + fn row_xor(&mut self, dest: usize, src: usize) { + if dest == src { + return; + } + for col in 0..self.size { + let dest_idx = self.index(dest, col); + let src_idx = self.index(src, col); + self.data[dest_idx] ^= self.data[src_idx]; + } + } + + fn swap_rows(&mut self, r1: usize, r2: usize) { + if r1 == r2 { + return; + } + for col in 0..self.size { + let i1 = self.index(r1, col); + let i2 = self.index(r2, col); + self.data.swap(i1, i2); + } + } + + fn inverse(&self) -> Self { + let mut mat = self.clone(); + let mut inv = LFState::new(self.size); + + for col in 0..self.size { + if !mat.get(col, col) { + let pivot = ((col + 1)..self.size).find(|&row| mat.get(row, col)); + let pivot = pivot.expect("LFState is singular; cannot invert"); + mat.swap_rows(col, pivot); + inv.swap_rows(col, pivot); + } + + for row in 0..self.size { + if row != col && mat.get(row, col) { + mat.row_xor(row, col); + inv.row_xor(row, col); + } + } + } + + debug_assert!(mat.solved(), "LFState inverse computation failed"); + inv + } + + fn invert(&mut self) { + *self = self.inverse(); + } } // This is the Env definition @@ -105,7 +160,18 @@ pub struct LinearFunction { pub difficulty: usize, pub gateset: Vec, pub depth_slope: usize, - pub max_depth: usize + pub max_depth: usize, + pub obs_perms: Vec>, + pub act_perms: Vec>, + metrics: MetricsTracker, + metrics_values: MetricsCounts, + metrics_weights: MetricsWeights, + reward_value: f32, + add_inverts: bool, + track_solution: bool, + solution: Vec, + solution_inv: Vec, + inverted: bool, } @@ -116,15 +182,77 @@ impl LinearFunction { gateset: Vec, depth_slope: usize, max_depth: usize, + metrics_weights: MetricsWeights, + add_inverts: bool, + add_perms: bool, + track_solution: bool, ) -> Self { let lf = LFState::new(num_qubits); let success = lf.solved(); - LinearFunction {lf, depth:1, success, difficulty, gateset, depth_slope, max_depth } + + // Only compute symmetries if enabled + let (obs_perms, act_perms) = if add_perms { + compute_twists_square(num_qubits, &gateset) + } else { + (Vec::new(), Vec::new()) + }; + + let metrics = MetricsTracker::new(num_qubits); + let metrics_values = metrics.snapshot(); + LinearFunction { + lf, + depth: 1, + success, + difficulty, + gateset, + depth_slope, + max_depth, + obs_perms, + act_perms, + metrics, + metrics_values, + metrics_weights, + reward_value: if success { 1.0 } else { 0.0 }, + add_inverts, + track_solution, + solution: Vec::new(), + solution_inv: Vec::new(), + inverted: false, + } } pub fn solved(&self) -> bool { self.lf.solved() } + fn maybe_random_invert(&mut self) { + if !self.add_inverts { + return; + } + if rand::thread_rng().gen_bool(0.5) { + self.lf.invert(); + self.inverted = !self.inverted; + } + } + + fn apply_gate_to_state(&mut self, gate: &Gate) { + match gate { + &Gate::CX(q1, q2) => self.lf.cx(q1, q2), + &Gate::SWAP(q1, q2) => self.lf.swap(q1, q2), + _ => {} + } + } + + fn reset_internals(&mut self) { + self.success = self.solved(); + self.metrics.reset(); + self.metrics_values = self.metrics.snapshot(); + self.reward_value = if self.success { 1.0 } else { 0.0 }; + self.inverted = false; + if self.track_solution { + self.solution_inv = Vec::new(); + self.solution = Vec::new(); + } + } } // This implements the necessary functions for the environment @@ -151,35 +279,52 @@ impl Env for LinearFunction { fn set_state(&mut self, state: Vec) { self.lf.data = state.iter().map(|&x| x>0).collect(); self.depth = self.max_depth; - self.success = self.solved(); + self.reset_internals(); } fn reset(&mut self) { // Create an identity matrix for the initial 'lf' state self.lf = LFState::new(self.lf.size); - self.depth = self.max_depth; - self.success = self.solved(); - let mut rng = rand::thread_rng(); let action_range = Uniform::new(0, self.num_actions()); // Apply random actions based on the difficulty for _ in 0..self.difficulty { let action = action_range.sample(&mut rng); - self.step(action); + if let Some(gate) = self.gateset.get(action).cloned() { + self.apply_gate_to_state(&gate); + } } - self.depth = (self.depth_slope * self.difficulty).min(self.max_depth); - self.success = self.solved(); + self.depth = (self.depth_slope * self.difficulty).min(self.max_depth); + self.reset_internals(); } fn step(&mut self, action: usize) { - match self.gateset[action] { - Gate::CX(q1, q2) => self.lf.cx(q1, q2), - Gate::SWAP(q1, q2) => self.lf.swap(q1, q2), - _ => {} - } + let mut penalty = 0.0f32; + + if let Some(gate) = self.gateset.get(action).cloned() { + let previous = self.metrics_values.clone(); + self.metrics.apply_gate(&gate); + let new_metrics = self.metrics.snapshot(); + penalty = new_metrics.weighted_delta(&previous, &self.metrics_weights); + self.metrics_values = new_metrics; + + self.apply_gate_to_state(&gate); + } + + if self.track_solution { + if self.inverted { + self.solution_inv.push(action); + } else { + self.solution.push(action); + } + } + self.depth = self.depth.saturating_sub(1); // Prevent underflow + self.maybe_random_invert(); self.success = self.solved(); + let achieved = if self.success { 1.0 } else { 0.0 }; + self.reward_value = achieved - penalty; } fn masks(&self) -> Vec { @@ -191,11 +336,11 @@ impl Env for LinearFunction { } fn reward(&self) -> f32 { - if self.success { - 1.0 - } else { - if self.depth == 0 { -0.5 } else { -0.5/(self.max_depth as f32) } - } + self.reward_value + } + + fn success(&self) -> bool { + self.success } fn observe(&self,) -> Vec { @@ -204,8 +349,20 @@ impl Env for LinearFunction { .filter_map(|(index, &value)| if value { Some(index) } else { None }) // Collect indices where the value is true .collect() } -} + fn twists(&self) -> (Vec>, Vec>) { + (self.obs_perms.clone(), self.act_perms.clone()) + } + + fn track_solution(&self) -> bool { self.track_solution } + + fn solution(&self) -> Vec { + let mut out = Vec::with_capacity(self.solution.len() + self.solution_inv.len()); + out.extend_from_slice(&self.solution); + out.extend(self.solution_inv.iter().rev().copied()); + out + } +} #[pyclass(name="LinearFunctionEnv", extends=PyBaseEnv)] pub struct PyLinearFunctionEnv; @@ -213,15 +370,41 @@ pub struct PyLinearFunctionEnv; #[pymethods] impl PyLinearFunctionEnv { #[new] + #[pyo3(signature = ( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + metrics_weights=None, + add_inverts=None, + add_perms=None, + track_solution=None, + ))] pub fn new( num_qubits: usize, difficulty: usize, gateset: Vec, depth_slope: usize, - max_depth: usize + max_depth: usize, + metrics_weights: Option>, + add_inverts: Option, + add_perms: Option, + track_solution: Option, ) -> (Self, PyBaseEnv) { - let env = LinearFunction::new(num_qubits, difficulty, gateset, depth_slope, max_depth); + let weights = MetricsWeights::from_hashmap(metrics_weights); + let env = LinearFunction::new( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + weights, + add_inverts.unwrap_or(true), + add_perms.unwrap_or(true), + track_solution.unwrap_or(true) + ); let env = Box::new(env); (PyLinearFunctionEnv, PyBaseEnv { env }) } -} \ No newline at end of file +} diff --git a/rust/src/envs/metrics.rs b/rust/src/envs/metrics.rs new file mode 100644 index 0000000..4eb8fc5 --- /dev/null +++ b/rust/src/envs/metrics.rs @@ -0,0 +1,184 @@ +// -*- coding: utf-8 -*- +/* +(C) Copyright 2025 IBM. All Rights Reserved. + +This code is licensed under the Apache License, Version 2.0. You may +obtain a copy of this license in the LICENSE.txt file in the root directory +of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. + +Any modifications or derivative works of this code must retain this +copyright notice, and modified files need to carry a notice indicating +that they have been altered from the originals. +*/ + +use std::collections::{HashMap, HashSet}; + +use crate::envs::common::Gate; + +#[derive(Clone)] +pub struct MetricsTracker { + num_qubits: usize, + n_cnots: usize, + n_gates: usize, + cnot_layers: HashSet, + layers: HashSet, + last_gates: Vec, + last_cxs: Vec, +} + +impl MetricsTracker { + pub fn new(num_qubits: usize) -> Self { + Self { + num_qubits, + n_cnots: 0, + n_gates: 0, + cnot_layers: HashSet::new(), + layers: HashSet::new(), + last_gates: vec![-1; num_qubits], + last_cxs: vec![-1; num_qubits], + } + } + + pub fn reset(&mut self) { + self.n_cnots = 0; + self.n_gates = 0; + self.cnot_layers.clear(); + self.layers.clear(); + for val in self.last_gates.iter_mut() { + *val = -1; + } + for val in self.last_cxs.iter_mut() { + *val = -1; + } + } + + pub fn snapshot(&self) -> MetricsCounts { + MetricsCounts { + n_cnots: self.n_cnots, + n_layers_cnots: self.cnot_layers.len(), + n_layers: self.layers.len(), + n_gates: self.n_gates, + } + } + + pub fn apply_gate(&mut self, gate: &Gate) { + match gate { + Gate::CX(c, t) => self.cx(*c, *t), + Gate::SWAP(c, t) => { + self.cx(*c, *t); + self.cx(*t, *c); + self.cx(*c, *t); + } + Gate::CZ(c, t) => { + self.single_qubit(*t); + self.cx(*c, *t); + self.single_qubit(*t); + } + Gate::H(q) | Gate::S(q) | Gate::Sdg(q) | Gate::SX(q) | Gate::SXdg(q) => { + self.single_qubit(*q); + } + } + } + + fn single_qubit(&mut self, target: usize) { + if target >= self.num_qubits { + return; + } + + self.n_gates += 1; + let gate_layer = self.last_gates[target] + 1; + self.last_gates[target] = gate_layer; + + if gate_layer >= 0 { + self.layers.insert(gate_layer as usize); + } + } + + fn cx(&mut self, control: usize, target: usize) { + if control == target + || control >= self.num_qubits + || target >= self.num_qubits + { + return; + } + + self.n_cnots += 1; + self.n_gates += 1; + + let gate_layer = (self.last_gates[control].max(self.last_gates[target])) + 1; + self.last_gates[control] = gate_layer; + self.last_gates[target] = gate_layer; + + if gate_layer >= 0 { + self.layers.insert(gate_layer as usize); + } + + let cx_layer = (self.last_cxs[control].max(self.last_cxs[target])) + 1; + self.last_cxs[control] = cx_layer; + self.last_cxs[target] = cx_layer; + + if cx_layer >= 0 { + self.cnot_layers.insert(cx_layer as usize); + } + } +} + +#[derive(Clone)] +pub struct MetricsCounts { + n_cnots: usize, + n_layers_cnots: usize, + n_layers: usize, + n_gates: usize, +} + +impl MetricsCounts { + pub fn weighted_delta(&self, previous: &Self, weights: &MetricsWeights) -> f32 { + let delta_cnots = self.n_cnots.saturating_sub(previous.n_cnots) as f32; + let delta_layers_cnots = + self.n_layers_cnots.saturating_sub(previous.n_layers_cnots) as f32; + let delta_layers = self.n_layers.saturating_sub(previous.n_layers) as f32; + let delta_gates = self.n_gates.saturating_sub(previous.n_gates) as f32; + + weights.n_cnots * delta_cnots + + weights.n_layers_cnots * delta_layers_cnots + + weights.n_layers * delta_layers + + weights.n_gates * delta_gates + } +} + +#[derive(Clone)] +pub struct MetricsWeights { + pub n_cnots: f32, + pub n_layers_cnots: f32, + pub n_layers: f32, + pub n_gates: f32, +} + +impl Default for MetricsWeights { + fn default() -> Self { + Self { + n_cnots: 0.01, + n_layers_cnots: 0.0, + n_layers: 0.0, + n_gates: 0.0001, + } + } +} + +impl MetricsWeights { + pub fn from_hashmap(map: Option>) -> Self { + let mut weights = Self::default(); + if let Some(values) = map { + for (key, value) in values { + match key.as_str() { + "n_cnots" => weights.n_cnots = value, + "n_layers_cnots" => weights.n_layers_cnots = value, + "n_layers" => weights.n_layers = value, + "n_gates" => weights.n_gates = value, + _ => {} + } + } + } + weights + } +} diff --git a/rust/src/envs/mod.rs b/rust/src/envs/mod.rs index 0218394..7a92e30 100644 --- a/rust/src/envs/mod.rs +++ b/rust/src/envs/mod.rs @@ -14,4 +14,6 @@ that they have been altered from the originals. pub mod clifford; pub mod linear_function; pub mod permutation; -pub mod common; \ No newline at end of file +pub mod common; +pub mod symmetry; +pub mod metrics; diff --git a/rust/src/envs/permutation.rs b/rust/src/envs/permutation.rs index 5f58ef5..ae70e8f 100644 --- a/rust/src/envs/permutation.rs +++ b/rust/src/envs/permutation.rs @@ -14,11 +14,15 @@ that they have been altered from the originals. use pyo3::prelude::*; use rand::distributions::{Distribution, Uniform}; +use rand::Rng; use twisterl::rl::env::Env; -use twisterl::python_interface::env::{PyBaseEnv, get_env_ref, get_env_mut}; +use twisterl::python_interface::env::PyBaseEnv; use crate::envs::common::Gate; +use crate::envs::metrics::{MetricsCounts, MetricsTracker, MetricsWeights}; +use crate::envs::symmetry::compute_twists_square; +use std::collections::HashMap; // This is the Env definition @@ -32,7 +36,18 @@ pub struct Permutation { pub difficulty: usize, pub gateset: Vec, pub depth_slope: usize, - pub max_depth: usize + pub max_depth: usize, + pub obs_perms: Vec>, + pub act_perms: Vec>, + metrics: MetricsTracker, + metrics_values: MetricsCounts, + metrics_weights: MetricsWeights, + reward_value: f32, + pub add_inverts: bool, + track_solution: bool, + solution: Vec, + solution_inv: Vec, + inverted: bool, } @@ -43,8 +58,65 @@ impl Permutation { gateset: Vec, depth_slope: usize, max_depth: usize, + metrics_weights: MetricsWeights, + add_inverts: bool, + add_perms: bool, + track_solution: bool, ) -> Self { - Permutation {state:(0..num_qubits).collect(), depth:1, success:true, num_qubits:num_qubits, difficulty:difficulty, gateset:gateset, depth_slope:depth_slope, max_depth:max_depth} + // Only compute symmetries if enabled + let (obs_perms, act_perms) = if add_perms { + compute_twists_square(num_qubits, &gateset) + } else { + (Vec::new(), Vec::new()) + }; + + let metrics = MetricsTracker::new(num_qubits); + let metrics_values = metrics.snapshot(); + let success = true; + Permutation { + state:(0..num_qubits).collect(), + depth:1, + success, + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + obs_perms, + act_perms, + metrics, + metrics_values, + metrics_weights, + reward_value: 1.0, + add_inverts, + track_solution, + solution: Vec::new(), + solution_inv: Vec::new(), + inverted: false, + } + } + + /// Compute the inverse of a permutation + /// For a permutation perm, returns inv such that perm[inv[i]] = i for all i + fn invert_perm(perm: &[usize]) -> Vec { + let mut inv = vec![0; perm.len()]; + for (i, &val) in perm.iter().enumerate() { + inv[val] = i; + } + inv + } + + /// Randomly invert the permutation with 50% probability when enabled. + fn maybe_random_invert(&mut self) { + if !self.add_inverts { + return; + } + + let mut rng = rand::thread_rng(); + if rng.gen_bool(0.5) { + self.state = Self::invert_perm(&self.state); + self.inverted = !self.inverted; + } } pub fn solved(&self) -> bool { @@ -58,6 +130,18 @@ impl Permutation { pub fn get_state(&self) -> Vec { self.state.clone() } + + fn reset_internals(&mut self) { + self.success = self.solved(); + self.metrics.reset(); + self.metrics_values = self.metrics.snapshot(); + self.reward_value = if self.success { 1.0 } else { 0.0 }; + self.inverted = false; + if self.track_solution { + self.solution_inv = Vec::new(); + self.solution = Vec::new(); + } + } } // This implements the necessary functions for the environment @@ -85,32 +169,59 @@ impl Env for Permutation { self.state = state.iter().map(|&x| x as usize).collect(); self.depth = self.max_depth; - self.success = self.solved(); + self.reset_internals(); } fn reset(&mut self) { // Reset the state to the target self.state = (0..self.num_qubits).collect(); - let mut rng = rand::thread_rng(); let action_range = Uniform::new(0, self.num_actions()); // Apply random actions based on the difficulty for _ in 0..self.difficulty { let action = action_range.sample(&mut rng); - self.step(action); + let gate = &self.gateset[action]; + match gate { + Gate::SWAP(q1, q2) => (self.state[*q2], self.state[*q1]) = (self.state[*q1], self.state[*q2]), + _ => {} + } } self.depth = (self.depth_slope * self.difficulty).min(self.max_depth); - self.success = self.solved(); + self.reset_internals(); } fn step(&mut self, action: usize) { - match self.gateset[action] { - Gate::SWAP(q1, q2) => (self.state[q2], self.state[q1]) = (self.state[q1], self.state[q2]), - _ => {} + let mut penalty = 0.0f32; + + if action < self.gateset.len() { + let gate = &self.gateset[action]; + let previous = self.metrics_values.clone(); + self.metrics.apply_gate(gate); + let new_metrics = self.metrics.snapshot(); + penalty = new_metrics.weighted_delta(&previous, &self.metrics_weights); + self.metrics_values = new_metrics; + + match gate { + Gate::SWAP(q1, q2) => (self.state[*q2], self.state[*q1]) = (self.state[*q1], self.state[*q2]), + _ => {} + } + + if self.track_solution { + if self.inverted { + self.solution_inv.push(action); + } else { + self.solution.push(action); + } + } } + + self.maybe_random_invert(); + self.depth = self.depth.saturating_sub(1); // Prevent underflow self.success = self.solved(); + let achieved = if self.success { 1.0 } else { 0.0 }; + self.reward_value = achieved - penalty; } fn masks(&self) -> Vec { @@ -121,17 +232,28 @@ impl Env for Permutation { self.depth == 0 || self.success } - fn reward(&self) -> f32 { - if self.success { - 1.0 - } else { - if self.depth == 0 { -0.5 } else { -0.5/(self.max_depth as f32) } - } - } + fn reward(&self) -> f32 { self.reward_value } + fn success(&self) -> bool { + self.success + } + fn observe(&self,) -> Vec { self.state.iter().enumerate().map(|(i, v)| i * self.num_qubits + v ).collect() } + + fn twists(&self) -> (Vec>, Vec>) { + (self.obs_perms.clone(), self.act_perms.clone()) + } + + fn track_solution(&self) -> bool { self.track_solution } + + fn solution(&self) -> Vec { + let mut out = Vec::with_capacity(self.solution.len() + self.solution_inv.len()); + out.extend_from_slice(&self.solution); + out.extend(self.solution_inv.iter().rev().copied()); + out + } } @@ -141,15 +263,41 @@ pub struct PyPermutationEnv; #[pymethods] impl PyPermutationEnv { #[new] + #[pyo3(signature = ( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + metrics_weights=None, + add_inverts=None, + add_perms=None, + track_solution=None, + ))] pub fn new( num_qubits: usize, difficulty: usize, gateset: Vec, depth_slope: usize, - max_depth: usize + max_depth: usize, + metrics_weights: Option>, + add_inverts: Option, + add_perms: Option, + track_solution: Option, ) -> (Self, PyBaseEnv) { - let env = Permutation::new(num_qubits, difficulty, gateset, depth_slope, max_depth); + let weights = MetricsWeights::from_hashmap(metrics_weights); + let env = Permutation::new( + num_qubits, + difficulty, + gateset, + depth_slope, + max_depth, + weights, + add_inverts.unwrap_or(true), + add_perms.unwrap_or(true), + track_solution.unwrap_or(true) + ); let env = Box::new(env); (PyPermutationEnv, PyBaseEnv { env }) } -} \ No newline at end of file +} diff --git a/rust/src/envs/symmetry.rs b/rust/src/envs/symmetry.rs new file mode 100644 index 0000000..069cb1d --- /dev/null +++ b/rust/src/envs/symmetry.rs @@ -0,0 +1,303 @@ +// -*- coding: utf-8 -*- +/* +(C) Copyright 2025 IBM. All Rights Reserved. + +This code is licensed under the Apache License, Version 2.0. You may +obtain a copy of this license in the LICENSE.txt file in the root directory +of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. + +Any modifications or derivative works of this code must retain this +copyright notice, and modified files need to carry a notice indicating +that they have been altered from the originals. +*/ + +use std::collections::{HashMap, HashSet}; + +use crate::envs::common::Gate; +use petgraph::algo::isomorphism::subgraph_isomorphisms_iter; +use petgraph::graph::Graph; +use petgraph::visit::NodeIndexable; +use petgraph::Directed; + +#[derive(Hash, Eq, PartialEq, Clone, Copy)] +enum GateKind { + H, + S, + Sdg, + SX, + SXdg, + CX, + CZ, + Swap, +} + +#[derive(Hash, Eq, PartialEq, Clone)] +struct GateKey { + kind: GateKind, + qubits: Vec, +} + +fn gate_kind(gate: &Gate) -> GateKind { + match gate { + Gate::H(_) => GateKind::H, + Gate::S(_) => GateKind::S, + Gate::Sdg(_) => GateKind::Sdg, + Gate::SX(_) => GateKind::SX, + Gate::SXdg(_) => GateKind::SXdg, + Gate::CX(_, _) => GateKind::CX, + Gate::CZ(_, _) => GateKind::CZ, + Gate::SWAP(_, _) => GateKind::Swap, + } +} + +fn gate_qubits(gate: &Gate) -> Vec { + match gate { + Gate::H(q) + | Gate::S(q) + | Gate::Sdg(q) + | Gate::SX(q) + | Gate::SXdg(q) => vec![*q], + Gate::CX(q1, q2) + | Gate::CZ(q1, q2) + | Gate::SWAP(q1, q2) => vec![*q1, *q2], + } +} + +fn canonical_key(kind: GateKind, mut qubits: Vec) -> GateKey { + if matches!(kind, GateKind::Swap) { + qubits.sort_unstable(); + } + GateKey { kind, qubits } +} + +fn two_qubit_targets(gate: &Gate) -> Option<(usize, usize)> { + match gate { + Gate::CX(q1, q2) | Gate::CZ(q1, q2) | Gate::SWAP(q1, q2) => Some((*q1, *q2)), + _ => None, + } +} + +fn identity_perm(num_qubits: usize) -> Vec { + (0..num_qubits).collect() +} + +fn all_permutations(num_qubits: usize) -> Vec> { + let mut perm: Vec = (0..num_qubits).collect(); + let mut results = Vec::new(); + + fn heap_permute(k: usize, perm: &mut Vec, results: &mut Vec>) { + if k == 1 { + results.push(perm.clone()); + return; + } + + heap_permute(k - 1, perm, results); + + for i in 0..(k - 1) { + if k % 2 == 0 { + perm.swap(i, k - 1); + } else { + perm.swap(0, k - 1); + } + heap_permute(k - 1, perm, results); + } + } + + if num_qubits == 0 { + results.push(Vec::new()); + } else { + heap_permute(num_qubits, &mut perm, &mut results); + } + + results +} + +fn compute_automorphisms(adjacency: &[Vec], has_edge: bool) -> Vec> { + let n = adjacency.len(); + if n == 0 { + return vec![Vec::new()]; + } + + if !has_edge { + return all_permutations(n); + } + + // Build a directed graph with symmetric edges and use petgraph's VF2 enumerator. + let mut graph = Graph::::new(); + let mut nodes = Vec::with_capacity(n); + for node in 0..n { + nodes.push(graph.add_node(node)); + } + for i in 0..n { + for j in (i + 1)..n { + if adjacency[i][j] { + graph.add_edge(nodes[i], nodes[j], ()); + graph.add_edge(nodes[j], nodes[i], ()); + } + } + } + + let mut results: Vec> = Vec::new(); + let mut node_match = |_: &usize, _: &usize| true; + let mut edge_match = |_: &(), _: &()| true; + + // Use &&graph so G0/G1 are `&Graph`, which implement the required traits for VF2. + let graph_ref = &graph; + if let Some(iter) = + subgraph_isomorphisms_iter(&graph_ref, &graph_ref, &mut node_match, &mut edge_match) + { + for mapping in iter { + if mapping.len() != n { + continue; + } + // mapping indices are compact node indices; translate back to node labels + let mut perm = vec![usize::MAX; n]; + for (from_idx, to_idx) in mapping.into_iter().enumerate() { + let from_node = graph.from_index(from_idx); + let to_node = graph.from_index(to_idx); + let from_label = graph.node_weight(from_node).copied().unwrap_or(0); + let to_label = graph.node_weight(to_node).copied().unwrap_or(0); + perm[from_label] = to_label; + } + if perm.iter().any(|&v| v == usize::MAX) { + continue; + } + results.push(perm); + } + } + + if results.is_empty() { + results.push(identity_perm(n)); + } + + results.sort(); + results.dedup(); + results +} + +fn build_action_perm( + gateset: &[Gate], + gate_index: &HashMap, + perm: &[usize], +) -> Option> { + let mut act_perm: Vec = Vec::with_capacity(gateset.len()); + + for gate in gateset { + let kind = gate_kind(gate); + let mut qubits = gate_qubits(gate); + for q in qubits.iter_mut() { + if *q >= perm.len() { + return None; + } + *q = perm[*q]; + } + let key = canonical_key(kind, qubits); + if let Some(idx) = gate_index.get(&key) { + act_perm.push(*idx); + } else { + return None; + } + } + + Some(act_perm) +} + +fn compute_twists_with_builder( + num_qubits: usize, + gateset: &[Gate], + mut build_obs_perm: F, +) -> (Vec>, Vec>) +where + F: FnMut(&[usize]) -> Vec, +{ + if num_qubits == 0 { + return (Vec::new(), Vec::new()); + } + + let mut gate_index: HashMap = HashMap::new(); + for (idx, gate) in gateset.iter().enumerate() { + let kind = gate_kind(gate); + let qubits = gate_qubits(gate); + let key = canonical_key(kind, qubits); + gate_index.insert(key, idx); + } + + let mut adjacency = vec![vec![false; num_qubits]; num_qubits]; + let mut has_edge = false; + + for gate in gateset { + if let Some((q1, q2)) = two_qubit_targets(gate) { + if q1 != q2 { + adjacency[q1][q2] = true; + adjacency[q2][q1] = true; + has_edge = true; + } + } + } + + let automorphisms = compute_automorphisms(&adjacency, has_edge); + + let mut seen: HashSet> = HashSet::new(); + let mut obs_perms: Vec> = Vec::new(); + let mut act_perms: Vec> = Vec::new(); + + for perm in automorphisms { + if !seen.insert(perm.clone()) { + continue; + } + if let Some(act_perm) = build_action_perm(gateset, &gate_index, &perm) { + obs_perms.push(build_obs_perm(&perm)); + act_perms.push(act_perm); + } + } + + if obs_perms.is_empty() { + let identity = identity_perm(num_qubits); + if let Some(act_perm) = build_action_perm(gateset, &gate_index, &identity) { + obs_perms.push(build_obs_perm(&identity)); + act_perms.push(act_perm); + } + } + + (obs_perms, act_perms) +} + +fn obs_perm_square(num_qubits: usize, perm: &[usize]) -> Vec { + let mut obs_perm = vec![0usize; num_qubits * num_qubits]; + for row in 0..num_qubits { + for col in 0..num_qubits { + let idx_old = row * num_qubits + col; + obs_perm[idx_old] = perm[row] * num_qubits + perm[col]; + } + } + obs_perm +} + +fn obs_perm_clifford(num_qubits: usize, perm: &[usize]) -> Vec { + let dim = 2 * num_qubits; + let mut obs_perm = vec![0usize; dim * dim]; + for row in 0..dim { + let mapped_row = if row < num_qubits { + perm[row] + } else { + num_qubits + perm[row - num_qubits] + }; + for col in 0..dim { + let mapped_col = if col < num_qubits { + perm[col] + } else { + num_qubits + perm[col - num_qubits] + }; + obs_perm[row * dim + col] = mapped_row * dim + mapped_col; + } + } + obs_perm +} + +pub fn compute_twists_square(num_qubits: usize, gateset: &[Gate]) -> (Vec>, Vec>) { + compute_twists_with_builder(num_qubits, gateset, |perm| obs_perm_square(num_qubits, perm)) +} + +pub fn compute_twists_clifford(num_qubits: usize, gateset: &[Gate]) -> (Vec>, Vec>) { + compute_twists_with_builder(num_qubits, gateset, |perm| obs_perm_clifford(num_qubits, perm)) +} diff --git a/src/qiskit_gym/envs/synthesis.py b/src/qiskit_gym/envs/synthesis.py index 19e7ad6..a6fa98f 100644 --- a/src/qiskit_gym/envs/synthesis.py +++ b/src/qiskit_gym/envs/synthesis.py @@ -43,6 +43,9 @@ def from_coupling_map( difficulty: int = 1, depth_slope: int = 2, max_depth: int = 128, + metrics_weights: dict[str, float] | None = None, + add_inverts: bool = True, + add_perms: bool = True, ): if basis_gates is None: basis_gates = tuple(cls.allowed_gates) @@ -72,18 +75,33 @@ def from_coupling_map( "gateset": gateset, "depth_slope": depth_slope, "max_depth": max_depth, + "metrics_weights": metrics_weights, + "add_inverts": add_inverts, + "add_perms": add_perms, } - return cls(**config) + # Filter config to only include parameters accepted by the class __init__ + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self'} + filtered_config = {k: v for k, v in config.items() if k in valid_params} + return cls(**filtered_config) @classmethod def from_json(cls, env_config): - return cls(**env_config) + # Filter config to only include parameters accepted by the class __init__ + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self'} + filtered_config = {k: v for k, v in env_config.items() if k in valid_params} + return cls(**filtered_config) @classmethod @abstractmethod def get_state(cls, input): pass + def post_process_synthesis(self, synth_circuit: QuantumCircuit, input_state): + return synth_circuit # --------------------------------------- # ------------- Env classes ------------- @@ -94,6 +112,22 @@ def get_state(cls, input): CliffordEnv = gym_adapter(qiskit_gym_rs.CliffordEnv) +def _solve_phases(clifford_cpy): + num_qubits = clifford_cpy.num_qubits + out = QuantumCircuit(num_qubits) + + # Add the phases (Pauli gates) to the Clifford circuit + for qubit in range(num_qubits): + stab = clifford_cpy.stab_phase[qubit] + destab = clifford_cpy.destab_phase[qubit] + if destab and stab: + out.y(qubit) + elif not destab and stab: + out.x(qubit) + elif destab and not stab: + out.z(qubit) + + return out class CliffordGym(CliffordEnv, BaseSynthesisEnv): cls_name = "CliffordEnv" @@ -106,6 +140,10 @@ def __init__( difficulty: int = 1, depth_slope: int = 2, max_depth: int = 128, + metrics_weights: dict[str, float] | None = None, + add_inverts: bool = True, + add_perms: bool = True, + track_solution: bool = True, ): super().__init__(**{ "num_qubits": num_qubits, @@ -113,12 +151,24 @@ def __init__( "gateset": gateset, "depth_slope": depth_slope, "max_depth": max_depth, + "metrics_weights": metrics_weights, + "add_inverts": add_inverts, + "add_perms": add_perms, + "track_solution": track_solution, }) def get_state(self, input: QuantumCircuit | Clifford): if isinstance(input, QuantumCircuit): input = Clifford(input) return input.adjoint().tableau[:, :-1].T.flatten().astype(int).tolist() + + def post_process_synthesis(self, synth_circuit: QuantumCircuit, input): + synth_circuit = synth_circuit.inverse() + if isinstance(input, QuantumCircuit): + input = Clifford(input) + dcliff = Clifford(synth_circuit).compose(input) + out = _solve_phases(dcliff).compose(synth_circuit).inverse() + return out # ------------- Linear Function ------------- @@ -138,6 +188,10 @@ def __init__( difficulty: int = 1, depth_slope: int = 2, max_depth: int = 128, + metrics_weights: dict[str, float] | None = None, + add_inverts: bool = True, + add_perms: bool = True, + track_solution: bool = True, ): super().__init__(**{ "num_qubits": num_qubits, @@ -145,15 +199,18 @@ def __init__( "gateset": gateset, "depth_slope": depth_slope, "max_depth": max_depth, + "metrics_weights": metrics_weights, + "add_inverts": add_inverts, + "add_perms": add_perms, + "track_solution": track_solution, }) def get_state(self, input: QuantumCircuit | LinearFunction): - if isinstance(input, QuantumCircuit): - input = LinearFunction(input.inverse()) - elif isinstance(input, LinearFunction): - input = LinearFunction(Clifford(input).adjoint()) + # This returns the inverse permutation to get the right + # synthesized circuit at output, instead of its inverse. + input = LinearFunction(Clifford(input).adjoint()) return np.array(input.linear).flatten().astype(int).tolist() - + # ------------- Permutation ------------- from qiskit.circuit.library.generalized_gates import PermutationGate @@ -172,6 +229,10 @@ def __init__( difficulty: int = 1, depth_slope: int = 2, max_depth: int = 128, + metrics_weights: dict[str, float] | None = None, + add_inverts: bool = True, + add_perms: bool = True, + track_solution: bool = True, ): super().__init__(**{ "num_qubits": num_qubits, @@ -179,6 +240,10 @@ def __init__( "gateset": gateset, "depth_slope": depth_slope, "max_depth": max_depth, + "metrics_weights": metrics_weights, + "add_inverts": add_inverts, + "add_perms": add_perms, + "track_solution": track_solution, }) def get_state(self, input: QuantumCircuit | PermutationGate | Iterable[int]): @@ -187,6 +252,8 @@ def get_state(self, input: QuantumCircuit | PermutationGate | Iterable[int]): elif isinstance(input, PermutationGate): input = input.pattern + # This returns the inverse permutation to get the right + # synthesized circuit at output, instead of its inverse. return np.argsort(np.array(input)).astype(int).tolist() diff --git a/src/qiskit_gym/rl/synthesis.py b/src/qiskit_gym/rl/synthesis.py index 5703f18..42adb99 100644 --- a/src/qiskit_gym/rl/synthesis.py +++ b/src/qiskit_gym/rl/synthesis.py @@ -17,7 +17,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from twisterl.utils import dynamic_import +from twisterl.utils import dynamic_import, load_checkpoint from qiskit_gym.rl.configs import ( AlphaZeroConfig, PPOConfig, @@ -103,9 +103,7 @@ def init_algorithm(self, model_path=None): act_perms=act_perms, ) if model_path is not None: - model.load_state_dict( - torch.load(open(model_path, "rb"), map_location=torch.device("cpu")) - ) + model.load_state_dict(load_checkpoint(model_path)) return self.algorithm_cls( self.env._raw_env, model, self.rl_config.to_json(), None @@ -125,10 +123,12 @@ def synth( state, deterministic, num_searches, num_mcts_searches, C, max_expand_depth ) if actions is not None: - return gate_list_to_circuit( + synth_circuit = gate_list_to_circuit( [self.env_config["gateset"][a] for a in actions], num_qubits=self.env.config["num_qubits"], ) + synth_circuit = self.env.post_process_synthesis(synth_circuit, input) + return synth_circuit def learn(self, initial_difficulty=1, num_iterations=int(1e10), tb_path=None): if tb_path is not None: