diff --git a/src/finchlite/finch_assembly/__init__.py b/src/finchlite/finch_assembly/__init__.py index a2417d22..5def8b1d 100644 --- a/src/finchlite/finch_assembly/__init__.py +++ b/src/finchlite/finch_assembly/__init__.py @@ -4,7 +4,11 @@ assembly_build_cfg, assembly_number_uses, ) -from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation +from .dataflow import ( + AssemblyAvailableExpressions, + AssemblyCopyPropagation, + assembly_copy_propagation, +) from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel from .nodes import ( AssemblyNode, @@ -40,6 +44,7 @@ from .type_checker import AssemblyTypeChecker, AssemblyTypeError, assembly_check_types __all__ = [ + "AssemblyAvailableExpressions", "AssemblyCFGBuilder", "AssemblyCopyPropagation", "AssemblyInterpreter", diff --git a/src/finchlite/finch_assembly/dataflow.py b/src/finchlite/finch_assembly/dataflow.py index 92239ffc..986d7846 100644 --- a/src/finchlite/finch_assembly/dataflow.py +++ b/src/finchlite/finch_assembly/dataflow.py @@ -2,12 +2,7 @@ from ..symbolic import DataFlowAnalysis, PostOrderDFS from .cfg_builder import assembly_build_cfg -from .nodes import ( - AssemblyNode, - Assign, - TaggedVariable, - Variable, -) +from .nodes import AssemblyNode, Assign, TaggedVariable, Variable def assembly_copy_propagation(node: AssemblyNode): @@ -24,6 +19,21 @@ def assembly_copy_propagation(node: AssemblyNode): return ctx +# def assembly_available_expressions(node: AssemblyNode): +# """Run available-expressions analysis on a FinchAssembly node. + +# Args: +# node: Root FinchAssembly node to analyze. + +# Returns: +# AssemblyAvailableExpressions: The completed analysis context. +# """ + +# ctx = AssemblyAvailableExpressions(assembly_build_cfg(node)) +# ctx.analyze() +# return ctx + + class AbstractAssemblyDataflow(DataFlowAnalysis): """Assembly-specific base for dataflow analyses.""" @@ -129,3 +139,19 @@ def join(self, state_1: dict, state_2: dict) -> dict: result[var_name] = state_1[var_name] return result + + +class AssemblyAvailableExpressions(AbstractAssemblyDataflow): + """Available expressions analysis for FinchAssembly.""" + + def direction(self) -> str: + return "forward" + + def transfer(self, stmts, state: dict) -> dict: + pass + + def join(self, state_1: dict, state_2: dict) -> dict: + pass + + def print_lattice_value(self, state, stmt) -> list[tuple[str, object]]: + pass diff --git a/tests/reference/test_asm_printer_comprehensive2.txt b/tests/reference/test_asm_printer_comprehensive2.txt new file mode 100644 index 00000000..e1ad2f36 --- /dev/null +++ b/tests/reference/test_asm_printer_comprehensive2.txt @@ -0,0 +1,12 @@ +def avail_demo(a: int64, b: int64, c: int64, cond: bool) -> int64: + t1: int64 = add(a, b) + t2: int64 = add(a, b) + t3: int64 = mul(b, c) + t1: int64 = add(add(a, b), mul(b, c)) + if cond: + t2: int64 = add(add(a, b), 1) + else: + a: int64 = 0 + t3: int64 = sub(a, c) + t3: int64 = add(t1, t2) + return t3 diff --git a/tests/test_dataflow.py b/tests/test_dataflow.py index 097eb297..168e6f6c 100644 --- a/tests/test_dataflow.py +++ b/tests/test_dataflow.py @@ -5,7 +5,9 @@ import finchlite.finch_assembly as asm from finchlite.codegen.numpy_buffer import NumpyBuffer from finchlite.finch_assembly.cfg_builder import assembly_build_cfg -from finchlite.finch_assembly.dataflow import assembly_copy_propagation +from finchlite.finch_assembly.dataflow import ( + assembly_copy_propagation, +) def test_asm_cfg_printer_if(file_regression): @@ -666,3 +668,69 @@ def test_asm_copy_propagation_comprehensive(file_regression): copy_propagation = assembly_copy_propagation(root) file_regression.check(str(copy_propagation), extension=".txt") + + +# def test_asm_avail_exp_comprehensive2(file_regression): +# a = asm.Variable("a", np.int64) +# b = asm.Variable("b", np.int64) +# c = asm.Variable("c", np.int64) +# cond = asm.Variable("cond", np.bool_) +# t1 = asm.Variable("t1", np.int64) +# t2 = asm.Variable("t2", np.int64) +# t3 = asm.Variable("t3", np.int64) + +# add_ab = asm.Call(asm.Literal(operator.add), (a, b)) +# mul_bc = asm.Call(asm.Literal(operator.mul), (b, c)) +# sub_ac = asm.Call(asm.Literal(operator.sub), (a, c)) +# nested = asm.Call(asm.Literal(operator.add), (add_ab, mul_bc)) + +# root = asm.Module( +# ( +# asm.Function( +# asm.Variable("avail_demo", np.int64), +# (a, b, c, cond), +# asm.Block( +# ( +# asm.Assign(t1, add_ab), +# asm.Assign(t2, add_ab), +# asm.Assign(t3, mul_bc), +# asm.Assign( +# t1, +# nested, +# ), +# asm.IfElse( +# cond, +# asm.Block( +# ( +# asm.Assign( +# t2, +# asm.Call( +# asm.Literal(operator.add), +# (add_ab, asm.Literal(np.int64(1))), +# ), +# ), +# ) +# ), +# asm.Block( +# ( +# asm.Assign(a, asm.Literal(np.int64(0))), +# asm.Assign(t3, sub_ac), +# ) +# ), +# ), +# asm.Assign( +# t3, +# asm.Call( +# asm.Literal(operator.add), +# (t1, t2), +# ), +# ), +# asm.Return(t3), +# ) +# ), +# ), +# ) +# ) + +# avail_exp = assembly_available_expressions(root) +# file_regression.check(str(avail_exp), extension=".txt") diff --git a/tests/test_printers.py b/tests/test_printers.py index a53a22fb..cecea04a 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -281,6 +281,71 @@ def test_tagged_asm_printer_if(file_regression): file_regression.check(str(root), extension=".txt") +def test_asm_printer_comprehensive2(file_regression): + a = asm.Variable("a", np.int64) + b = asm.Variable("b", np.int64) + c = asm.Variable("c", np.int64) + cond = asm.Variable("cond", np.bool_) + t1 = asm.Variable("t1", np.int64) + t2 = asm.Variable("t2", np.int64) + t3 = asm.Variable("t3", np.int64) + + add_ab = asm.Call(asm.Literal(operator.add), (a, b)) + mul_bc = asm.Call(asm.Literal(operator.mul), (b, c)) + sub_ac = asm.Call(asm.Literal(operator.sub), (a, c)) + nested = asm.Call(asm.Literal(operator.add), (add_ab, mul_bc)) + + root = asm.Module( + ( + asm.Function( + asm.Variable("avail_demo", np.int64), + (a, b, c, cond), + asm.Block( + ( + asm.Assign(t1, add_ab), + asm.Assign(t2, add_ab), + asm.Assign(t3, mul_bc), + asm.Assign( + t1, + nested, + ), + asm.IfElse( + cond, + asm.Block( + ( + asm.Assign( + t2, + asm.Call( + asm.Literal(operator.add), + (add_ab, asm.Literal(np.int64(1))), + ), + ), + ) + ), + asm.Block( + ( + asm.Assign(a, asm.Literal(np.int64(0))), + asm.Assign(t3, sub_ac), + ) + ), + ), + asm.Assign( + t3, + asm.Call( + asm.Literal(operator.add), + (t1, t2), + ), + ), + asm.Return(t3), + ) + ), + ), + ) + ) + + file_regression.check(str(root), extension=".txt") + + def test_asm_printer_dot(file_regression): c = asm.Variable("c", np.float64) i = asm.Variable("i", np.int64)