Skip to content

Commit 0ff58ca

Browse files
committed
Fix address analysis
1 parent b98d463 commit 0ff58ca

File tree

3 files changed

+34
-33
lines changed

3 files changed

+34
-33
lines changed

src/bloqade/analysis/address/analysis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def run_lattice(
119119
case ConstResult(const.Value(ir.Method() as method)):
120120
_, ret = self.call(
121121
method.code,
122+
self.method_self(method),
122123
*inputs,
123124
# **kwargs,
124125
# **{k: v for k, v in zip(kwargs, frame.get_values(stmt.kwargs))},

src/bloqade/analysis/address/impls.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,10 @@ def invoke(
180180
frame: ForwardFrame[Address],
181181
stmt: func.Invoke,
182182
):
183-
184-
args = interp_.permute_values(
185-
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
186-
)
187-
_, ret = interp_.run_method(
188-
stmt.callee,
189-
args,
183+
_, ret = interp_.call(
184+
stmt.callee.code,
185+
interp_.method_self(stmt.callee),
186+
*frame.get_values(stmt.inputs),
190187
)
191188

192189
return (ret,)
@@ -319,26 +316,28 @@ def ifelse(
319316
):
320317
body = stmt.then_body if const_cond.data else stmt.else_body
321318
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
322-
ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
319+
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
323320
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
324321
return ret
325322
else:
326323
# run both branches
327324
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
328-
then_results = interp_.run_ssacfg_region(
329-
then_frame, stmt.then_body, (address_cond,)
330-
)
331-
interp_.set_values(
332-
frame, then_frame.entries.keys(), then_frame.entries.values()
325+
then_results = interp_.frame_call_region(
326+
then_frame,
327+
stmt,
328+
stmt.then_body,
329+
address_cond,
333330
)
331+
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
334332

335333
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
336-
else_results = interp_.run_ssacfg_region(
337-
else_frame, stmt.else_body, (address_cond,)
338-
)
339-
interp_.set_values(
340-
frame, else_frame.entries.keys(), else_frame.entries.values()
334+
else_results = interp_.frame_call_region(
335+
else_frame,
336+
stmt,
337+
stmt.else_body,
338+
address_cond,
341339
)
340+
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
342341
# TODO: pick the non-return value
343342
if isinstance(then_results, interp.ReturnValue) and isinstance(
344343
else_results, interp.ReturnValue
@@ -364,12 +363,12 @@ def for_loop(
364363
iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
365364

366365
if iter_type is None:
367-
return interp_.eval_stmt_fallback(frame, stmt)
366+
return interp_.eval_fallback(frame, stmt)
368367

369368
for value in iterable:
370369
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371-
loop_vars = interp_.run_ssacfg_region(
372-
body_frame, stmt.body, (value,) + loop_vars
370+
loop_vars = interp_.frame_call_region(
371+
body_frame, stmt, stmt.body, value, *loop_vars
373372
)
374373

375374
if loop_vars is None:

test/analysis/address/test_qubit_analysis.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test():
2121
return (q1[1], q2)
2222

2323
address_analysis = address.AddressAnalysis(test.dialects)
24-
frame, _ = address_analysis.run_analysis(test, no_raise=False)
24+
frame, _ = address_analysis.run(test)
2525
address_types = collect_address_types(frame, address.PartialTuple)
2626

2727
test.print(analysis=frame.entries)
@@ -116,7 +116,7 @@ def main():
116116
return q
117117

118118
address_analysis = address.AddressAnalysis(main.dialects)
119-
address_analysis.run_analysis(main, no_raise=False)
119+
address_analysis.run(main)
120120

121121

122122
def test_new_qubit():
@@ -125,7 +125,7 @@ def main():
125125
return squin.qubit.new()
126126

127127
address_analysis = address.AddressAnalysis(main.dialects)
128-
_, result = address_analysis.run_analysis(main, no_raise=False)
128+
_, result = address_analysis.run(main)
129129
assert result == address.AddressQubit(0)
130130

131131

@@ -139,8 +139,9 @@ def main(n: int):
139139
return qreg
140140

141141
address_analysis = address.AddressAnalysis(main.dialects)
142-
frame, result = address_analysis.run_analysis(
143-
main, args=(address.ConstResult(const.Unknown()),), no_raise=False
142+
frame, result = address_analysis.run(
143+
main,
144+
address.ConstResult(const.Unknown()),
144145
)
145146
assert result == address.AddressReg(data=tuple(range(4)))
146147

@@ -155,7 +156,7 @@ def main(n: int):
155156
return qreg
156157

157158
address_analysis = address.AddressAnalysis(main.dialects)
158-
frame, result = address_analysis.run_analysis(main, no_raise=False)
159+
frame, result = address_analysis.run(main)
159160
assert result == address.AddressReg(data=tuple(range(4)))
160161

161162

@@ -165,7 +166,7 @@ def main(n: int):
165166
return (0, 1) + (2, n)
166167

167168
address_analysis = address.AddressAnalysis(main.dialects)
168-
frame, result = address_analysis.run_analysis(main, no_raise=False)
169+
frame, result = address_analysis.run(main)
169170

170171
assert result == address.PartialTuple(
171172
data=(
@@ -183,7 +184,7 @@ def main(n: int):
183184
return (0, 1) + [2, n] # type: ignore
184185

185186
address_analysis = address.AddressAnalysis(main.dialects)
186-
frame, result = address_analysis.run_analysis(main, no_raise=False)
187+
frame, result = address_analysis.run(main)
187188

188189
assert result == address.Bottom()
189190

@@ -194,7 +195,7 @@ def main(n: tuple[int, ...]):
194195
return (0, 1) + n
195196

196197
address_analysis = address.AddressAnalysis(main.dialects)
197-
frame, result = address_analysis.run_analysis(main, no_raise=False)
198+
frame, result = address_analysis.run(main)
198199

199200
assert result == address.Unknown()
200201

@@ -207,7 +208,7 @@ def main(q: qubit.Qubit):
207208
return (0, q, 2, q)[1::2]
208209

209210
address_analysis = address.AddressAnalysis(main.dialects)
210-
frame, result = address_analysis.run_analysis(main, no_raise=False)
211+
frame, result = address_analysis.run(main)
211212
assert result == address.UnknownReg()
212213

213214

@@ -219,7 +220,7 @@ def main(n: int):
219220

220221
main.print()
221222
address_analysis = address.AddressAnalysis(main.dialects)
222-
frame, result = address_analysis.run_analysis(main, no_raise=False)
223+
frame, result = address_analysis.run(main)
223224
main.print(analysis=frame.entries)
224225
assert (
225226
result == address.UnknownReg()
@@ -260,7 +261,7 @@ def main():
260261

261262
func = main
262263
analysis = address.AddressAnalysis(squin.kernel)
263-
_, ret = analysis.run_analysis(func, no_raise=False)
264+
_, ret = analysis.run(func)
264265

265266
assert ret == address.AddressReg(data=tuple(range(20)))
266267
assert analysis.qubit_count == 20

0 commit comments

Comments
 (0)