Skip to content

Commit c36c8b1

Browse files
committed
Fixing interface for extracting iterable values
1 parent 5cc6a5e commit c36c8b1

File tree

2 files changed

+45
-38
lines changed

2 files changed

+45
-38
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from typing import Any
66
from itertools import chain
7-
from collections.abc import Iterable
87

98
from kirin import ir, interp
109
from kirin.analysis import ForwardFrame, const
@@ -57,14 +56,20 @@ def from_literal(literal: Any) -> Address:
5756
return ConstResult(const.Value(literal))
5857

5958
match collection:
60-
case PartialIList(data) | PartialTuple(data):
61-
return data
59+
case PartialIList(data):
60+
return PartialIList, data
61+
case PartialTuple(data):
62+
return PartialTuple, data
6263
case AddressReg():
63-
return collection.qubits
64-
case ConstResult(const.Value(data)) if isinstance(data, Iterable):
65-
return tuple(map(from_literal, data))
64+
return PartialIList, collection.qubits
65+
case ConstResult(const.Value(IList() as data)):
66+
return PartialIList, tuple(map(from_literal, data))
67+
case ConstResult(const.Value(tuple() as data)):
68+
return PartialTuple, tuple(map(from_literal, data))
6669
case ConstResult(const.PartialTuple(data)):
67-
return tuple(map(from_constant, data))
70+
return PartialTuple, tuple(map(from_constant, data))
71+
case _:
72+
return None, ()
6873

6974

7075
@py.constant.dialect.register(key="qubit.address")
@@ -90,36 +95,17 @@ def add(
9095
):
9196
lhs = frame.get(stmt.lhs)
9297
rhs = frame.get(stmt.rhs)
93-
print(lhs, rhs)
94-
match lhs, rhs:
95-
case (PartialTuple(lhs_data), PartialTuple(rhs_data)) | (
96-
PartialIList(lhs_data),
97-
PartialIList(rhs_data),
98-
):
99-
return (lhs.new(lhs_data + rhs_data),)
100-
case (AddressReg(lhs_data), AddressReg(rhs_data)):
101-
return (AddressReg(tuple(chain(lhs_data, rhs_data))),)
10298

103-
lhs_constant = lhs.result if isinstance(lhs, ConstResult) else const.Unknown()
104-
rhs_constant = rhs.result if isinstance(rhs, ConstResult) else const.Unknown()
99+
lhs_type, lhs_values = self.get_values(lhs)
100+
rhs_type, rhs_values = self.get_values(rhs)
105101

106-
match (lhs, rhs_constant):
107-
case PartialIList() | AddressReg(), const.Value(IList() as lst) if (
108-
len(lst) == 0
109-
):
110-
return (lhs,)
111-
case PartialTuple(), const.Value(()):
112-
return (lhs,)
102+
if lhs_type is None or rhs_type is None:
103+
return (Address.top(),)
113104

114-
match (lhs_constant, rhs):
115-
case const.Value(IList() as lst), PartialIList() | AddressReg() if (
116-
len(lst) == 0
117-
):
118-
return (rhs,)
119-
case const.Value(()), PartialTuple():
120-
return (rhs,)
105+
if lhs_type is not rhs_type:
106+
return (Address.bottom(),)
121107

122-
return interp_.eval_stmt_fallback(frame, stmt)
108+
return (lhs_type(tuple(chain(lhs_values, rhs_values))),)
123109

124110

125111
@py.tuple.dialect.register(key="qubit.address")
@@ -156,11 +142,14 @@ def map_(
156142
results = []
157143
fn = frame.get(stmt.fn)
158144
collection = frame.get(stmt.collection)
159-
iterable = self.get_values(collection)
145+
iterable_type, iterable = self.get_values(collection)
160146

161-
if iterable is None:
147+
if iterable_type is None:
162148
return (Address.top(),)
163149

150+
if iterable_type is not PartialIList:
151+
return (Address.bottom(),)
152+
164153
results = []
165154
for ele in iterable:
166155
ret = self.call_function(interp_, fn, (ele,), ())
@@ -429,9 +418,9 @@ def for_loop(
429418
stmt: scf.For,
430419
):
431420
loop_vars = frame.get_values(stmt.initializers)
432-
iterable = self.get_values(frame.get(stmt.iterable))
421+
iter_type, iterable = self.get_values(frame.get(stmt.iterable))
433422

434-
if iterable is None:
423+
if iter_type is None:
435424
return interp_.eval_stmt_fallback(frame, stmt)
436425

437426
for value in iterable:

test/analysis/address/test_qubit_analysis.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,26 @@ def main(n: int):
159159
assert result == address.AddressReg(data=tuple(range(4)))
160160

161161

162+
def test_partial_tuple_add():
163+
@squin.kernel
164+
def main(n: int):
165+
return (0, 1) + (2, n)
166+
167+
address_analysis = address.AddressAnalysis(main.dialects)
168+
frame, result = address_analysis.run_analysis(main, no_raise=False)
169+
170+
assert result == address.PartialTuple(
171+
data=(
172+
address.ConstResult(const.Value(0)),
173+
address.ConstResult(const.Value(1)),
174+
address.ConstResult(const.Value(2)),
175+
address.Unknown(),
176+
)
177+
)
178+
179+
162180
if __name__ == "__main__":
163-
test_partial_tuple()
181+
test_partial_tuple_add()
164182

165183

166184
@pytest.mark.xfail

0 commit comments

Comments
 (0)