Skip to content

Commit ae2bddf

Browse files
committed
moving mixin methods to interpreter
1 parent 4dbee48 commit ae2bddf

File tree

2 files changed

+84
-92
lines changed

2 files changed

+84
-92
lines changed

src/bloqade/analysis/address/analysis.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import TypeVar
1+
from typing import Any, TypeVar
22
from dataclasses import field
33

44
from kirin import ir, types, interp
55
from kirin.analysis import Forward, const
6+
from kirin.dialects.ilist import IList
67
from kirin.analysis.forward import ForwardFrame
8+
from kirin.analysis.const.lattice import PartialLambda
79

8-
from .lattice import Address, ConstResult
10+
from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple
911

1012

1113
class AddressAnalysis(Forward[Address]):
@@ -57,6 +59,72 @@ def try_eval_const_prop(
5759
case _:
5860
return result
5961

62+
def unpack_iterable(self, collection: Address):
63+
"""Extract the values of a container lattice element.
64+
65+
Args:
66+
collection: The lattice element representing a container.
67+
68+
Returns:
69+
A tuple of the container type and the contained values.
70+
71+
"""
72+
73+
def from_constant(constant: const.Result) -> Address:
74+
return ConstResult(constant)
75+
76+
def from_literal(literal: Any) -> Address:
77+
return ConstResult(const.Value(literal))
78+
79+
match collection:
80+
case PartialIList(data):
81+
return PartialIList, data
82+
case PartialTuple(data):
83+
return PartialTuple, data
84+
case AddressReg():
85+
return PartialIList, collection.qubits
86+
case ConstResult(const.Value(IList() as data)):
87+
return PartialIList, tuple(map(from_literal, data))
88+
case ConstResult(const.Value(tuple() as data)):
89+
return PartialTuple, tuple(map(from_literal, data))
90+
case ConstResult(const.PartialTuple(data)):
91+
return PartialTuple, tuple(map(from_constant, data))
92+
case _:
93+
return None, ()
94+
95+
def run_lattice(
96+
self,
97+
callee: Address,
98+
inputs: tuple[Address, ...],
99+
kwargs: tuple[str, ...],
100+
) -> Address:
101+
"""Run a callable lattice element with the given inputs and keyword arguments.
102+
103+
Args:
104+
callee (Address): The lattice element representing the callable.
105+
inputs (tuple[Address, ...]): The input lattice elements.
106+
kwargs (tuple[str, ...]): The keyword argument names.
107+
108+
Returns:
109+
Address: The resulting lattice element after invoking the callable.
110+
111+
"""
112+
113+
match callee:
114+
case PartialLambda(code=code, argnames=argnames):
115+
_, ret = self.run_callable(
116+
code, (callee,) + self.permute_values(argnames, inputs, kwargs)
117+
)
118+
return ret
119+
case ConstResult(const.Value(ir.Method() as method)):
120+
_, ret = self.run_method(
121+
method,
122+
self.permute_values(method.arg_names, inputs, kwargs),
123+
)
124+
return ret
125+
case _:
126+
return Address.top()
127+
60128
def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
61129
args = frame.get_values(stmt.args)
62130
if types.is_tuple_of(args, ConstResult):

src/bloqade/analysis/address/impls.py

Lines changed: 14 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
qubit.address method table for a few builtin dialects.
33
"""
44

5-
from typing import Any
65
from itertools import chain
76

87
from kirin import ir, interp
98
from kirin.analysis import ForwardFrame, const
109
from kirin.dialects import cf, py, scf, func, ilist
11-
from kirin.dialects.ilist import IList
1210

1311
from .lattice import (
1412
Address,
@@ -22,79 +20,6 @@
2220
from .analysis import AddressAnalysis
2321

2422

25-
# TODO: put this in the interpreter
26-
class CallInterfaceMixin:
27-
"""This mixin provides a generic implementation to call lattice elements for method tables.
28-
29-
It handles PartialLambda and ConstResult wrapping ir.Method."""
30-
31-
def call_function(
32-
self,
33-
interp_: AddressAnalysis,
34-
callee: Address,
35-
inputs: tuple[Address, ...],
36-
kwargs: tuple[str, ...],
37-
) -> Address:
38-
39-
match callee:
40-
case PartialLambda(code=code, argnames=argnames):
41-
_, ret = interp_.run_callable(
42-
code, (callee,) + interp_.permute_values(argnames, inputs, kwargs)
43-
)
44-
return ret
45-
case ConstResult(const.Value(ir.Method() as method)):
46-
_, ret = interp_.run_method(
47-
method,
48-
interp_.permute_values(method.arg_names, inputs, kwargs),
49-
)
50-
return ret
51-
case _:
52-
return Address.top()
53-
54-
55-
class GetValuesMixin:
56-
"""This mixin provides a generic implementation to extract values of lattice elements
57-
58-
that are represent the values of containers. The return type is used to differentiate
59-
between IList and Tuple containers in the analysis for cases where the type information
60-
is important for the analysis not just the contained values.
61-
62-
"""
63-
64-
def get_values(self, collection: Address):
65-
"""Extract the values of a container lattice element.
66-
67-
Args:
68-
collection: The lattice element representing a container.
69-
70-
Returns:
71-
A tuple of the container type and the contained values.
72-
73-
"""
74-
75-
def from_constant(constant: const.Result) -> Address:
76-
return ConstResult(constant)
77-
78-
def from_literal(literal: Any) -> Address:
79-
return ConstResult(const.Value(literal))
80-
81-
match collection:
82-
case PartialIList(data):
83-
return PartialIList, data
84-
case PartialTuple(data):
85-
return PartialTuple, data
86-
case AddressReg():
87-
return PartialIList, collection.qubits
88-
case ConstResult(const.Value(IList() as data)):
89-
return PartialIList, tuple(map(from_literal, data))
90-
case ConstResult(const.Value(tuple() as data)):
91-
return PartialTuple, tuple(map(from_literal, data))
92-
case ConstResult(const.PartialTuple(data)):
93-
return PartialTuple, tuple(map(from_constant, data))
94-
case _:
95-
return None, ()
96-
97-
9823
@py.constant.dialect.register(key="qubit.address")
9924
class PyConstant(interp.MethodTable):
10025
@interp.impl(py.Constant)
@@ -108,7 +33,7 @@ def constant(
10833

10934

11035
@py.binop.dialect.register(key="qubit.address")
111-
class PyBinOp(interp.MethodTable, GetValuesMixin):
36+
class PyBinOp(interp.MethodTable):
11237
@interp.impl(py.Add)
11338
def add(
11439
self,
@@ -119,8 +44,8 @@ def add(
11944
lhs = frame.get(stmt.lhs)
12045
rhs = frame.get(stmt.rhs)
12146

122-
lhs_type, lhs_values = self.get_values(lhs)
123-
rhs_type, rhs_values = self.get_values(rhs)
47+
lhs_type, lhs_values = interp_.unpack_iterable(lhs)
48+
rhs_type, rhs_values = interp_.unpack_iterable(rhs)
12449

12550
if lhs_type is None or rhs_type is None:
12651
return (Address.top(),)
@@ -144,7 +69,7 @@ def new_tuple(
14469

14570

14671
@ilist.dialect.register(key="qubit.address")
147-
class IListMethods(interp.MethodTable, CallInterfaceMixin, GetValuesMixin):
72+
class IListMethods(interp.MethodTable):
14873
@interp.impl(ilist.New)
14974
def new_ilist(
15075
self,
@@ -165,7 +90,7 @@ def map_(
16590
results = []
16691
fn = frame.get(stmt.fn)
16792
collection = frame.get(stmt.collection)
168-
collection_type, values = self.get_values(collection)
93+
collection_type, values = interp_.unpack_iterable(collection)
16994

17095
if collection_type is None:
17196
return (Address.top(),)
@@ -175,21 +100,21 @@ def map_(
175100

176101
results = []
177102
for ele in values:
178-
ret = self.call_function(interp_, fn, (ele,), ())
103+
ret = interp_.run_lattice(fn, (ele,), ())
179104
results.append(ret)
180105

181106
if isinstance(stmt, ilist.Map):
182107
return (PartialIList(tuple(results)),)
183108

184109

185110
@py.len.dialect.register(key="qubit.address")
186-
class PyLen(interp.MethodTable, GetValuesMixin):
111+
class PyLen(interp.MethodTable):
187112
@interp.impl(py.Len)
188113
def len_(
189114
self, interp_: AddressAnalysis, frame: ForwardFrame[Address], stmt: py.Len
190115
):
191116
obj = frame.get(stmt.value)
192-
_, values = self.get_values(obj)
117+
_, values = interp_.unpack_iterable(obj)
193118

194119
if values is None:
195120
return (Address.top(),)
@@ -198,7 +123,7 @@ def len_(
198123

199124

200125
@py.indexing.dialect.register(key="qubit.address")
201-
class PyIndexing(interp.MethodTable, GetValuesMixin):
126+
class PyIndexing(interp.MethodTable):
202127
@interp.impl(py.GetItem)
203128
def getitem(
204129
self,
@@ -211,7 +136,7 @@ def getitem(
211136
obj = frame.get(stmt.obj)
212137
index = frame.get(stmt.index)
213138

214-
typ, values = self.get_values(obj)
139+
typ, values = interp_.unpack_iterable(obj)
215140
if typ is None:
216141
return (Address.top(),)
217142

@@ -242,7 +167,7 @@ def alias(
242167

243168
# TODO: look for abstract method table for func.
244169
@func.dialect.register(key="qubit.address")
245-
class Func(interp.MethodTable, CallInterfaceMixin):
170+
class Func(interp.MethodTable):
246171
@interp.impl(func.Return)
247172
def return_(
248173
self,
@@ -296,8 +221,7 @@ def call(
296221
frame: ForwardFrame[Address],
297222
stmt: func.Call,
298223
):
299-
result = self.call_function(
300-
interp_,
224+
result = interp_.run_lattice(
301225
frame.get(stmt.callee),
302226
frame.get_values(stmt.inputs),
303227
stmt.kwargs,
@@ -376,7 +300,7 @@ def conditional_branch(
376300

377301

378302
@scf.dialect.register(key="qubit.address")
379-
class Scf(interp.MethodTable, GetValuesMixin):
303+
class Scf(interp.MethodTable):
380304
@interp.impl(scf.Yield)
381305
def yield_(
382306
self,
@@ -442,7 +366,7 @@ def for_loop(
442366
stmt: scf.For,
443367
):
444368
loop_vars = frame.get_values(stmt.initializers)
445-
iter_type, iterable = self.get_values(frame.get(stmt.iterable))
369+
iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
446370

447371
if iter_type is None:
448372
return interp_.eval_stmt_fallback(frame, stmt)

0 commit comments

Comments
 (0)