Skip to content

Commit b2f1238

Browse files
committed
remove subclasses
1 parent 52e8828 commit b2f1238

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

src/bloqade/squin/qubit.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,18 @@
2020
dialect = ir.Dialect("squin.qubit")
2121

2222

23-
@statement
24-
class MultiQubitStatement(ir.Statement):
23+
@statement(dialect=dialect)
24+
class Apply(ir.Statement):
2525
traits = frozenset({lowering.FromPythonCall()})
2626
operator: ir.SSAValue = info.argument(OpType)
2727
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
2828

2929

3030
@statement(dialect=dialect)
31-
class Apply(MultiQubitStatement):
32-
pass
33-
34-
35-
@statement(dialect=dialect)
36-
class Broadcast(MultiQubitStatement):
37-
pass
31+
class Broadcast(ir.Statement):
32+
traits = frozenset({lowering.FromPythonCall()})
33+
operator: ir.SSAValue = info.argument(OpType)
34+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
3835

3936

4037
@statement(dialect=dialect)

src/bloqade/squin/wire.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ class Unwrap(ir.Statement):
4848
result: ir.ResultValue = info.result(WireType)
4949

5050

51-
@statement
52-
class MultiWireStatement(ir.Statement):
51+
# In Quake, you put a wire in and get a wire out when you "apply" an operator
52+
# In this case though we just need to indicate that an operator is applied to list[wires]
53+
@statement(dialect=dialect)
54+
class Apply(ir.Statement):
5355
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
5456
operator: ir.SSAValue = info.argument(OpType)
5557
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
@@ -66,18 +68,22 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
6668
) # custom lowering required for wrapper to work here
6769

6870

69-
# In Quake, you put a wire in and get a wire out when you "apply" an operator
70-
# In this case though we just need to indicate that an operator is applied to list[wires]
7171
@statement(dialect=dialect)
72-
class Apply(MultiWireStatement): # apply(op, w1, w2, ...)
73-
def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
74-
super().__init__(operator, *args)
75-
72+
class Broadcast(ir.Statement):
73+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
74+
operator: ir.SSAValue = info.argument(OpType)
75+
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
7676

77-
@statement(dialect=dialect)
78-
class Broadcast(MultiWireStatement):
7977
def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
80-
super().__init__(operator, *args)
78+
result_types = tuple(WireType for _ in args)
79+
super().__init__(
80+
args=(operator,) + args,
81+
result_types=result_types, # result types of the Apply statement, should all be WireTypes
82+
args_slice={
83+
"operator": 0,
84+
"inputs": slice(1, None),
85+
}, # pretty printing + syntax sugar
86+
) # custom lowering required for wrapper to work here
8187

8288

8389
# NOTE: measurement cannot be pure because they will collapse the state

0 commit comments

Comments
 (0)