@@ -3,7 +3,7 @@ use nalgebra::DMatrix;
33use nalgebra_sparse:: { coo:: CooMatrix , csr:: CsrMatrix } ;
44use num_complex:: Complex ;
55
6- use crate :: gates:: { h_matrix, x_matrix} ;
6+ use crate :: gates:: { h_matrix, rx_matrix , ry_matrix , rz_matrix , x_matrix} ;
77use crate :: qstate:: QState ;
88use crate :: Qbit ;
99
@@ -12,16 +12,32 @@ enum Gate {
1212 Sparse ( CsrMatrix < Qbit > ) ,
1313}
1414
15+ pub enum ParameterizedGate {
16+ RX ,
17+ RY ,
18+ RZ ,
19+ }
20+
21+ struct Parameter {
22+ gate_index : usize ,
23+ qbit_index : usize ,
24+ gate : ParameterizedGate ,
25+ value : f64 ,
26+ }
27+
1528pub struct Circuit {
1629 gates : Vec < Gate > ,
1730 num_of_qbits : usize ,
31+
32+ parameters : Vec < Parameter > ,
1833}
1934
2035impl Circuit {
2136 pub fn new ( num_of_qbits : usize ) -> Self {
2237 Self {
2338 gates : Vec :: new ( ) ,
2439 num_of_qbits,
40+ parameters : Vec :: new ( ) ,
2541 }
2642 }
2743
@@ -54,6 +70,19 @@ impl Circuit {
5470 Ok ( matrix)
5571 }
5672
73+ fn create_parametric_gate_for_index (
74+ & self ,
75+ param : & Parameter ,
76+ value : f64 ,
77+ ) -> Result < CsrMatrix < Qbit > > {
78+ let gate = match param. gate {
79+ ParameterizedGate :: RX => rx_matrix ( value) ,
80+ ParameterizedGate :: RY => ry_matrix ( value) ,
81+ ParameterizedGate :: RZ => rz_matrix ( value) ,
82+ } ;
83+ self . create_gate_for_index ( param. qbit_index , & gate)
84+ }
85+
5786 pub fn gate_at ( mut self , index : usize , gate : CsrMatrix < Qbit > ) -> Result < Self > {
5887 let gate = self . create_gate_for_index ( index, & gate) ?;
5988 self . add_gate ( gate) ;
@@ -66,6 +95,60 @@ impl Circuit {
6695 Ok ( ( ) )
6796 }
6897
98+ pub fn add_parametric_gate_at (
99+ & mut self ,
100+ index : usize ,
101+ gate : ParameterizedGate ,
102+ value : f64 ,
103+ ) -> Result < ( ) > {
104+ let param = Parameter {
105+ gate_index : self . gates . len ( ) ,
106+ qbit_index : index,
107+ gate,
108+ value,
109+ } ;
110+ let gate = self . create_parametric_gate_for_index ( & param, value) ?;
111+
112+ self . parameters . push ( param) ;
113+ self . add_gate ( gate) ;
114+
115+ Ok ( ( ) )
116+ }
117+
118+ pub fn get_parameters ( & self ) -> Vec < f64 > {
119+ self . parameters . iter ( ) . map ( |param| param. value ) . collect ( )
120+ }
121+
122+ pub fn set_parameter ( & mut self , param_index : usize , value : f64 ) -> Result < ( ) > {
123+ if let Some ( param) = self . parameters . get_mut ( param_index) {
124+ param. value = value;
125+ } else {
126+ return Err ( anyhow:: anyhow!( "Parameter index out of bounds" ) ) ;
127+ } ;
128+
129+ // No index check is needed
130+ let param = & self . parameters [ param_index] ;
131+
132+ let gate = self . create_parametric_gate_for_index ( param, value) ?;
133+ self . gates [ param. gate_index ] = Gate :: Sparse ( gate) ;
134+
135+ Ok ( ( ) )
136+ }
137+
138+ pub fn set_parameters ( & mut self , values : & [ f64 ] ) -> Result < ( ) > {
139+ if values. len ( ) != self . parameters . len ( ) {
140+ return Err ( anyhow:: anyhow!(
141+ "Number of values does not match number of parameters"
142+ ) ) ;
143+ }
144+
145+ for ( i, & value) in values. iter ( ) . enumerate ( ) {
146+ self . set_parameter ( i, value) ?;
147+ }
148+
149+ Ok ( ( ) )
150+ }
151+
69152 #[ allow( non_snake_case) ]
70153 pub fn H ( self , index : usize ) -> Result < Self > {
71154 self . gate_at ( index, h_matrix ( ) )
@@ -176,6 +259,8 @@ pub fn kronecker_product(x: &CsrMatrix<Qbit>, y: &CsrMatrix<Qbit>) -> CsrMatrix<
176259
177260#[ cfg( test) ]
178261mod tests {
262+ use std:: f64:: consts:: PI ;
263+
179264 use crate :: {
180265 assert_approx_complex_eq,
181266 gates:: { s_matrix, t_matrix} ,
@@ -249,4 +334,35 @@ mod tests {
249334
250335 Ok ( ( ) )
251336 }
337+
338+ #[ test]
339+ fn test_parameterized_gate ( ) -> Result < ( ) > {
340+ let q00 = QState :: from_str ( "00" ) . unwrap ( ) ;
341+ let mut circuit = Circuit :: new ( q00. num_of_qbits ( ) ) ;
342+ circuit. add_parametric_gate_at ( 0 , ParameterizedGate :: RX , PI ) ?;
343+
344+ let result = circuit. apply ( & q00) ;
345+
346+ assert_approx_complex_eq ! ( 0.0 , 0.0 , result. state[ 0 ] ) ;
347+ assert_approx_complex_eq ! ( 0.0 , -1.0 , result. state[ 1 ] ) ;
348+
349+ // Update the parameter to PI/2
350+ let mut param = circuit. get_parameters ( ) ;
351+ assert_eq ! ( 1 , param. len( ) ) ;
352+ assert_eq ! ( PI , param[ 0 ] ) ;
353+
354+ param[ 0 ] = PI / 2.0 ;
355+ circuit. set_parameters ( & param) ?;
356+
357+ let param = circuit. get_parameters ( ) ;
358+ assert_eq ! ( 1 , param. len( ) ) ;
359+ assert_eq ! ( PI / 2.0 , param[ 0 ] ) ;
360+
361+ let result = circuit. apply ( & q00) ;
362+
363+ assert_approx_complex_eq ! ( 1.0 / 2f64 . sqrt( ) , 0.0 , result. state[ 0 ] ) ;
364+ assert_approx_complex_eq ! ( 0.0 , -1.0 / 2f64 . sqrt( ) , result. state[ 1 ] ) ;
365+
366+ Ok ( ( ) )
367+ }
252368}
0 commit comments