55from bloqade import qubit , annotate
66
77from .lattice import (
8+ Predicate ,
89 AnyMeasureId ,
910 NotMeasureId ,
11+ RawMeasureId ,
1012 MeasureIdBool ,
1113 MeasureIdTuple ,
1214 InvalidMeasureId ,
1315)
1416from .analysis import MeasureIDFrame , MeasurementIDAnalysis
1517
16- ## Can't do wire right now because of
17- ## unresolved RFC on return type
18- # from bloqade.squin import wire
19-
2018
2119@qubit .dialect .register (key = "measure_id" )
2220class SquinQubit (interp .MethodTable ):
@@ -30,7 +28,6 @@ def measure_qubit_list(
3028 ):
3129
3230 # try to get the length of the list
33- ## "...safely assume the type inference will give you what you need"
3431 qubits_type = stmt .qubits .type
3532 # vars[0] is just the type of the elements in the ilist,
3633 # vars[1] can contain a literal with length information
@@ -41,10 +38,41 @@ def measure_qubit_list(
4138 measure_id_bools = []
4239 for _ in range (num_qubits .data ):
4340 interp .measure_count += 1
44- measure_id_bools .append (MeasureIdBool (interp .measure_count ))
41+ measure_id_bools .append (RawMeasureId (interp .measure_count ))
4542
4643 return (MeasureIdTuple (data = tuple (measure_id_bools )),)
4744
45+ @interp .impl (qubit .stmts .IsLost )
46+ @interp .impl (qubit .stmts .IsOne )
47+ @interp .impl (qubit .stmts .IsZero )
48+ def measurement_predicate (
49+ self ,
50+ interp : MeasurementIDAnalysis ,
51+ frame : interp .Frame ,
52+ stmt : qubit .stmts .IsLost | qubit .stmts .IsOne | qubit .stmts .IsZero ,
53+ ):
54+ original_measure_id_tuple = frame .get (stmt .measurements )
55+ if not all (
56+ isinstance (measure_id , RawMeasureId )
57+ for measure_id in original_measure_id_tuple .data
58+ ):
59+ return (InvalidMeasureId (),)
60+
61+ if isinstance (stmt , qubit .stmts .IsLost ):
62+ predicate = Predicate .IS_LOST
63+ elif isinstance (stmt , qubit .stmts .IsOne ):
64+ predicate = Predicate .IS_ONE
65+ elif isinstance (stmt , qubit .stmts .IsZero ):
66+ predicate = Predicate .IS_ZERO
67+ else :
68+ return (InvalidMeasureId (),)
69+
70+ predicate_measure_ids = [
71+ MeasureIdBool (measure_id .idx , predicate )
72+ for measure_id in original_measure_id_tuple .data
73+ ]
74+ return (MeasureIdTuple (data = tuple (predicate_measure_ids )),)
75+
4876
4977@annotate .dialect .register (key = "measure_id" )
5078class Annotate (interp .MethodTable ):
@@ -94,14 +122,10 @@ def getitem(
94122 self , interp : MeasurementIDAnalysis , frame : interp .Frame , stmt : py .GetItem
95123 ):
96124
97- idx_or_slice = interp .get_const_value ( (int , slice ), stmt . index )
125+ idx_or_slice = interp .maybe_const ( stmt . index , (int , slice ))
98126 if idx_or_slice is None :
99127 return (InvalidMeasureId (),)
100128
101- # hint = stmt.index.hints.get("const")
102- # if hint is None or not isinstance(hint, const.Value):
103- # return (InvalidMeasureId(),)
104-
105129 obj = frame .get (stmt .obj )
106130 if isinstance (obj , MeasureIdTuple ):
107131 if isinstance (idx_or_slice , slice ):
0 commit comments