Skip to content

Commit 3d97bb2

Browse files
committed
Fix types in qbraid lowering
1 parent 519f016 commit 3d97bb2

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

src/bloqade/qasm2/dialects/expr/stmts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def print_impl(self, printer: Printer) -> None:
8787

8888

8989
# QASM 2.0 arithmetic operations
90-
PyNum = types.Union(types.Int, types.Float)
90+
PyNum = types.TypeVar("PyNum", bound=types.Union(types.Int, types.Float))
9191

9292

9393
@statement(dialect=dialect)
@@ -110,7 +110,7 @@ class Sin(ir.Statement):
110110
traits = frozenset({lowering.FromPythonCall()})
111111
value: ir.SSAValue = info.argument(PyNum)
112112
"""value (Union[int, float]): The number to take the sine of."""
113-
result: ir.ResultValue = info.result(PyNum)
113+
result: ir.ResultValue = info.result(types.Float)
114114
"""result (float): The sine of the number."""
115115

116116

@@ -122,7 +122,7 @@ class Cos(ir.Statement):
122122
traits = frozenset({lowering.FromPythonCall()})
123123
value: ir.SSAValue = info.argument(PyNum)
124124
"""value (Union[int, float]): The number to take the cosine of."""
125-
result: ir.ResultValue = info.result(PyNum)
125+
result: ir.ResultValue = info.result(types.Float)
126126
"""result (float): The cosine of the number."""
127127

128128

@@ -134,7 +134,7 @@ class Tan(ir.Statement):
134134
traits = frozenset({lowering.FromPythonCall()})
135135
value: ir.SSAValue = info.argument(PyNum)
136136
"""value (Union[int, float]): The number to take the tangent of."""
137-
result: ir.ResultValue = info.result(PyNum)
137+
result: ir.ResultValue = info.result(types.Float)
138138
"""result (float): The tangent of the number."""
139139

140140

@@ -146,7 +146,7 @@ class Exp(ir.Statement):
146146
traits = frozenset({lowering.FromPythonCall()})
147147
value: ir.SSAValue = info.argument(PyNum)
148148
"""value (Union[int, float]): The number to take the exponential of."""
149-
result: ir.ResultValue = info.result(PyNum)
149+
result: ir.ResultValue = info.result(types.Float)
150150
"""result (float): The exponential of the number."""
151151

152152

@@ -158,7 +158,7 @@ class Log(ir.Statement):
158158
traits = frozenset({lowering.FromPythonCall()})
159159
value: ir.SSAValue = info.argument(PyNum)
160160
"""value (Union[int, float]): The number to take the natural log of."""
161-
result: ir.ResultValue = info.result(PyNum)
161+
result: ir.ResultValue = info.result(types.Float)
162162
"""result (float): The natural log of the number."""
163163

164164

@@ -170,7 +170,7 @@ class Sqrt(ir.Statement):
170170
traits = frozenset({lowering.FromPythonCall()})
171171
value: ir.SSAValue = info.argument(PyNum)
172172
"""value (Union[int, float]): The number to take the square root of."""
173-
result: ir.ResultValue = info.result(PyNum)
173+
result: ir.ResultValue = info.result(types.Float)
174174
"""result (float): The square root of the number."""
175175

176176

src/bloqade/qbraid/lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,5 +320,6 @@ def lower_full_turns(self, value: float) -> ir.SSAValue:
320320
self.block_list.append(const_pi)
321321
turns = self.lower_number(2 * value)
322322
mul = qasm2.expr.Mul(const_pi.result, turns)
323+
mul.result.type = types.Float
323324
self.block_list.append(mul)
324325
return mul.result

src/bloqade/qbraid/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ def decompiled_circuit(self) -> str:
238238
str: The decompiled circuit from hardware execution.
239239
240240
"""
241-
from bloqade.noise import native
242241
from bloqade.qasm2.emit import QASM2
243242
from bloqade.qasm2.passes import glob, parallel
243+
from bloqade.qasm2.rewrite.noise import remove_noise
244244

245245
mt = self.lower_noise_model("method")
246246

247-
native.RemoveNoisePass(mt.dialects)(mt)
247+
remove_noise.RemoveNoisePass(mt.dialects)(mt)
248248
parallel.ParallelToUOp(mt.dialects)(mt)
249249
glob.GlobalToUOP(mt.dialects)(mt)
250250
return QASM2(qelib1=True).emit_str(mt)

test/qbraid/test_lowering.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,8 @@ def run_assert(noise_model: schema.NoiseModel, expected_stmts: List[ir.Statement
5656
)
5757

5858
expected_mt = ir.Method(
59-
mod=None,
60-
py_func=None,
6159
dialects=lowering.qbraid_noise,
6260
sym_name="test",
63-
arg_names=[],
6461
code=expected_func_stmt,
6562
)
6663

@@ -242,7 +239,10 @@ def test_lowering_global_w():
242239
(lam_num := as_float(2 * -(0.5 + phi_val))),
243240
(lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)),
244241
parallel.UGate(
245-
theta=theta.result, phi=phi.result, lam=lam.result, qargs=qargs.result
242+
theta=ir.ResultValue(theta, 0, type=types.Float),
243+
phi=ir.ResultValue(phi, 0, type=types.Float),
244+
lam=ir.ResultValue(lam, 0, type=types.Float),
245+
qargs=qargs.result,
246246
),
247247
func.Return(creg.result),
248248
]
@@ -304,7 +304,10 @@ def test_lowering_local_w():
304304
(lam_num := as_float(2 * -(0.5 + phi_val))),
305305
(lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)),
306306
parallel.UGate(
307-
qargs=qargs.result, theta=theta.result, phi=phi.result, lam=lam.result
307+
qargs=qargs.result,
308+
theta=ir.ResultValue(theta, 0, type=types.Float),
309+
phi=ir.ResultValue(phi, 0, type=types.Float),
310+
lam=ir.ResultValue(lam, 0, type=types.Float),
308311
),
309312
func.Return(creg.result),
310313
]
@@ -348,7 +351,9 @@ def test_lowering_global_rz():
348351
(theta_pi := qasm2.expr.ConstPI()),
349352
(theta_num := as_float(2 * phi_val)),
350353
(theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)),
351-
parallel.RZ(theta=theta.result, qargs=qargs.result),
354+
parallel.RZ(
355+
theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result
356+
),
352357
func.Return(creg.result),
353358
]
354359

@@ -401,7 +406,9 @@ def test_lowering_local_rz():
401406
(theta_pi := qasm2.expr.ConstPI()),
402407
(theta_num := as_float(2 * phi_val)),
403408
(theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)),
404-
parallel.RZ(theta=theta.result, qargs=qargs.result),
409+
parallel.RZ(
410+
theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result
411+
),
405412
func.Return(creg.result),
406413
]
407414

0 commit comments

Comments
 (0)