11from typing import Any
2- from dataclasses import dataclass
2+ from dataclasses import field , dataclass
33
44import numpy as np
55from kirin .dialects import ilist
@@ -48,11 +48,11 @@ def get_method_name(self, adjoint: bool, control: bool) -> str:
4848
4949 return method_name + self .method_name
5050
51- def apply (self , qubits : PyQrackQubit , adjoint : bool = False ) -> None :
51+ def apply (self , qubit : PyQrackQubit , adjoint : bool = False ) -> None :
5252 if not qubit .is_active ():
5353 return
5454 method_name = self .get_method_name (adjoint = adjoint , control = False )
55- getattr (qubits .sim_reg , method_name )(qubits .addr )
55+ getattr (qubit .sim_reg , method_name )(qubit .addr )
5656
5757 def control_apply (self , * qubits : PyQrackQubit , adjoint : bool = False ) -> None :
5858 ctrls = [qbit .addr for qbit in qubits [:- 1 ]]
@@ -232,12 +232,9 @@ def mat(self, adjoint: bool) -> list[complex]:
232232class RotRuntime (OperatorRuntimeABC ):
233233 axis : OperatorRuntimeABC
234234 angle : float
235+ pyqrack_axis : Pauli = field (init = False )
235236
236- def apply (self , * qubits : PyQrackQubit , adjoint : bool = False ) -> None :
237- sign = (- 1 ) ** adjoint
238- angle = sign * self .angle
239- target = qubits [- 1 ]
240-
237+ def __post_init__ (self ):
241238 if not isinstance (self .axis , OperatorRuntime ):
242239 raise RuntimeError (
243240 f"Rotation only supported for Pauli operators! Got { self .axis } "
@@ -250,7 +247,15 @@ def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
250247 f"Rotation only supported for Pauli operators! Got { self .axis } "
251248 )
252249
253- target .sim_reg .r (axis , angle , target .addr )
250+ # NOTE: weird setattr for frozen dataclasses
251+ object .__setattr__ (self , "pyqrack_axis" , axis )
252+
253+ def apply (self , * qubits : PyQrackQubit , adjoint : bool = False ) -> None :
254+ sign = (- 1 ) ** adjoint
255+ angle = sign * self .angle
256+ target = qubits [- 1 ]
257+
258+ target .sim_reg .r (self .pyqrack_axis , angle , target .addr )
254259
255260 def control_apply (self , * qubits : PyQrackQubit , adjoint : bool = False ) -> None :
256261 sign = (- 1 ) ** (not adjoint )
@@ -259,19 +264,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
259264 ctrls = [qbit .addr for qbit in qubits [:- 1 ]]
260265 target = qubits [- 1 ]
261266
262- if not isinstance (self .axis , OperatorRuntime ):
263- raise RuntimeError (
264- f"Rotation only supported for Pauli operators! Got { self .axis } "
265- )
266-
267- try :
268- axis = getattr (Pauli , "Pauli" + self .axis .method_name .upper ())
269- except KeyError :
270- raise RuntimeError (
271- f"Rotation only supported for Pauli operators! Got { self .axis } "
272- )
273-
274- target .sim_reg .mcr (axis , angle , ctrls , target .addr )
267+ target .sim_reg .mcr (self .pyqrack_axis , angle , ctrls , target .addr )
275268
276269
277270@dataclass (frozen = True )
0 commit comments