Skip to content

Commit ca005e0

Browse files
weinbe58david-pl
andauthored
New Address Analysis. (#563)
~In this PR I implement a Joint analysis with constant prop and Address analysis to try to get the address analysis to work properly going into the call stack.~ I do not use a Cartesian product for the lattic because I need to implement partial lambda and partial IList and tuple to get the analysis to work. In order to support constant folding, however, I opted into adding an extra lattice element to wrap the constant prop results. --------- Co-authored-by: David Plankensteiner <[email protected]>
1 parent 517722f commit ca005e0

File tree

25 files changed

+1852
-1098
lines changed

25 files changed

+1852
-1098
lines changed
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from . import impls as impls
22
from .lattice import (
3+
Bottom as Bottom,
34
Address as Address,
4-
NotQubit as NotQubit,
5+
Unknown as Unknown,
56
AddressReg as AddressReg,
6-
AnyAddress as AnyAddress,
7-
AddressWire as AddressWire,
7+
UnknownReg as UnknownReg,
8+
ConstResult as ConstResult,
89
AddressQubit as AddressQubit,
9-
AddressTuple as AddressTuple,
10+
PartialIList as PartialIList,
11+
PartialTuple as PartialTuple,
12+
UnknownQubit as UnknownQubit,
13+
PartialLambda as PartialLambda,
1014
)
1115
from .analysis import AddressAnalysis as AddressAnalysis
Lines changed: 119 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import TypeVar
1+
from typing import Any, Type, TypeVar
22
from dataclasses import field
33

4-
from kirin import ir, interp
4+
from kirin import ir, types, interp
55
from kirin.analysis import Forward, const
6+
from kirin.dialects.ilist import IList
67
from kirin.analysis.forward import ForwardFrame
8+
from kirin.analysis.const.lattice import PartialLambda
79

8-
from bloqade.types import QubitType
9-
10-
from .lattice import Address
10+
from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple
1111

1212

1313
class AddressAnalysis(Forward[Address]):
@@ -16,12 +16,15 @@ class AddressAnalysis(Forward[Address]):
1616
"""
1717

1818
keys = ["qubit.address"]
19+
_const_prop: const.Propagate
1920
lattice = Address
2021
next_address: int = field(init=False)
2122

2223
def initialize(self):
2324
super().initialize()
2425
self.next_address: int = 0
26+
self._const_prop = const.Propagate(self.dialects)
27+
self._const_prop.initialize()
2528
return self
2629

2730
@property
@@ -31,30 +34,117 @@ def qubit_count(self) -> int:
3134

3235
T = TypeVar("T")
3336

34-
def get_const_value(self, typ: type[T], value: ir.SSAValue) -> T:
35-
if isinstance(hint := value.hints.get("const"), const.Value):
36-
data = hint.data
37-
if isinstance(data, typ):
38-
return hint.data
39-
raise interp.InterpreterError(
40-
f"Expected constant value <type = {typ}>, got {data}"
41-
)
42-
raise interp.InterpreterError(
43-
f"Expected constant value <type = {typ}>, got {value}"
44-
)
45-
46-
def eval_stmt_fallback(
47-
self, frame: ForwardFrame[Address], stmt: ir.Statement
48-
) -> tuple[Address, ...] | interp.SpecialValue[Address]:
49-
return tuple(
50-
(
51-
self.lattice.top()
52-
if result.type.is_subseteq(QubitType)
53-
else self.lattice.bottom()
54-
)
55-
for result in stmt.results
56-
)
37+
def to_address(self, result: const.Result):
38+
return ConstResult(result)
39+
40+
def try_eval_const_prop(
41+
self,
42+
frame: ForwardFrame[Address],
43+
stmt: ir.Statement,
44+
args: tuple[ConstResult, ...],
45+
) -> interp.StatementResult[Address]:
46+
_frame = self._const_prop.initialize_frame(frame.code)
47+
_frame.set_values(stmt.args, tuple(x.result for x in args))
48+
result = self._const_prop.eval_stmt(_frame, stmt)
49+
50+
match result:
51+
case interp.ReturnValue(constant_ret):
52+
return interp.ReturnValue(self.to_address(constant_ret))
53+
case interp.YieldValue(constant_values):
54+
return interp.YieldValue(tuple(map(self.to_address, constant_values)))
55+
case interp.Successor(block, block_args):
56+
return interp.Successor(block, *map(self.to_address, block_args))
57+
case tuple():
58+
return tuple(map(self.to_address, result))
59+
case _:
60+
return result
61+
62+
def unpack_iterable(self, iterable: Address):
63+
"""Extract the values of a container lattice element.
64+
65+
Args:
66+
iterable: The lattice element representing a container.
67+
68+
Returns:
69+
A tuple of the container type and the contained values.
70+
71+
"""
72+
73+
def from_constant(constant: const.Result) -> Address:
74+
return ConstResult(constant)
75+
76+
def from_literal(literal: Any) -> Address:
77+
return ConstResult(const.Value(literal))
78+
79+
match iterable:
80+
case PartialIList(data):
81+
return PartialIList, data
82+
case PartialTuple(data):
83+
return PartialTuple, data
84+
case AddressReg():
85+
return PartialIList, iterable.qubits
86+
case ConstResult(const.Value(IList() as data)):
87+
return PartialIList, tuple(map(from_literal, data))
88+
case ConstResult(const.Value(tuple() as data)):
89+
return PartialTuple, tuple(map(from_literal, data))
90+
case ConstResult(const.PartialTuple(data)):
91+
return PartialTuple, tuple(map(from_constant, data))
92+
case _:
93+
return None, ()
94+
95+
def run_lattice(
96+
self,
97+
callee: Address,
98+
inputs: tuple[Address, ...],
99+
kwargs: tuple[str, ...],
100+
) -> Address:
101+
"""Run a callable lattice element with the given inputs and keyword arguments.
102+
103+
Args:
104+
callee (Address): The lattice element representing the callable.
105+
inputs (tuple[Address, ...]): The input lattice elements.
106+
kwargs (tuple[str, ...]): The keyword argument names.
107+
108+
Returns:
109+
Address: The resulting lattice element after invoking the callable.
110+
111+
"""
112+
113+
match callee:
114+
case PartialLambda(code=code, argnames=argnames):
115+
_, ret = self.run_callable(
116+
code, (callee,) + self.permute_values(argnames, inputs, kwargs)
117+
)
118+
return ret
119+
case ConstResult(const.Value(ir.Method() as method)):
120+
_, ret = self.run_method(
121+
method,
122+
self.permute_values(method.arg_names, inputs, kwargs),
123+
)
124+
return ret
125+
case _:
126+
return Address.top()
127+
128+
def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
129+
if not isinstance(addr, ConstResult):
130+
return None
131+
132+
if not isinstance(result := addr.result, const.Value):
133+
return None
134+
135+
if not isinstance(value := result.data, typ):
136+
return None
137+
138+
return value
139+
140+
def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
141+
args = frame.get_values(stmt.args)
142+
if types.is_tuple_of(args, ConstResult):
143+
return self.try_eval_const_prop(frame, stmt, args)
144+
145+
return tuple(Address.from_type(result.type) for result in stmt.results)
57146

58147
def run_method(self, method: ir.Method, args: tuple[Address, ...]):
59148
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
60-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
149+
self_mt = ConstResult(const.Value(method))
150+
return self.run_callable(method.code, (self_mt,) + args)

0 commit comments

Comments
 (0)