Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/finchlite/finch_assembly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
assembly_build_cfg,
assembly_number_uses,
)
from .dataflow import AssemblyCopyPropagation, assembly_copy_propagation
from .dataflow import (
AssemblyAvailableExpressions,
AssemblyCopyPropagation,
assembly_copy_propagation,
)
from .interpreter import AssemblyInterpreter, AssemblyInterpreterKernel
from .nodes import (
AssemblyNode,
Expand Down Expand Up @@ -40,6 +44,7 @@
from .type_checker import AssemblyTypeChecker, AssemblyTypeError, assembly_check_types

__all__ = [
"AssemblyAvailableExpressions",
"AssemblyCFGBuilder",
"AssemblyCopyPropagation",
"AssemblyInterpreter",
Expand Down
38 changes: 32 additions & 6 deletions src/finchlite/finch_assembly/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

from ..symbolic import DataFlowAnalysis, PostOrderDFS
from .cfg_builder import assembly_build_cfg
from .nodes import (
AssemblyNode,
Assign,
TaggedVariable,
Variable,
)
from .nodes import AssemblyNode, Assign, TaggedVariable, Variable


def assembly_copy_propagation(node: AssemblyNode):
Expand All @@ -24,6 +19,21 @@ def assembly_copy_propagation(node: AssemblyNode):
return ctx


# def assembly_available_expressions(node: AssemblyNode):
# """Run available-expressions analysis on a FinchAssembly node.

# Args:
# node: Root FinchAssembly node to analyze.

# Returns:
# AssemblyAvailableExpressions: The completed analysis context.
# """

# ctx = AssemblyAvailableExpressions(assembly_build_cfg(node))
# ctx.analyze()
# return ctx


class AbstractAssemblyDataflow(DataFlowAnalysis):
"""Assembly-specific base for dataflow analyses."""

Expand Down Expand Up @@ -129,3 +139,19 @@ def join(self, state_1: dict, state_2: dict) -> dict:
result[var_name] = state_1[var_name]

return result


class AssemblyAvailableExpressions(AbstractAssemblyDataflow):
"""Available expressions analysis for FinchAssembly."""

def direction(self) -> str:
return "forward"

def transfer(self, stmts, state: dict) -> dict:
pass

def join(self, state_1: dict, state_2: dict) -> dict:
pass

def print_lattice_value(self, state, stmt) -> list[tuple[str, object]]:
pass
12 changes: 12 additions & 0 deletions tests/reference/test_asm_printer_comprehensive2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def avail_demo(a: int64, b: int64, c: int64, cond: bool) -> int64:
t1: int64 = add(a, b)
t2: int64 = add(a, b)
t3: int64 = mul(b, c)
t1: int64 = add(add(a, b), mul(b, c))
if cond:
t2: int64 = add(add(a, b), 1)
else:
a: int64 = 0
t3: int64 = sub(a, c)
t3: int64 = add(t1, t2)
return t3
70 changes: 69 additions & 1 deletion tests/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import finchlite.finch_assembly as asm
from finchlite.codegen.numpy_buffer import NumpyBuffer
from finchlite.finch_assembly.cfg_builder import assembly_build_cfg
from finchlite.finch_assembly.dataflow import assembly_copy_propagation
from finchlite.finch_assembly.dataflow import (
assembly_copy_propagation,
)


def test_asm_cfg_printer_if(file_regression):
Expand Down Expand Up @@ -666,3 +668,69 @@ def test_asm_copy_propagation_comprehensive(file_regression):

copy_propagation = assembly_copy_propagation(root)
file_regression.check(str(copy_propagation), extension=".txt")


# def test_asm_avail_exp_comprehensive2(file_regression):
# a = asm.Variable("a", np.int64)
# b = asm.Variable("b", np.int64)
# c = asm.Variable("c", np.int64)
# cond = asm.Variable("cond", np.bool_)
# t1 = asm.Variable("t1", np.int64)
# t2 = asm.Variable("t2", np.int64)
# t3 = asm.Variable("t3", np.int64)

# add_ab = asm.Call(asm.Literal(operator.add), (a, b))
# mul_bc = asm.Call(asm.Literal(operator.mul), (b, c))
# sub_ac = asm.Call(asm.Literal(operator.sub), (a, c))
# nested = asm.Call(asm.Literal(operator.add), (add_ab, mul_bc))

# root = asm.Module(
# (
# asm.Function(
# asm.Variable("avail_demo", np.int64),
# (a, b, c, cond),
# asm.Block(
# (
# asm.Assign(t1, add_ab),
# asm.Assign(t2, add_ab),
# asm.Assign(t3, mul_bc),
# asm.Assign(
# t1,
# nested,
# ),
# asm.IfElse(
# cond,
# asm.Block(
# (
# asm.Assign(
# t2,
# asm.Call(
# asm.Literal(operator.add),
# (add_ab, asm.Literal(np.int64(1))),
# ),
# ),
# )
# ),
# asm.Block(
# (
# asm.Assign(a, asm.Literal(np.int64(0))),
# asm.Assign(t3, sub_ac),
# )
# ),
# ),
# asm.Assign(
# t3,
# asm.Call(
# asm.Literal(operator.add),
# (t1, t2),
# ),
# ),
# asm.Return(t3),
# )
# ),
# ),
# )
# )

# avail_exp = assembly_available_expressions(root)
# file_regression.check(str(avail_exp), extension=".txt")
65 changes: 65 additions & 0 deletions tests/test_printers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,71 @@ def test_tagged_asm_printer_if(file_regression):
file_regression.check(str(root), extension=".txt")


def test_asm_printer_comprehensive2(file_regression):
a = asm.Variable("a", np.int64)
b = asm.Variable("b", np.int64)
c = asm.Variable("c", np.int64)
cond = asm.Variable("cond", np.bool_)
t1 = asm.Variable("t1", np.int64)
t2 = asm.Variable("t2", np.int64)
t3 = asm.Variable("t3", np.int64)

add_ab = asm.Call(asm.Literal(operator.add), (a, b))
mul_bc = asm.Call(asm.Literal(operator.mul), (b, c))
sub_ac = asm.Call(asm.Literal(operator.sub), (a, c))
nested = asm.Call(asm.Literal(operator.add), (add_ab, mul_bc))

root = asm.Module(
(
asm.Function(
asm.Variable("avail_demo", np.int64),
(a, b, c, cond),
asm.Block(
(
asm.Assign(t1, add_ab),
asm.Assign(t2, add_ab),
asm.Assign(t3, mul_bc),
asm.Assign(
t1,
nested,
),
asm.IfElse(
cond,
asm.Block(
(
asm.Assign(
t2,
asm.Call(
asm.Literal(operator.add),
(add_ab, asm.Literal(np.int64(1))),
),
),
)
),
asm.Block(
(
asm.Assign(a, asm.Literal(np.int64(0))),
asm.Assign(t3, sub_ac),
)
),
),
asm.Assign(
t3,
asm.Call(
asm.Literal(operator.add),
(t1, t2),
),
),
asm.Return(t3),
)
),
),
)
)

file_regression.check(str(root), extension=".txt")


def test_asm_printer_dot(file_regression):
c = asm.Variable("c", np.float64)
i = asm.Variable("i", np.int64)
Expand Down
Loading