Skip to content

Commit 966712a

Browse files
committed
fixing bug in call mix-in
1 parent 522174d commit 966712a

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Address,
1616
AddressReg,
1717
ConstResult,
18+
AddressQubit,
1819
PartialIList,
1920
PartialTuple,
2021
PartialLambda,
@@ -31,13 +32,14 @@ def call_function(
3132
inputs: tuple[Address, ...],
3233
kwargs: tuple[str, ...],
3334
) -> Address:
35+
3436
match callee:
3537
case PartialLambda(code=code, argnames=argnames):
3638
_, ret = interp_.run_callable(
3739
code, (callee,) + interp_.permute_values(argnames, inputs, kwargs)
3840
)
3941
return ret
40-
case ConstResult(constant=const.Value(data=ir.Method() as method)):
42+
case ConstResult(const.Value(ir.Method() as method)):
4143
_, ret = interp_.run_method(
4244
method,
4345
interp_.permute_values(method.arg_names, inputs, kwargs),
@@ -58,6 +60,8 @@ def from_literal(literal: Any) -> Address:
5860
match collection:
5961
case PartialIList(data) | PartialTuple(data):
6062
return data
63+
case AddressReg(data):
64+
return tuple(map(AddressQubit, data))
6165
case ConstResult(const.Value(data)) if isinstance(data, Iterable):
6266
return tuple(map(from_literal, data))
6367
case ConstResult(const.PartialTuple(data)):
@@ -156,12 +160,28 @@ def map_(
156160

157161
results = []
158162
for ele in iterable:
159-
results.append(self.call_function(interp_, fn, (ele,), ()))
163+
ret = self.call_function(interp_, fn, (ele,), ())
164+
results.append(ret)
160165

161166
if isinstance(stmt, ilist.Map):
162167
return (PartialIList(tuple(results)),)
163168

164169

170+
@py.len.dialect.register(key="qubit.address")
171+
class PyLen(interp.MethodTable, GetValuesMixin):
172+
@interp.impl(py.Len)
173+
def len_(
174+
self, interp_: AddressAnalysis, frame: ForwardFrame[Address], stmt: py.Len
175+
):
176+
obj = frame.get(stmt.value)
177+
values = self.get_values(obj)
178+
179+
if values is None:
180+
return (Address.top(),)
181+
182+
return (ConstResult(const.Value(len(values))),)
183+
184+
165185
@py.indexing.dialect.register(key="qubit.address")
166186
class PyIndexing(interp.MethodTable, GetValuesMixin):
167187
@interp.impl(py.GetItem)
@@ -177,7 +197,6 @@ def getitem(
177197
index = frame.get(stmt.index)
178198

179199
values = self.get_values(obj)
180-
181200
if not isinstance(obj, StaticContainer):
182201
return interp_.eval_stmt_fallback(frame, stmt)
183202

@@ -224,13 +243,15 @@ def invoke(
224243
frame: ForwardFrame[Address],
225244
stmt: func.Invoke,
226245
):
246+
227247
args = interp_.permute_values(
228248
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
229249
)
230250
_, ret = interp_.run_method(
231251
stmt.callee,
232252
args,
233253
)
254+
234255
return (ret,)
235256

236257
@interp.impl(func.Lambda)
@@ -356,30 +377,32 @@ def ifelse(
356377
stmt: scf.IfElse,
357378
):
358379
address_cond = frame.get(stmt.cond)
359-
360380
# run specific branch
361381
if isinstance(address_cond, ConstResult) and isinstance(
362382
const_cond := address_cond.result, const.Value
363383
):
364384
body = stmt.then_body if const_cond.data else stmt.else_body
365385
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
366386
ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
367-
frame.entries.update(body_frame.entries)
387+
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
368388
return ret
369389
else:
370390
# run both branches
371391
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
372392
then_results = interp_.run_ssacfg_region(
373393
then_frame, stmt.then_body, (address_cond,)
374394
)
395+
# interp_.set_values(
396+
# frame, then_frame.entries.keys(), then_frame.entries.values()
397+
# )
375398

376399
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
377400
else_results = interp_.run_ssacfg_region(
378401
else_frame, stmt.else_body, (address_cond,)
379402
)
380-
381-
frame.entries.update(then_frame.entries)
382-
frame.entries.update(else_frame.entries)
403+
# interp_.set_values(
404+
# frame, else_frame.entries.keys(), else_frame.entries.values()
405+
# )
383406
# TODO: pick the non-return value
384407
if isinstance(then_results, interp.ReturnValue) and isinstance(
385408
else_results, interp.ReturnValue
@@ -403,14 +426,15 @@ def for_loop(
403426
):
404427
loop_vars = frame.get_values(stmt.initializers)
405428
iterable = self.get_values(frame.get(stmt.iterable))
429+
406430
if iterable is None:
407431
return interp_.eval_stmt_fallback(frame, stmt)
408-
409432
for value in iterable:
410433
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
411434
loop_vars = interp_.run_ssacfg_region(
412435
body_frame, stmt.body, (value,) + loop_vars
413436
)
437+
414438
if loop_vars is None:
415439
loop_vars = ()
416440

0 commit comments

Comments
 (0)