From d143ce88db59ae5828251e60aa1e8a205493f258 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 9 Oct 2025 10:12:52 -0400 Subject: [PATCH 01/57] * Added einsum lowerer back; * Still needs to be tested --- src/finchlite/autoschedule/einsum.py | 194 +++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 src/finchlite/autoschedule/einsum.py diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py new file mode 100644 index 00000000..3a657e05 --- /dev/null +++ b/src/finchlite/autoschedule/einsum.py @@ -0,0 +1,194 @@ +import finchlite.finch_einsum as ein +from finchlite.finch_logic import ( + Plan, + Produces, + Query, + Alias, + Table, + LogicNode, + MapJoin, + Literal, + Reorder, + Aggregate, + Relabel, +) +from finchlite.algebra import overwrite, init_value, is_commutative +from collections.abc import Callable + +class EinsumLowerer: + alias_counter: int = 0 + + def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Table]]: + parameters: dict[str, Table] = {} + definitions: dict[str, ein.Einsum] = {} + return self.compile_plan(prgm, parameters, definitions), parameters + + def get_next_alias(self) -> ein.Alias: + self.alias_counter += 1 + return ein.Alias(f"einsum_{self.alias_counter}") + + def rename_einsum( + self, einsum: ein.Einsum, new_alias: ein.Alias, definitions: dict[str, ein.Einsum] + ) -> ein.Einsum: + definitions[new_alias.name] = einsum + return ein.Einsum(einsum.op, new_alias, einsum.idxs, einsum.arg) + + def compile_plan( + self, plan: Plan, parameters: dict[str, Table], definitions: dict[str, ein.Einsum] + ) -> ein.Plan: + einsums: list[ein.Einsum] = [] + returnValue: list[ein.EinsumExpr] = [] + + for body in plan.bodies: + match body: + case Plan(_): + inner_plan = self.compile_plan(body, parameters, definitions) + einsums.extend(inner_plan.bodies) + returnValue.extend(inner_plan.returnValues) + break + case Query(Alias(name), Table(_, _)): + parameters[name] = body.rhs + case Query(Alias(name), rhs): + einsums.append( + self.rename_einsum( + self.lower_to_einsum(rhs, einsums, parameters, definitions), + ein.Alias(name), + definitions, + ) + ) + case Produces(args): + returnValue = [ + ein.Alias(arg.name) + if isinstance(arg, Alias) + else self.lower_to_einsum(arg, einsums, parameters, definitions) + for arg in args + ] + break + case _: + einsums.append( + self.rename_einsum( + self.lower_to_einsum( + body, einsums, parameters, definitions + ), + self.get_next_alias(), + definitions, + ) + ) + + return ein.Plan(tuple(einsums), tuple(returnValue)) + + def lower_to_einsum( + self, + ex: LogicNode, + einsums: list[ein.Einsum], + parameters: dict[str, Table], + definitions: dict[str, ein.Einsum], + ) -> ein.Einsum: + match ex: + case Plan(_): + raise Exception("Plans within plans are not supported.") + case MapJoin(Literal(operation), args): + args_list = [ + self.lower_to_pointwise(arg, einsums, parameters, definitions) + for arg in args + ] + pointwise_expr = self.lower_to_pointwise_op(operation, tuple(args_list)) + return ein.Einsum( + reduceOp=overwrite, + output=self.get_next_alias(), + output_fields=tuple( + ein.Index(field.name) for field in ex.fields + ), + pointwise_expr=pointwise_expr, + ) + case Reorder(arg, idxs): + return self.lower_to_einsum( + arg, einsums, parameters, definitions + ).reorder(idxs) + case Aggregate(Literal(operation), Literal(init), arg, idxs): + if init != init_value(operation, type(init)): + raise Exception(f""" + Init value {init} is not the default value + for operation {operation} of type {type(init)}. + Non standard init values are not supported. + """) + aggregate_expr = self.lower_to_pointwise( + arg, einsums, parameters, definitions + ) + return ein.Einsum( + op = ein.Literal(operation), + tns = self.get_next_alias(), + idxs = tuple(ein.Index(field.name) for field in ex.fields), + arg = aggregate_expr + ) + case _: + raise Exception(f"Unrecognized logic: {ex}") + + def lower_to_pointwise_op( + self, operation: Callable, args: tuple[ein.EinsumExpr, ...] + ) -> ein.EinsumExpr: + # if operation is commutative, we simply pass + # all the args to the pointwise op since + # order of args does not matter + if is_commutative(operation): + + def flatten_args( + m_args: tuple[ein.EinsumExpr, ...], + ) -> tuple[ein.EinsumExpr, ...]: + ret_args: list[ein.EinsumExpr] = [] + for arg in m_args: + match arg: + case ein.Call(op2, _) if op2 == operation: + ret_args.extend(flatten_args(arg.args)) + case _: + ret_args.append(arg) + return tuple(ret_args) + + return ein.Call(operation, flatten_args(args)) + + # combine args from left to right (i.e a / b / c -> (a / b) / c) + return ein.Call(operation, args) + + # lowers nested mapjoin logic IR nodes into a single pointwise expression + def lower_to_pointwise( + self, + ex: LogicNode, + einsums: list[ein.Einsum], + parameters: dict[str, Table], + definitions: dict[str, ein.Einsum], + ) -> ein.EinsumExpr: + match ex: + case Reorder(arg, idxs): + return self.lower_to_pointwise(arg, einsums, parameters, definitions) + case MapJoin(Literal(operation), args): + args_list = [ + self.lower_to_pointwise(arg, einsums, parameters, definitions) + for arg in args + ] + return self.lower_to_pointwise_op(operation, tuple(args_list)) + case Relabel( + Alias(name), idxs + ): # relable is really just a glorified pointwise access + return ein.Access( + tns=ein.Alias(name), + idxs=tuple(ein.Index(idx.name) for idx in idxs), + ) + case Literal(value): + return ein.Literal(val=value) + case Aggregate( + _, _, _, _ + ): # aggregate has to be computed seperatley as it's own einsum + aggregate_einsum_alias = self.get_next_alias() + einsums.append( + self.rename_einsum( + self.lower_to_einsum(ex, einsums, parameters, definitions), + aggregate_einsum_alias, + definitions, + ) + ) + return ein.Access( + alias=ein.Alias(aggregate_einsum_alias), + idxs=tuple(ein.Index(field.name) for field in ex.fields), + ) + case _: + raise Exception(f"Unrecognized logic: {ex}") \ No newline at end of file From 7c518292912908496d91618afb5ca0b7bed5450f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 9 Oct 2025 10:21:26 -0400 Subject: [PATCH 02/57] Fixed some obvious errors --- src/finchlite/autoschedule/einsum.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 3a657e05..1b115a0d 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -94,12 +94,12 @@ def lower_to_einsum( ] pointwise_expr = self.lower_to_pointwise_op(operation, tuple(args_list)) return ein.Einsum( - reduceOp=overwrite, - output=self.get_next_alias(), - output_fields=tuple( + op=ein.Literal(overwrite), + tns=self.get_next_alias(), + idxs=tuple( ein.Index(field.name) for field in ex.fields ), - pointwise_expr=pointwise_expr, + arg=pointwise_expr, ) case Reorder(arg, idxs): return self.lower_to_einsum( @@ -138,16 +138,16 @@ def flatten_args( ret_args: list[ein.EinsumExpr] = [] for arg in m_args: match arg: - case ein.Call(op2, _) if op2 == operation: + case ein.Call(ein.Literal(op2), _) if op2 == operation: ret_args.extend(flatten_args(arg.args)) case _: ret_args.append(arg) return tuple(ret_args) - return ein.Call(operation, flatten_args(args)) + return ein.Call(ein.Literal(operation), flatten_args(args)) # combine args from left to right (i.e a / b / c -> (a / b) / c) - return ein.Call(operation, args) + return ein.Call(ein.Literal(operation), args) # lowers nested mapjoin logic IR nodes into a single pointwise expression def lower_to_pointwise( @@ -187,7 +187,7 @@ def lower_to_pointwise( ) ) return ein.Access( - alias=ein.Alias(aggregate_einsum_alias), + tns=aggregate_einsum_alias, idxs=tuple(ein.Index(field.name) for field in ex.fields), ) case _: From d32aa8a1fd261a6720b1a2e89631c1cd0d308367 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 9 Oct 2025 10:37:00 -0400 Subject: [PATCH 03/57] * Added pytests * Currently all are failing --- tests/test_einsum_lowerer.py | 640 +++++++++++++++++++++++++++++++++++ 1 file changed, 640 insertions(+) create mode 100644 tests/test_einsum_lowerer.py diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py new file mode 100644 index 00000000..19b98236 --- /dev/null +++ b/tests/test_einsum_lowerer.py @@ -0,0 +1,640 @@ +import pytest +import numpy as np +import operator +from finchlite.autoschedule.einsum import EinsumLowerer +from finchlite.finch_logic import ( + Plan, + Produces, + Query, + Alias, + Table, + MapJoin, + Literal, + Aggregate, + Relabel, + Field, + Reorder, +) +from finchlite.finch_einsum import EinsumInterpreter +from finchlite.algebra import promote_max, promote_min + + +@pytest.fixture +def rng(): + return np.random.default_rng(42) + + +def test_simple_addition(rng): + """Test lowering of simple addition A + B""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + + # Create logic IR for C[i,j] = A[i,j] + B[i,j] + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query( + Alias("C"), + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("C"),)), + )) + + # Lower to einsum IR + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + + # Interpret einsum + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + # Compare with expected + expected = A + B + assert np.allclose(result, expected) + + +def test_scalar_multiplication(rng): + """Test lowering of scalar multiplication 2 * A""" + A = rng.random((4, 4)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("B"), + MapJoin( + Literal(operator.mul), + ( + Literal(2), + Relabel(Alias("A"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("B"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = 2 * A + assert np.allclose(result, expected) + + +def test_element_wise_operations(rng): + """Test lowering of element-wise operations""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + C = rng.random((3, 3)) + + # D[i,j] = A[i,j] * B[i,j] + C[i,j] + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query(Alias("C"), Table(C, (Field("i"), Field("j")))), + Query( + Alias("D"), + MapJoin( + Literal(operator.add), + ( + MapJoin( + Literal(operator.mul), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + Relabel(Alias("C"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("D"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = A * B + C + assert np.allclose(result, expected) + + +def test_sum_reduction(rng): + """Test lowering of sum reduction C[i] = sum_j A[i,j]""" + A = rng.random((3, 4)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("C"), + Aggregate( + Literal(operator.add), + Literal(0), # init value + Relabel(Alias("A"), (Field("i"), Field("j"))), + (Field("j"),), # sum over j + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.sum(A, axis=1) + assert np.allclose(result, expected) + + +def test_max_reduction(rng): + """Test lowering of max reduction C[i] = max_j A[i,j]""" + A = rng.random((3, 4)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("C"), + Aggregate( + Literal(promote_max), + Literal(-np.inf), # init value for max + Relabel(Alias("A"), (Field("i"), Field("j"))), + (Field("j"),), # max over j + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.max(A, axis=1) + assert np.allclose(result, expected) + + +def test_min_reduction(rng): + """Test lowering of min reduction C[i] = min_j A[i,j]""" + A = rng.random((3, 4)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("C"), + Aggregate( + Literal(promote_min), + Literal(np.inf), # init value for min + Relabel(Alias("A"), (Field("i"), Field("j"))), + (Field("j"),), # min over j + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.min(A, axis=1) + assert np.allclose(result, expected) + + +def test_matrix_multiplication(rng): + """Test lowering of matrix multiplication C[i,j] = sum_k A[i,k] * B[k,j]""" + A = rng.random((3, 4)) + B = rng.random((4, 5)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("k")))), + Query(Alias("B"), Table(B, (Field("k"), Field("j")))), + Query( + Alias("C"), + Aggregate( + Literal(operator.add), + Literal(0), + MapJoin( + Literal(operator.mul), + ( + Relabel(Alias("A"), (Field("i"), Field("k"))), + Relabel(Alias("B"), (Field("k"), Field("j"))), + ), + ), + (Field("k"),), # sum over k + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = A @ B + assert np.allclose(result, expected) + + +def test_nested_operations(rng): + """Test nested operations: D = (A + B) * C""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + C = rng.random((3, 3)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query(Alias("C"), Table(C, (Field("i"), Field("j")))), + Query( + Alias("D"), + MapJoin( + Literal(operator.mul), + ( + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + Relabel(Alias("C"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("D"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = (A + B) * C + assert np.allclose(result, expected) + + +def test_multiple_aggregations(rng): + """Test multiple aggregations in sequence""" + A = rng.random((3, 4, 5)) + + # First sum over k: B[i,j] = sum_k A[i,j,k] + # Then sum over j: C[i] = sum_j B[i,j] + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j"), Field("k")))), + Query( + Alias("B"), + Aggregate( + Literal(operator.add), + Literal(0), + Relabel(Alias("A"), (Field("i"), Field("j"), Field("k"))), + (Field("k"),), + ), + ), + Query( + Alias("C"), + Aggregate( + Literal(operator.add), + Literal(0), + Relabel(Alias("B"), (Field("i"), Field("j"))), + (Field("j"),), + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.sum(np.sum(A, axis=2), axis=1) + assert np.allclose(result, expected) + + +def test_aggregate_with_pointwise(rng): + """Test aggregation combined with pointwise operations""" + A = rng.random((3, 4)) + B = rng.random((3, 4)) + + # C[i] = sum_j (A[i,j] * B[i,j]) + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query( + Alias("C"), + Aggregate( + Literal(operator.add), + Literal(0), + MapJoin( + Literal(operator.mul), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + (Field("j"),), + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.sum(A * B, axis=1) + assert np.allclose(result, expected) + + +def test_transpose(rng): + """Test lowering of transpose operation""" + A = rng.random((3, 4)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("B"), + Reorder( + Relabel(Alias("A"), (Field("i"), Field("j"))), + (Field("j"), Field("i")), + ), + ), + Produces((Alias("B"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = A.T + assert np.allclose(result, expected) + + +def test_permutation_3d(rng): + """Test permutation of 3D tensor""" + A = rng.random((2, 3, 4)) + + # Permute from [i,j,k] to [k,i,j] + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j"), Field("k")))), + Query( + Alias("B"), + Reorder( + Relabel(Alias("A"), (Field("i"), Field("j"), Field("k"))), + (Field("k"), Field("i"), Field("j")), + ), + ), + Produces((Alias("B"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.transpose(A, (2, 0, 1)) + assert np.allclose(result, expected) + + +def test_multiple_outputs(rng): + """Test lowering with multiple output tensors""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query( + Alias("C"), + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + ), + Query( + Alias("D"), + MapJoin( + Literal(operator.mul), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("C"), Alias("D"))), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result_c, result_d = interpreter(einsum_plan) + + expected_c = A + B + expected_d = A * B + assert np.allclose(result_c, expected_c) + assert np.allclose(result_d, expected_d) + + +def test_empty_plan(): + """Test lowering of empty plan""" + plan = Plan(()) + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + assert len(einsum_plan.bodies) == 0 + assert len(einsum_plan.returnValues) == 0 + + +def test_scalar_operations(): + """Test operations with scalar results""" + A = np.array([[1, 2], [3, 4]]) + + # Total sum: result = sum_{i,j} A[i,j] + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("result"), + Aggregate( + Literal(operator.add), + Literal(0), + Relabel(Alias("A"), (Field("i"), Field("j"))), + (Field("i"), Field("j")), # sum over all dimensions + ), + ), + Produces((Alias("result"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = np.sum(A) + assert np.allclose(result, expected) + + +def test_nested_aggregate_in_pointwise(rng): + """Test aggregate inside a pointwise expression""" + A = rng.random((3, 4)) + + # C[i,j] = A[i,j] + (sum_k A[i,k]) + # This requires the aggregate to be computed separately + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query( + Alias("C"), + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Aggregate( + Literal(operator.add), + Literal(0), + Relabel(Alias("A"), (Field("i"), Field("k"))), + (Field("k"),), + ), + ), + ), + ), + Produces((Alias("C"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + row_sums = np.sum(A, axis=1, keepdims=True) + expected = A + row_sums + assert np.allclose(result, expected) + + +def test_commutative_flattening(rng): + """Test that commutative operations are flattened""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + C = rng.random((3, 3)) + D = rng.random((3, 3)) + + # (A + B) + (C + D) should be flattened to A + B + C + D + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query(Alias("C"), Table(C, (Field("i"), Field("j")))), + Query(Alias("D"), Table(D, (Field("i"), Field("j")))), + Query( + Alias("E"), + MapJoin( + Literal(operator.add), + ( + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("C"), (Field("i"), Field("j"))), + Relabel(Alias("D"), (Field("i"), Field("j"))), + ), + ), + ), + ), + ), + Produces((Alias("E"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = A + B + C + D + assert np.allclose(result, expected) + + +def test_non_commutative_order(): + """Test that non-commutative operations preserve order""" + A = np.array([[4.0, 6.0], [8.0, 10.0]]) + B = np.array([[2.0, 2.0], [2.0, 2.0]]) + C = np.array([[1.0, 1.0], [1.0, 1.0]]) + + # (A / B) / C should NOT be flattened + plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + Query(Alias("C"), Table(C, (Field("i"), Field("j")))), + Query( + Alias("D"), + MapJoin( + Literal(operator.truediv), + ( + MapJoin( + Literal(operator.truediv), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + Relabel(Alias("C"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("D"),)), + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = (A / B) / C + assert np.allclose(result, expected) + + +def test_nested_plan(rng): + """Test lowering of nested plans""" + A = rng.random((3, 3)) + B = rng.random((3, 3)) + + inner_plan = Plan(( + Query( + Alias("temp"), + MapJoin( + Literal(operator.add), + ( + Relabel(Alias("A"), (Field("i"), Field("j"))), + Relabel(Alias("B"), (Field("i"), Field("j"))), + ), + ), + ), + Produces((Alias("temp"),)), + )) + + outer_plan = Plan(( + Query(Alias("A"), Table(A, (Field("i"), Field("j")))), + Query(Alias("B"), Table(B, (Field("i"), Field("j")))), + inner_plan, + )) + + lowerer = EinsumLowerer() + einsum_plan, parameters = lowerer(outer_plan) + interpreter = EinsumInterpreter(bindings=parameters) + result = interpreter(einsum_plan) + + expected = A + B + assert np.allclose(result, expected) \ No newline at end of file From 2524294376528ee4c6ca2475f31e6c87494543b6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 10:16:24 -0400 Subject: [PATCH 04/57] * Fixed error in EinsumInterpreter with handling return values from ein.Plan * Fixed errors handling parameters in EinsumLowerer --- src/finchlite/autoschedule/einsum.py | 32 +- src/finchlite/finch_einsum/interpreter.py | 12 +- tests/test_einsum_lowerer.py | 654 ++-------------------- 3 files changed, 77 insertions(+), 621 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 1b115a0d..2c5e4af4 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,3 +1,5 @@ +from numpy import isin +from finchlite.algebra.tensor import Tensor import finchlite.finch_einsum as ein from finchlite.finch_logic import ( Plan, @@ -14,12 +16,14 @@ ) from finchlite.algebra import overwrite, init_value, is_commutative from collections.abc import Callable +from typing import Any +from finchlite.interface import Scalar class EinsumLowerer: alias_counter: int = 0 - def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Table]]: - parameters: dict[str, Table] = {} + def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: + parameters: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, parameters, definitions), parameters @@ -33,8 +37,13 @@ def rename_einsum( definitions[new_alias.name] = einsum return ein.Einsum(einsum.op, new_alias, einsum.idxs, einsum.arg) + def reorder_einsum( + self, einsum: ein.Einsum, idxs: tuple[ein.Index, ...] + ) -> ein.Einsum: + return ein.Einsum(einsum.op, einsum.tns, idxs, einsum.arg) + def compile_plan( - self, plan: Plan, parameters: dict[str, Table], definitions: dict[str, ein.Einsum] + self, plan: Plan, parameters: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: einsums: list[ein.Einsum] = [] returnValue: list[ein.EinsumExpr] = [] @@ -46,8 +55,10 @@ def compile_plan( einsums.extend(inner_plan.bodies) returnValue.extend(inner_plan.returnValues) break - case Query(Alias(name), Table(_, _)): - parameters[name] = body.rhs + case Query(Alias(name), Table(Literal(val), _)) if isinstance(val, Scalar): + parameters[name] = val.val + case Query(Alias(name), Table(Literal(tns), _)) if isinstance(tns, Tensor): + parameters[name] = tns.to_numpy() case Query(Alias(name), rhs): einsums.append( self.rename_einsum( @@ -81,7 +92,7 @@ def lower_to_einsum( self, ex: LogicNode, einsums: list[ein.Einsum], - parameters: dict[str, Table], + parameters: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.Einsum: match ex: @@ -102,9 +113,10 @@ def lower_to_einsum( arg=pointwise_expr, ) case Reorder(arg, idxs): - return self.lower_to_einsum( - arg, einsums, parameters, definitions - ).reorder(idxs) + return self.reorder_einsum( + self.lower_to_einsum(arg, einsums, parameters, definitions), + tuple(ein.Index(field.name) for field in idxs) + ) case Aggregate(Literal(operation), Literal(init), arg, idxs): if init != init_value(operation, type(init)): raise Exception(f""" @@ -154,7 +166,7 @@ def lower_to_pointwise( self, ex: LogicNode, einsums: list[ein.Einsum], - parameters: dict[str, Table], + parameters: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 01f256fe..dcbae557 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -95,18 +95,22 @@ def __call__(self, node): case ein.Access(tns, idxs): assert len(idxs) == len(set(idxs)) assert self.loops is not None + perm = [idxs.index(idx) for idx in self.loops if idx in idxs] tns = self(tns) + tns = xp.permute_dims(tns, perm) return xp.expand_dims( tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - case ein.Plan(bodies): - res = None + case ein.Plan(bodies, returnValues): for body in bodies: - res = self(body) - return res + self(body) #execute each einsum statement individually + + if returnValues: #return and evaluate the return values seperately + return tuple(self(rv) for rv in returnValues) if len(returnValues) > 1 else self(returnValues[0]) + return None case ein.Produces(args): return tuple(self(arg) for arg in args) case ein.Einsum(op, ein.Alias(tns), idxs, arg): diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 19b98236..51900e4e 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -1,640 +1,80 @@ import pytest import numpy as np -import operator +import finchlite from finchlite.autoschedule.einsum import EinsumLowerer -from finchlite.finch_logic import ( - Plan, - Produces, - Query, - Alias, - Table, - MapJoin, - Literal, - Aggregate, - Relabel, - Field, - Reorder, -) +from finchlite.autoschedule import optimize from finchlite.finch_einsum import EinsumInterpreter -from finchlite.algebra import promote_max, promote_min - +from finchlite.interface.fuse import compute +from finchlite.finch_logic import Plan, Query, Produces, Alias, LogicNode +from finchlite.interface.lazy import defer +from finchlite.symbolic import gensym +from finchlite.compile.bufferized_ndarray import BufferizedNDArray @pytest.fixture def rng(): return np.random.default_rng(42) +def lower_and_execute(ir: LogicNode): + """ + Helper function to optimize, lower, and execute a Logic IR plan. + + Args: + plan: The Logic IR plan to execute + + Returns: + The result of executing the einsum plan + """ + # Optimize into a plan + var = Alias(gensym("result")) + plan = Plan((Query(var, ir), Produces((var,)))) + optimized_plan = optimize(plan) + + # Lower to einsum IR + lowerer = EinsumLowerer() + einsum_plan, plan_parameters = lowerer(optimized_plan) + + # Interpret and execute + interpreter = EinsumInterpreter(bindings=plan_parameters) + return interpreter(einsum_plan) + + def test_simple_addition(rng): """Test lowering of simple addition A + B""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - - # Create logic IR for C[i,j] = A[i,j] + B[i,j] - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query( - Alias("C"), - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("C"),)), - )) + A = defer(rng.random((3, 3))) + B = defer(rng.random((3, 3))) - # Lower to einsum IR - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) + C = finchlite.add(A, B) - # Interpret einsum - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) + # Execute the plan + result = lower_and_execute(C.data) # Compare with expected - expected = A + B + expected = compute(A + B) assert np.allclose(result, expected) def test_scalar_multiplication(rng): """Test lowering of scalar multiplication 2 * A""" - A = rng.random((4, 4)) + A = defer(rng.random((4, 4))) - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("B"), - MapJoin( - Literal(operator.mul), - ( - Literal(2), - Relabel(Alias("A"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("B"),)), - )) + B = finchlite.multiply(2, A) - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) + result = lower_and_execute(B.data) - expected = 2 * A + expected = compute(B) assert np.allclose(result, expected) def test_element_wise_operations(rng): """Test lowering of element-wise operations""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - C = rng.random((3, 3)) - - # D[i,j] = A[i,j] * B[i,j] + C[i,j] - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query(Alias("C"), Table(C, (Field("i"), Field("j")))), - Query( - Alias("D"), - MapJoin( - Literal(operator.add), - ( - MapJoin( - Literal(operator.mul), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - Relabel(Alias("C"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("D"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = A * B + C - assert np.allclose(result, expected) - - -def test_sum_reduction(rng): - """Test lowering of sum reduction C[i] = sum_j A[i,j]""" - A = rng.random((3, 4)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("C"), - Aggregate( - Literal(operator.add), - Literal(0), # init value - Relabel(Alias("A"), (Field("i"), Field("j"))), - (Field("j"),), # sum over j - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.sum(A, axis=1) - assert np.allclose(result, expected) - - -def test_max_reduction(rng): - """Test lowering of max reduction C[i] = max_j A[i,j]""" - A = rng.random((3, 4)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("C"), - Aggregate( - Literal(promote_max), - Literal(-np.inf), # init value for max - Relabel(Alias("A"), (Field("i"), Field("j"))), - (Field("j"),), # max over j - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.max(A, axis=1) - assert np.allclose(result, expected) - - -def test_min_reduction(rng): - """Test lowering of min reduction C[i] = min_j A[i,j]""" - A = rng.random((3, 4)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("C"), - Aggregate( - Literal(promote_min), - Literal(np.inf), # init value for min - Relabel(Alias("A"), (Field("i"), Field("j"))), - (Field("j"),), # min over j - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.min(A, axis=1) - assert np.allclose(result, expected) - - -def test_matrix_multiplication(rng): - """Test lowering of matrix multiplication C[i,j] = sum_k A[i,k] * B[k,j]""" - A = rng.random((3, 4)) - B = rng.random((4, 5)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("k")))), - Query(Alias("B"), Table(B, (Field("k"), Field("j")))), - Query( - Alias("C"), - Aggregate( - Literal(operator.add), - Literal(0), - MapJoin( - Literal(operator.mul), - ( - Relabel(Alias("A"), (Field("i"), Field("k"))), - Relabel(Alias("B"), (Field("k"), Field("j"))), - ), - ), - (Field("k"),), # sum over k - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = A @ B - assert np.allclose(result, expected) - - -def test_nested_operations(rng): - """Test nested operations: D = (A + B) * C""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - C = rng.random((3, 3)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query(Alias("C"), Table(C, (Field("i"), Field("j")))), - Query( - Alias("D"), - MapJoin( - Literal(operator.mul), - ( - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - Relabel(Alias("C"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("D"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = (A + B) * C - assert np.allclose(result, expected) - - -def test_multiple_aggregations(rng): - """Test multiple aggregations in sequence""" - A = rng.random((3, 4, 5)) - - # First sum over k: B[i,j] = sum_k A[i,j,k] - # Then sum over j: C[i] = sum_j B[i,j] - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j"), Field("k")))), - Query( - Alias("B"), - Aggregate( - Literal(operator.add), - Literal(0), - Relabel(Alias("A"), (Field("i"), Field("j"), Field("k"))), - (Field("k"),), - ), - ), - Query( - Alias("C"), - Aggregate( - Literal(operator.add), - Literal(0), - Relabel(Alias("B"), (Field("i"), Field("j"))), - (Field("j"),), - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.sum(np.sum(A, axis=2), axis=1) - assert np.allclose(result, expected) - - -def test_aggregate_with_pointwise(rng): - """Test aggregation combined with pointwise operations""" - A = rng.random((3, 4)) - B = rng.random((3, 4)) - - # C[i] = sum_j (A[i,j] * B[i,j]) - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query( - Alias("C"), - Aggregate( - Literal(operator.add), - Literal(0), - MapJoin( - Literal(operator.mul), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - (Field("j"),), - ), - ), - Produces((Alias("C"),)), - )) + A = defer(rng.random((3, 3))) + B = defer(rng.random((3, 3))) + C = defer(rng.random((3, 3))) - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.sum(A * B, axis=1) - assert np.allclose(result, expected) - - -def test_transpose(rng): - """Test lowering of transpose operation""" - A = rng.random((3, 4)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("B"), - Reorder( - Relabel(Alias("A"), (Field("i"), Field("j"))), - (Field("j"), Field("i")), - ), - ), - Produces((Alias("B"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = A.T - assert np.allclose(result, expected) - - -def test_permutation_3d(rng): - """Test permutation of 3D tensor""" - A = rng.random((2, 3, 4)) - - # Permute from [i,j,k] to [k,i,j] - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j"), Field("k")))), - Query( - Alias("B"), - Reorder( - Relabel(Alias("A"), (Field("i"), Field("j"), Field("k"))), - (Field("k"), Field("i"), Field("j")), - ), - ), - Produces((Alias("B"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) + D = finchlite.add(finchlite.multiply(A, B), C) - expected = np.transpose(A, (2, 0, 1)) - assert np.allclose(result, expected) - - -def test_multiple_outputs(rng): - """Test lowering with multiple output tensors""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query( - Alias("C"), - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - ), - Query( - Alias("D"), - MapJoin( - Literal(operator.mul), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("C"), Alias("D"))), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result_c, result_d = interpreter(einsum_plan) - - expected_c = A + B - expected_d = A * B - assert np.allclose(result_c, expected_c) - assert np.allclose(result_d, expected_d) - - -def test_empty_plan(): - """Test lowering of empty plan""" - plan = Plan(()) - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - assert len(einsum_plan.bodies) == 0 - assert len(einsum_plan.returnValues) == 0 - - -def test_scalar_operations(): - """Test operations with scalar results""" - A = np.array([[1, 2], [3, 4]]) - - # Total sum: result = sum_{i,j} A[i,j] - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("result"), - Aggregate( - Literal(operator.add), - Literal(0), - Relabel(Alias("A"), (Field("i"), Field("j"))), - (Field("i"), Field("j")), # sum over all dimensions - ), - ), - Produces((Alias("result"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = np.sum(A) - assert np.allclose(result, expected) - - -def test_nested_aggregate_in_pointwise(rng): - """Test aggregate inside a pointwise expression""" - A = rng.random((3, 4)) - - # C[i,j] = A[i,j] + (sum_k A[i,k]) - # This requires the aggregate to be computed separately - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query( - Alias("C"), - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Aggregate( - Literal(operator.add), - Literal(0), - Relabel(Alias("A"), (Field("i"), Field("k"))), - (Field("k"),), - ), - ), - ), - ), - Produces((Alias("C"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - row_sums = np.sum(A, axis=1, keepdims=True) - expected = A + row_sums - assert np.allclose(result, expected) - - -def test_commutative_flattening(rng): - """Test that commutative operations are flattened""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - C = rng.random((3, 3)) - D = rng.random((3, 3)) - - # (A + B) + (C + D) should be flattened to A + B + C + D - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query(Alias("C"), Table(C, (Field("i"), Field("j")))), - Query(Alias("D"), Table(D, (Field("i"), Field("j")))), - Query( - Alias("E"), - MapJoin( - Literal(operator.add), - ( - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("C"), (Field("i"), Field("j"))), - Relabel(Alias("D"), (Field("i"), Field("j"))), - ), - ), - ), - ), - ), - Produces((Alias("E"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = A + B + C + D - assert np.allclose(result, expected) - - -def test_non_commutative_order(): - """Test that non-commutative operations preserve order""" - A = np.array([[4.0, 6.0], [8.0, 10.0]]) - B = np.array([[2.0, 2.0], [2.0, 2.0]]) - C = np.array([[1.0, 1.0], [1.0, 1.0]]) - - # (A / B) / C should NOT be flattened - plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - Query(Alias("C"), Table(C, (Field("i"), Field("j")))), - Query( - Alias("D"), - MapJoin( - Literal(operator.truediv), - ( - MapJoin( - Literal(operator.truediv), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - Relabel(Alias("C"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("D"),)), - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) - - expected = (A / B) / C - assert np.allclose(result, expected) - - -def test_nested_plan(rng): - """Test lowering of nested plans""" - A = rng.random((3, 3)) - B = rng.random((3, 3)) - - inner_plan = Plan(( - Query( - Alias("temp"), - MapJoin( - Literal(operator.add), - ( - Relabel(Alias("A"), (Field("i"), Field("j"))), - Relabel(Alias("B"), (Field("i"), Field("j"))), - ), - ), - ), - Produces((Alias("temp"),)), - )) - - outer_plan = Plan(( - Query(Alias("A"), Table(A, (Field("i"), Field("j")))), - Query(Alias("B"), Table(B, (Field("i"), Field("j")))), - inner_plan, - )) - - lowerer = EinsumLowerer() - einsum_plan, parameters = lowerer(outer_plan) - interpreter = EinsumInterpreter(bindings=parameters) - result = interpreter(einsum_plan) + result = lower_and_execute(D.data) - expected = A + B + expected = compute(D) assert np.allclose(result, expected) \ No newline at end of file From f0f2112d474881b50454ba1937fd4a298c174aba Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 11:19:28 -0400 Subject: [PATCH 05/57] * Addressed issue #209 by properly initializing Einsum with return values instead of relying on Produce --- src/finchlite/finch_einsum/interpreter.py | 2 +- src/finchlite/interface/lazy.py | 6 ++++-- tests/test_einsum_lowerer.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index dcbae557..6b586392 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -109,7 +109,7 @@ def __call__(self, node): self(body) #execute each einsum statement individually if returnValues: #return and evaluate the return values seperately - return tuple(self(rv) for rv in returnValues) if len(returnValues) > 1 else self(returnValues[0]) + return tuple(self(rv) for rv in returnValues) return None case ein.Produces(args): return tuple(self(arg) for arg in args) diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index 638affaa..f5c5f05f 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -1805,7 +1805,8 @@ def std( def einop(prgm, **kwargs): stmt = ein.parse_einop(prgm) - prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) + prgm = ein.Plan((stmt, ), (stmt.tns,)) + xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, dict(**kwargs)) return ctx(prgm)[0] @@ -1813,7 +1814,8 @@ def einop(prgm, **kwargs): def einsum(prgm, *args, **kwargs): stmt, bindings = ein.parse_einsum(prgm, *args) - prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) + prgm = ein.Plan((stmt, ), (stmt.tns,)) + xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, bindings) return ctx(prgm)[0] diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 51900e4e..6ef9af5a 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -36,7 +36,7 @@ def lower_and_execute(ir: LogicNode): # Interpret and execute interpreter = EinsumInterpreter(bindings=plan_parameters) - return interpreter(einsum_plan) + return interpreter(einsum_plan)[0] def test_simple_addition(rng): From 6da9ff7cc0ce8e135c659f3953fe9e152e76d277 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 11:23:17 -0400 Subject: [PATCH 06/57] * Addressed issue #209 by removing redudant Produce IR node --- src/finchlite/finch_einsum/__init__.py | 4 +-- src/finchlite/finch_einsum/interpreter.py | 4 +-- src/finchlite/finch_einsum/nodes.py | 40 +++++++++++------------ 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index 223a5d4a..5196bc28 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -9,7 +9,7 @@ Index, Literal, Plan, - Produces, +# Produces, ) from .parser import parse_einop, parse_einsum @@ -27,7 +27,7 @@ "Index", "Literal", "Plan", - "Produces", +# "Produces", "parse_einop", "parse_einsum", ] diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 6b586392..bd2d0dac 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -111,8 +111,8 @@ def __call__(self, node): if returnValues: #return and evaluate the return values seperately return tuple(self(rv) for rv in returnValues) return None - case ein.Produces(args): - return tuple(self(arg) for arg in args) + #case ein.Produces(args): + # return tuple(self(arg) for arg in args) case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution loops = arg.get_idxs() diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 5129443d..123b0b66 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -244,26 +244,26 @@ def children(self): return [*self.bodies, self.returnValues] -@dataclass(eq=True, frozen=True) -class Produces(EinsumTree): - """ - Represents a logical AST statement that returns `args...` from the current plan. - Halts execution of the program. - - Attributes: - args: The arguments to return. - """ - - args: tuple[EinsumNode, ...] - - @property - def children(self): - """Returns the children of the node.""" - return [*self.args] - - @classmethod - def from_children(cls, *args): - return cls(args) +#@dataclass(eq=True, frozen=True) +#class Produces(EinsumTree): +# """ +# Represents a logical AST statement that returns `args...` from the current plan. +# Halts execution of the program. +# +# Attributes: +# args: The arguments to return. +# """ + +# args: tuple[EinsumNode, ...] + +# @property +# def children(self): +# """Returns the children of the node.""" +# return [*self.args] + +# @classmethod +# def from_children(cls, *args): +# return cls(args) infix_strs = { From 4969b25f1b21e8346224e52526150ea1b4ce87f6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 11:40:48 -0400 Subject: [PATCH 07/57] * Added more pytests --- tests/test_einsum_lowerer.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 6ef9af5a..ec3eb0a7 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -77,4 +77,46 @@ def test_element_wise_operations(rng): result = lower_and_execute(D.data) expected = compute(D) + assert np.allclose(result, expected) + +def test_sum_reduction(rng): + """Test sum reduction using +=""" + A = defer(rng.random((3, 4))) + + B = finchlite.sum(A, axis=1) + + result = lower_and_execute(B.data) + + expected = compute(B) + assert np.allclose(result, expected) + +def test_maximum_reduction(rng): + """Test maximum reduction using max=""" + A = defer(rng.random((3, 4))) + + B = finchlite.max(A, axis=1) + + result = lower_and_execute(B.data) + expected = compute(B) + assert np.allclose(result, expected) + +def test_batch_matrix_multiplication(rng): + """Test batch matrix multiplication using +=""" + A = defer(rng.random((2, 3, 4))) + B = defer(rng.random((2, 4, 5))) + + C = finchlite.matmul(A, B) + + result = lower_and_execute(C.data) + expected = compute(C) + assert np.allclose(result, expected) + +def test_minimum_reduction(rng): + """Test minimum reduction using min=""" + A = defer(rng.random((3, 4))) + + B = finchlite.min(A, axis=1) + + result = lower_and_execute(B.data) + expected = compute(B) assert np.allclose(result, expected) \ No newline at end of file From b80fb7cbbb99ac1f7677dc1cb7e15aacfb224c1f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 11:44:25 -0400 Subject: [PATCH 08/57] * Fixed ruff errors --- src/finchlite/autoschedule/einsum.py | 30 ++++++++++++++--------- src/finchlite/finch_einsum/__init__.py | 4 +-- src/finchlite/finch_einsum/interpreter.py | 8 +++--- src/finchlite/finch_einsum/nodes.py | 4 +-- src/finchlite/interface/lazy.py | 4 +-- tests/test_einsum_lowerer.py | 17 ++++++++----- 6 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 2c5e4af4..dcb8dfca 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -19,6 +19,7 @@ from typing import Any from finchlite.interface import Scalar + class EinsumLowerer: alias_counter: int = 0 @@ -32,7 +33,10 @@ def get_next_alias(self) -> ein.Alias: return ein.Alias(f"einsum_{self.alias_counter}") def rename_einsum( - self, einsum: ein.Einsum, new_alias: ein.Alias, definitions: dict[str, ein.Einsum] + self, + einsum: ein.Einsum, + new_alias: ein.Alias, + definitions: dict[str, ein.Einsum], ) -> ein.Einsum: definitions[new_alias.name] = einsum return ein.Einsum(einsum.op, new_alias, einsum.idxs, einsum.arg) @@ -55,9 +59,13 @@ def compile_plan( einsums.extend(inner_plan.bodies) returnValue.extend(inner_plan.returnValues) break - case Query(Alias(name), Table(Literal(val), _)) if isinstance(val, Scalar): + case Query(Alias(name), Table(Literal(val), _)) if isinstance( + val, Scalar + ): parameters[name] = val.val - case Query(Alias(name), Table(Literal(tns), _)) if isinstance(tns, Tensor): + case Query(Alias(name), Table(Literal(tns), _)) if isinstance( + tns, Tensor + ): parameters[name] = tns.to_numpy() case Query(Alias(name), rhs): einsums.append( @@ -107,15 +115,13 @@ def lower_to_einsum( return ein.Einsum( op=ein.Literal(overwrite), tns=self.get_next_alias(), - idxs=tuple( - ein.Index(field.name) for field in ex.fields - ), + idxs=tuple(ein.Index(field.name) for field in ex.fields), arg=pointwise_expr, ) case Reorder(arg, idxs): return self.reorder_einsum( self.lower_to_einsum(arg, einsums, parameters, definitions), - tuple(ein.Index(field.name) for field in idxs) + tuple(ein.Index(field.name) for field in idxs), ) case Aggregate(Literal(operation), Literal(init), arg, idxs): if init != init_value(operation, type(init)): @@ -128,10 +134,10 @@ def lower_to_einsum( arg, einsums, parameters, definitions ) return ein.Einsum( - op = ein.Literal(operation), - tns = self.get_next_alias(), - idxs = tuple(ein.Index(field.name) for field in ex.fields), - arg = aggregate_expr + op=ein.Literal(operation), + tns=self.get_next_alias(), + idxs=tuple(ein.Index(field.name) for field in ex.fields), + arg=aggregate_expr, ) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -203,4 +209,4 @@ def lower_to_pointwise( idxs=tuple(ein.Index(field.name) for field in ex.fields), ) case _: - raise Exception(f"Unrecognized logic: {ex}") \ No newline at end of file + raise Exception(f"Unrecognized logic: {ex}") diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index 5196bc28..cabb2540 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -9,7 +9,7 @@ Index, Literal, Plan, -# Produces, + # Produces, ) from .parser import parse_einop, parse_einsum @@ -27,7 +27,7 @@ "Index", "Literal", "Plan", -# "Produces", + # "Produces", "parse_einop", "parse_einsum", ] diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index bd2d0dac..f8e0f7a4 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -106,12 +106,12 @@ def __call__(self, node): ) case ein.Plan(bodies, returnValues): for body in bodies: - self(body) #execute each einsum statement individually - - if returnValues: #return and evaluate the return values seperately + self(body) # execute each einsum statement individually + + if returnValues: # return and evaluate the return values seperately return tuple(self(rv) for rv in returnValues) return None - #case ein.Produces(args): + # case ein.Produces(args): # return tuple(self(arg) for arg in args) case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 123b0b66..c9e506e7 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -244,8 +244,8 @@ def children(self): return [*self.bodies, self.returnValues] -#@dataclass(eq=True, frozen=True) -#class Produces(EinsumTree): +# @dataclass(eq=True, frozen=True) +# class Produces(EinsumTree): # """ # Represents a logical AST statement that returns `args...` from the current plan. # Halts execution of the program. diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index f5c5f05f..4b3f8e9e 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -1805,7 +1805,7 @@ def std( def einop(prgm, **kwargs): stmt = ein.parse_einop(prgm) - prgm = ein.Plan((stmt, ), (stmt.tns,)) + prgm = ein.Plan((stmt,), (stmt.tns,)) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, dict(**kwargs)) @@ -1814,7 +1814,7 @@ def einop(prgm, **kwargs): def einsum(prgm, *args, **kwargs): stmt, bindings = ein.parse_einsum(prgm, *args) - prgm = ein.Plan((stmt, ), (stmt.tns,)) + prgm = ein.Plan((stmt,), (stmt.tns,)) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, bindings) diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index ec3eb0a7..6f01d5a5 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -10,6 +10,7 @@ from finchlite.symbolic import gensym from finchlite.compile.bufferized_ndarray import BufferizedNDArray + @pytest.fixture def rng(): return np.random.default_rng(42) @@ -18,10 +19,10 @@ def rng(): def lower_and_execute(ir: LogicNode): """ Helper function to optimize, lower, and execute a Logic IR plan. - + Args: plan: The Logic IR plan to execute - + Returns: The result of executing the einsum plan """ @@ -29,11 +30,11 @@ def lower_and_execute(ir: LogicNode): var = Alias(gensym("result")) plan = Plan((Query(var, ir), Produces((var,)))) optimized_plan = optimize(plan) - + # Lower to einsum IR lowerer = EinsumLowerer() einsum_plan, plan_parameters = lowerer(optimized_plan) - + # Interpret and execute interpreter = EinsumInterpreter(bindings=plan_parameters) return interpreter(einsum_plan)[0] @@ -79,6 +80,7 @@ def test_element_wise_operations(rng): expected = compute(D) assert np.allclose(result, expected) + def test_sum_reduction(rng): """Test sum reduction using +=""" A = defer(rng.random((3, 4))) @@ -86,10 +88,11 @@ def test_sum_reduction(rng): B = finchlite.sum(A, axis=1) result = lower_and_execute(B.data) - + expected = compute(B) assert np.allclose(result, expected) + def test_maximum_reduction(rng): """Test maximum reduction using max=""" A = defer(rng.random((3, 4))) @@ -100,6 +103,7 @@ def test_maximum_reduction(rng): expected = compute(B) assert np.allclose(result, expected) + def test_batch_matrix_multiplication(rng): """Test batch matrix multiplication using +=""" A = defer(rng.random((2, 3, 4))) @@ -111,6 +115,7 @@ def test_batch_matrix_multiplication(rng): expected = compute(C) assert np.allclose(result, expected) + def test_minimum_reduction(rng): """Test minimum reduction using min=""" A = defer(rng.random((3, 4))) @@ -119,4 +124,4 @@ def test_minimum_reduction(rng): result = lower_and_execute(B.data) expected = compute(B) - assert np.allclose(result, expected) \ No newline at end of file + assert np.allclose(result, expected) From 6ab474fe8aa220f4be2b81897a84319045d27d1c Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 11:45:19 -0400 Subject: [PATCH 09/57] * Fixed more ruff errors --- src/finchlite/autoschedule/einsum.py | 24 ++++++++++++------------ src/finchlite/finch_einsum/nodes.py | 8 ++++---- tests/test_einsum_lowerer.py | 7 ++++--- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index dcb8dfca..eb5b5883 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,22 +1,22 @@ -from numpy import isin -from finchlite.algebra.tensor import Tensor +from collections.abc import Callable +from typing import Any + import finchlite.finch_einsum as ein +from finchlite.algebra import init_value, is_commutative, overwrite +from finchlite.algebra.tensor import Tensor from finchlite.finch_logic import ( - Plan, - Produces, - Query, + Aggregate, Alias, - Table, + Literal, LogicNode, MapJoin, - Literal, - Reorder, - Aggregate, + Plan, + Produces, + Query, Relabel, + Reorder, + Table, ) -from finchlite.algebra import overwrite, init_value, is_commutative -from collections.abc import Callable -from typing import Any from finchlite.interface import Scalar diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index c9e506e7..af529879 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -351,9 +351,9 @@ def __call__(self, prgm: EinsumNode): ctx_2(body) self.exec(ctx_2.emit()) return None - case Produces(args): - args = tuple(self(arg) for arg in args) - self.exec(f"{feed}return {args}\n") - return None + #case Produces(args): + # args = tuple(self(arg) for arg in args) + # self.exec(f"{feed}return {args}\n") + # return None case _: raise ValueError(f"Unknown expression type: {type(prgm)}") diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 6f01d5a5..8c14ef3f 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -1,14 +1,15 @@ import pytest + import numpy as np + import finchlite -from finchlite.autoschedule.einsum import EinsumLowerer from finchlite.autoschedule import optimize +from finchlite.autoschedule.einsum import EinsumLowerer from finchlite.finch_einsum import EinsumInterpreter +from finchlite.finch_logic import Alias, LogicNode, Plan, Produces, Query from finchlite.interface.fuse import compute -from finchlite.finch_logic import Plan, Query, Produces, Alias, LogicNode from finchlite.interface.lazy import defer from finchlite.symbolic import gensym -from finchlite.compile.bufferized_ndarray import BufferizedNDArray @pytest.fixture From 6358d81054741cf6bb2dccd991aa7f5bedb8e6cc Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 12:16:35 -0400 Subject: [PATCH 10/57] * Added support to EinsumPrinterContext to print Plan return values * Modified pytests accordingly --- src/finchlite/finch_einsum/nodes.py | 11 +++++++++-- tests/reference/test_einsum_printer.txt | 4 +++- tests/test_printers.py | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index af529879..f41d5dc6 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -319,7 +319,7 @@ def subblock(self): return blk def __call__(self, prgm: EinsumNode): - feed = self.feed + #feed = self.feed match prgm: case Literal(value): return qual_str(value).replace("\n", "") @@ -344,12 +344,19 @@ def __call__(self, prgm: EinsumNode): f"{op_str}= {self(arg)}" ) return None - case Plan(bodies): + case Plan(bodies, returnValues): self.exec(f"{self.feed}plan:") ctx_2 = self.subblock() for body in bodies: ctx_2(body) self.exec(ctx_2.emit()) + + if len(returnValues) > 0: + self.exec(f"{self.feed}returnValues:") + ctx_3 = self.subblock() + for returnValue in returnValues: + ctx_3.exec(f"{ctx_3.feed}{ctx_3(returnValue)}") + self.exec(ctx_3.emit()) return None #case Produces(args): # args = tuple(self(arg) for arg in args) diff --git a/tests/reference/test_einsum_printer.txt b/tests/reference/test_einsum_printer.txt index f284c993..bf766d81 100644 --- a/tests/reference/test_einsum_printer.txt +++ b/tests/reference/test_einsum_printer.txt @@ -3,4 +3,6 @@ plan: plan: D[i, j] += (A[i, k] * B[k, j]) E[i] min= lshift((A[i, k] + D[k, j]), 1) - return ('C', 'E') +returnValues: + C + E diff --git a/tests/test_printers.py b/tests/test_printers.py index e5b373a6..7a1efcd3 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -542,8 +542,8 @@ def test_einsum_printer(file_regression): ein.parse_einop("E[i] min= A[i,k] + D[k,j] << 1"), ) ), - ein.Produces((ein.Alias("C"), ein.Alias("E"))), - ) + ), + (ein.Alias("C"), ein.Alias("E")) ) file_regression.check(str(prgm), extension=".txt") From 2f6c8f8c2ce11a0cd3f36bce0ffacde5a273d39a Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 12:18:38 -0400 Subject: [PATCH 11/57] * Ran ruff format --- src/finchlite/finch_einsum/nodes.py | 4 ++-- tests/test_printers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index f41d5dc6..428a191a 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -319,7 +319,7 @@ def subblock(self): return blk def __call__(self, prgm: EinsumNode): - #feed = self.feed + # feed = self.feed match prgm: case Literal(value): return qual_str(value).replace("\n", "") @@ -358,7 +358,7 @@ def __call__(self, prgm: EinsumNode): ctx_3.exec(f"{ctx_3.feed}{ctx_3(returnValue)}") self.exec(ctx_3.emit()) return None - #case Produces(args): + # case Produces(args): # args = tuple(self(arg) for arg in args) # self.exec(f"{feed}return {args}\n") # return None diff --git a/tests/test_printers.py b/tests/test_printers.py index 7a1efcd3..960cf051 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -543,7 +543,7 @@ def test_einsum_printer(file_regression): ) ), ), - (ein.Alias("C"), ein.Alias("E")) + (ein.Alias("C"), ein.Alias("E")), ) file_regression.check(str(prgm), extension=".txt") From 21213c5d05c8f432d6bfe09cc350acbb24b0861b Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 12:29:07 -0400 Subject: [PATCH 12/57] * Fixed mypy type errors * Properly handled einsums in return value --- src/finchlite/autoschedule/einsum.py | 28 +++++++++++++++++----------- tests/test_einsum_lowerer.py | 3 ++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index eb5b5883..2b2751b9 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,6 +1,8 @@ from collections.abc import Callable from typing import Any +import numpy as np + import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite from finchlite.algebra.tensor import Tensor @@ -19,18 +21,16 @@ ) from finchlite.interface import Scalar +from finchlite.symbolic import gensym class EinsumLowerer: - alias_counter: int = 0 - def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: parameters: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, parameters, definitions), parameters def get_next_alias(self) -> ein.Alias: - self.alias_counter += 1 - return ein.Alias(f"einsum_{self.alias_counter}") + return ein.Alias(gensym("einsum")) def rename_einsum( self, @@ -66,7 +66,7 @@ def compile_plan( case Query(Alias(name), Table(Literal(tns), _)) if isinstance( tns, Tensor ): - parameters[name] = tns.to_numpy() + parameters[name] = np.asarray(tns) case Query(Alias(name), rhs): einsums.append( self.rename_einsum( @@ -76,12 +76,18 @@ def compile_plan( ) ) case Produces(args): - returnValue = [ - ein.Alias(arg.name) - if isinstance(arg, Alias) - else self.lower_to_einsum(arg, einsums, parameters, definitions) - for arg in args - ] + returnValue = [] + for arg in args: + if isinstance(arg, Alias): + returnValue.append(ein.Alias(arg.name)) + else: + einsum = self.rename_einsum( + self.lower_to_einsum(arg, einsums, parameters, definitions), + self.get_next_alias(), + definitions, + ) + einsums.append(einsum) + returnValue.append(einsum.tns) break case _: einsums.append( diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 8c14ef3f..881b269f 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -10,6 +10,7 @@ from finchlite.interface.fuse import compute from finchlite.interface.lazy import defer from finchlite.symbolic import gensym +from typing import cast @pytest.fixture @@ -30,7 +31,7 @@ def lower_and_execute(ir: LogicNode): # Optimize into a plan var = Alias(gensym("result")) plan = Plan((Query(var, ir), Produces((var,)))) - optimized_plan = optimize(plan) + optimized_plan = cast(Plan, optimize(plan)) # Lower to einsum IR lowerer = EinsumLowerer() From 5f36539c488a240a79a2d84b32b5dee78902493b Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 12:32:43 -0400 Subject: [PATCH 13/57] Still have to invoke tns.to_numpy; added type safe attribute checking mechanism --- src/finchlite/autoschedule/einsum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 2b2751b9..03887134 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -66,7 +66,7 @@ def compile_plan( case Query(Alias(name), Table(Literal(tns), _)) if isinstance( tns, Tensor ): - parameters[name] = np.asarray(tns) + parameters[name] = tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) # type: ignore[attr-defined] case Query(Alias(name), rhs): einsums.append( self.rename_einsum( From c4e2eda0e1056c9b2858ef54e7e09d86ad403db1 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 10 Oct 2025 12:34:10 -0400 Subject: [PATCH 14/57] * Ran ruff --- src/finchlite/autoschedule/einsum.py | 10 +++++++--- tests/test_einsum_lowerer.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 03887134..ecbfe4e7 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -20,9 +20,9 @@ Table, ) from finchlite.interface import Scalar - from finchlite.symbolic import gensym + class EinsumLowerer: def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: parameters: dict[str, Any] = {} @@ -66,7 +66,9 @@ def compile_plan( case Query(Alias(name), Table(Literal(tns), _)) if isinstance( tns, Tensor ): - parameters[name] = tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) # type: ignore[attr-defined] + parameters[name] = ( + tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) + ) # type: ignore[attr-defined] case Query(Alias(name), rhs): einsums.append( self.rename_einsum( @@ -82,7 +84,9 @@ def compile_plan( returnValue.append(ein.Alias(arg.name)) else: einsum = self.rename_einsum( - self.lower_to_einsum(arg, einsums, parameters, definitions), + self.lower_to_einsum( + arg, einsums, parameters, definitions + ), self.get_next_alias(), definitions, ) diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 881b269f..43be026f 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -1,3 +1,5 @@ +from typing import cast + import pytest import numpy as np @@ -10,7 +12,6 @@ from finchlite.interface.fuse import compute from finchlite.interface.lazy import defer from finchlite.symbolic import gensym -from typing import cast @pytest.fixture From 9be7e18030914c873939a8368ac4b856bae55ad6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 14 Oct 2025 09:55:03 -0400 Subject: [PATCH 15/57] * Restored produce einsum ir node --- src/finchlite/finch_einsum/nodes.py | 59 ++++++++++++----------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 428a191a..eece5ea1 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -225,45 +225,43 @@ class Plan(EinsumTree): """ bodies: tuple[Einsum, ...] = () - returnValues: tuple[EinsumExpr, ...] = () @classmethod def from_children(cls, *children: Term) -> Self: # The last child is the returnValues tuple, all others are bodies if len(children) < 1: raise ValueError("Plan expects at least 1 child") - *bodies, returnValues = children + bodies = children return cls( tuple(cast(Einsum, b) for b in bodies), - cast(tuple[EinsumExpr, ...], returnValues), ) @property def children(self): - return [*self.bodies, self.returnValues] + return [*self.bodies] -# @dataclass(eq=True, frozen=True) -# class Produces(EinsumTree): -# """ -# Represents a logical AST statement that returns `args...` from the current plan. -# Halts execution of the program. -# -# Attributes: -# args: The arguments to return. -# """ +@dataclass(eq=True, frozen=True) +class Produces(EinsumTree): + """ + Represents a logical AST statement that returns `args...` from the current plan. + Halts execution of the program. + + Attributes: + args: The arguments to return. + """ -# args: tuple[EinsumNode, ...] + args: tuple[EinsumNode, ...] -# @property -# def children(self): -# """Returns the children of the node.""" -# return [*self.args] + @property + def children(self): + """Returns the children of the node.""" + return [*self.args] -# @classmethod -# def from_children(cls, *args): -# return cls(args) + @classmethod + def from_children(cls, *args): + return cls(args) infix_strs = { @@ -319,7 +317,7 @@ def subblock(self): return blk def __call__(self, prgm: EinsumNode): - # feed = self.feed + feed = self.feed match prgm: case Literal(value): return qual_str(value).replace("\n", "") @@ -344,23 +342,16 @@ def __call__(self, prgm: EinsumNode): f"{op_str}= {self(arg)}" ) return None - case Plan(bodies, returnValues): + case Plan(bodies): self.exec(f"{self.feed}plan:") ctx_2 = self.subblock() for body in bodies: ctx_2(body) self.exec(ctx_2.emit()) - - if len(returnValues) > 0: - self.exec(f"{self.feed}returnValues:") - ctx_3 = self.subblock() - for returnValue in returnValues: - ctx_3.exec(f"{ctx_3.feed}{ctx_3(returnValue)}") - self.exec(ctx_3.emit()) return None - # case Produces(args): - # args = tuple(self(arg) for arg in args) - # self.exec(f"{feed}return {args}\n") - # return None + case Produces(args): + args = tuple(self(arg) for arg in args) + self.exec(f"{feed}return {args}\n") + return None case _: raise ValueError(f"Unknown expression type: {type(prgm)}") From 4f7446c65f8eb6ad678e240bb0f51cfb58a1c228 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 18 Oct 2025 10:01:03 -0400 Subject: [PATCH 16/57] * Removed return values from Einsum IR Node * Return values will now be encoded by ein.Produce nodes in the Einsum's bodies tuple --- src/finchlite/autoschedule/einsum.py | 29 +++++++++++------------ src/finchlite/finch_einsum/__init__.py | 4 ++-- src/finchlite/finch_einsum/interpreter.py | 14 +++++------ src/finchlite/finch_einsum/nodes.py | 4 ++-- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index ecbfe4e7..921cc88d 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -49,15 +49,13 @@ def reorder_einsum( def compile_plan( self, plan: Plan, parameters: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: - einsums: list[ein.Einsum] = [] - returnValue: list[ein.EinsumExpr] = [] + bodies: list[ein.EinsumNode] = [] for body in plan.bodies: match body: case Plan(_): inner_plan = self.compile_plan(body, parameters, definitions) - einsums.extend(inner_plan.bodies) - returnValue.extend(inner_plan.returnValues) + bodies.extend(inner_plan.bodies) break case Query(Alias(name), Table(Literal(val), _)) if isinstance( val, Scalar @@ -70,41 +68,42 @@ def compile_plan( tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) ) # type: ignore[attr-defined] case Query(Alias(name), rhs): - einsums.append( + bodies.append( self.rename_einsum( - self.lower_to_einsum(rhs, einsums, parameters, definitions), + self.lower_to_einsum(rhs, bodies, parameters, definitions), ein.Alias(name), definitions, ) ) case Produces(args): - returnValue = [] + returnValues = [] for arg in args: if isinstance(arg, Alias): - returnValue.append(ein.Alias(arg.name)) + returnValues.append(ein.Alias(arg.name)) else: einsum = self.rename_einsum( self.lower_to_einsum( - arg, einsums, parameters, definitions + arg, bodies, parameters, definitions ), self.get_next_alias(), definitions, ) - einsums.append(einsum) - returnValue.append(einsum.tns) - break + bodies.append(einsum) + returnValues.append(einsum.tns) + + bodies.append(ein.Produces(tuple(returnValues))) case _: - einsums.append( + bodies.append( self.rename_einsum( self.lower_to_einsum( - body, einsums, parameters, definitions + body, bodies, parameters, definitions ), self.get_next_alias(), definitions, ) ) - return ein.Plan(tuple(einsums), tuple(returnValue)) + return ein.Plan(tuple(bodies)) def lower_to_einsum( self, diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index cabb2540..223a5d4a 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -9,7 +9,7 @@ Index, Literal, Plan, - # Produces, + Produces, ) from .parser import parse_einop, parse_einsum @@ -27,7 +27,7 @@ "Index", "Literal", "Plan", - # "Produces", + "Produces", "parse_einop", "parse_einsum", ] diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index f8e0f7a4..20806039 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -104,15 +104,13 @@ def __call__(self, node): tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - case ein.Plan(bodies, returnValues): + case ein.Plan(bodies): + returnVal = None for body in bodies: - self(body) # execute each einsum statement individually - - if returnValues: # return and evaluate the return values seperately - return tuple(self(rv) for rv in returnValues) - return None - # case ein.Produces(args): - # return tuple(self(arg) for arg in args) + returnVal = self(body) # execute each einsum statement individually + return returnVal + case ein.Produces(args): + return tuple(self(arg) for arg in args) case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution loops = arg.get_idxs() diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index eece5ea1..34e92ad8 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -224,7 +224,7 @@ class Plan(EinsumTree): Basically a list of einsums and some return values. """ - bodies: tuple[Einsum, ...] = () + bodies: tuple[EinsumNode, ...] = () @classmethod def from_children(cls, *children: Term) -> Self: @@ -234,7 +234,7 @@ def from_children(cls, *children: Term) -> Self: bodies = children return cls( - tuple(cast(Einsum, b) for b in bodies), + tuple(cast(EinsumNode, b) for b in bodies), ) @property From a02ffcdbcad10914d2f8c5d23e0f2f7dc509b1cd Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 18 Oct 2025 10:26:45 -0400 Subject: [PATCH 17/57] * Fixed issues with einsum printer. * Reverted to old printing that supports ein.Produce * Fixed misc issues related to reverting changes --- src/finchlite/interface/lazy.py | 10 ++++++++-- tests/reference/test_einsum_printer.txt | 4 +--- tests/test_printers.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index 4b3f8e9e..48f6b7e6 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -1805,7 +1805,10 @@ def std( def einop(prgm, **kwargs): stmt = ein.parse_einop(prgm) - prgm = ein.Plan((stmt,), (stmt.tns,)) + prgm = ein.Plan(( + stmt, + ein.Produces((stmt.tns,)) + )) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, dict(**kwargs)) @@ -1814,7 +1817,10 @@ def einop(prgm, **kwargs): def einsum(prgm, *args, **kwargs): stmt, bindings = ein.parse_einsum(prgm, *args) - prgm = ein.Plan((stmt,), (stmt.tns,)) + prgm = ein.Plan(( + stmt, + ein.Produces((stmt.tns,)) + )) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, bindings) diff --git a/tests/reference/test_einsum_printer.txt b/tests/reference/test_einsum_printer.txt index bf766d81..569a686e 100644 --- a/tests/reference/test_einsum_printer.txt +++ b/tests/reference/test_einsum_printer.txt @@ -3,6 +3,4 @@ plan: plan: D[i, j] += (A[i, k] * B[k, j]) E[i] min= lshift((A[i, k] + D[k, j]), 1) -returnValues: - C - E + return ('C', 'E') \ No newline at end of file diff --git a/tests/test_printers.py b/tests/test_printers.py index 960cf051..e3186736 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -542,8 +542,8 @@ def test_einsum_printer(file_regression): ein.parse_einop("E[i] min= A[i,k] + D[k,j] << 1"), ) ), + ein.Produces((ein.Alias("C"), ein.Alias("E"))), ), - (ein.Alias("C"), ein.Alias("E")), ) file_regression.check(str(prgm), extension=".txt") From c81a1b6ad22be3dbd48181e9961adf3b33d12ed9 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 18 Oct 2025 10:30:42 -0400 Subject: [PATCH 18/57] * Fixed type errors in einsum lowerer --- src/finchlite/autoschedule/einsum.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 921cc88d..0ba9b0cf 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -108,7 +108,7 @@ def compile_plan( def lower_to_einsum( self, ex: LogicNode, - einsums: list[ein.Einsum], + bodies: list[ein.EinsumNode], parameters: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.Einsum: @@ -117,7 +117,7 @@ def lower_to_einsum( raise Exception("Plans within plans are not supported.") case MapJoin(Literal(operation), args): args_list = [ - self.lower_to_pointwise(arg, einsums, parameters, definitions) + self.lower_to_pointwise(arg, bodies, parameters, definitions) for arg in args ] pointwise_expr = self.lower_to_pointwise_op(operation, tuple(args_list)) @@ -129,7 +129,7 @@ def lower_to_einsum( ) case Reorder(arg, idxs): return self.reorder_einsum( - self.lower_to_einsum(arg, einsums, parameters, definitions), + self.lower_to_einsum(arg, bodies, parameters, definitions), tuple(ein.Index(field.name) for field in idxs), ) case Aggregate(Literal(operation), Literal(init), arg, idxs): @@ -140,7 +140,7 @@ def lower_to_einsum( Non standard init values are not supported. """) aggregate_expr = self.lower_to_pointwise( - arg, einsums, parameters, definitions + arg, bodies, parameters, definitions ) return ein.Einsum( op=ein.Literal(operation), @@ -180,16 +180,16 @@ def flatten_args( def lower_to_pointwise( self, ex: LogicNode, - einsums: list[ein.Einsum], + bodies: list[ein.EinsumNode], parameters: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: case Reorder(arg, idxs): - return self.lower_to_pointwise(arg, einsums, parameters, definitions) + return self.lower_to_pointwise(arg, bodies, parameters, definitions) case MapJoin(Literal(operation), args): args_list = [ - self.lower_to_pointwise(arg, einsums, parameters, definitions) + self.lower_to_pointwise(arg, bodies, parameters, definitions) for arg in args ] return self.lower_to_pointwise_op(operation, tuple(args_list)) @@ -206,9 +206,9 @@ def lower_to_pointwise( _, _, _, _ ): # aggregate has to be computed seperatley as it's own einsum aggregate_einsum_alias = self.get_next_alias() - einsums.append( + bodies.append( self.rename_einsum( - self.lower_to_einsum(ex, einsums, parameters, definitions), + self.lower_to_einsum(ex, bodies, parameters, definitions), aggregate_einsum_alias, definitions, ) From fad5a319e48701e8a27a5139531cc997b19e6df1 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 18 Oct 2025 10:35:17 -0400 Subject: [PATCH 19/57] * Fixed ruff errors --- src/finchlite/autoschedule/einsum.py | 6 ++---- src/finchlite/interface/lazy.py | 10 ++-------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 0ba9b0cf..483fa3e7 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -90,14 +90,12 @@ def compile_plan( ) bodies.append(einsum) returnValues.append(einsum.tns) - + bodies.append(ein.Produces(tuple(returnValues))) case _: bodies.append( self.rename_einsum( - self.lower_to_einsum( - body, bodies, parameters, definitions - ), + self.lower_to_einsum(body, bodies, parameters, definitions), self.get_next_alias(), definitions, ) diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index 48f6b7e6..ade2247f 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -1805,10 +1805,7 @@ def std( def einop(prgm, **kwargs): stmt = ein.parse_einop(prgm) - prgm = ein.Plan(( - stmt, - ein.Produces((stmt.tns,)) - )) + prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, dict(**kwargs)) @@ -1817,10 +1814,7 @@ def einop(prgm, **kwargs): def einsum(prgm, *args, **kwargs): stmt, bindings = ein.parse_einsum(prgm, *args) - prgm = ein.Plan(( - stmt, - ein.Produces((stmt.tns,)) - )) + prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, bindings) From 8e3b5c7bbe2a62d193cd931e47727eb2753f0935 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 08:52:40 -0400 Subject: [PATCH 20/57] * Added sparse tensor implementation * Sparse tensor implementation is COO * Stores an array of coordinates of non-zero elements and an array of non-zero elements --- src/finchlite/autoschedule/insum.py | 0 src/finchlite/tensor/sparse_tensor.py | 90 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 src/finchlite/autoschedule/insum.py create mode 100644 src/finchlite/tensor/sparse_tensor.py diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py new file mode 100644 index 00000000..e69de29b diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py new file mode 100644 index 00000000..60f441f3 --- /dev/null +++ b/src/finchlite/tensor/sparse_tensor.py @@ -0,0 +1,90 @@ +from finchlite.algebra import TensorFType +from finchlite.interface.eager import EagerTensor +import numpy as np + +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type): + self.shape = shape + self._element_type = element_type + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type + + def __hash__(self): + return hash((self.shape, self.element_type)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def element_type(self): + return self._element_type + + @property + def fill_value(self): + return 0 + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_type=np.float64): + self.coords = coords + self.data = data + self._shape = shape + self._element_type = element_type + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: np.ndarray): + + coords = np.where(dense_tensor != 0) + data = dense_tensor[coords] + shape = dense_tensor.shape + element_type = dense_tensor.dtype.type # Get the type, not the dtype + # Convert coords from tuple of arrays to array of coordinates + coords_array = np.array(coords).T + return cls(data, coords_array, shape, element_type) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self._element_type) + + @property + def shape(self): + return self._shape + + @property + def ndim(self) -> int: + return len(self._shape) + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return self.coords.shape[0] / np.prod(self.shape) + + def __getitem__(self, idx: tuple): + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + # coords is a 2D array where each row is a coordinate + mask = np.all(self.coords == idx, axis=1) + matching_indices = np.where(mask)[0] + + if len(matching_indices) > 0: + return self.data[matching_indices[0]] + return 0 + + def __str__(self): + return f"SparseTensor(data={self.data}, coords={self.coords}, shape={self.shape}, element_type={self._element_type})" + + def to_dense(self) -> np.ndarray: + dense_tensor = np.zeros(self.shape, dtype=self._element_type) + for i in range(self.coords.shape[0]): + dense_tensor[tuple(self.coords[i])] = self.data[i] + return dense_tensor \ No newline at end of file From 01152850bdcf580c480a9d1d2b70822b5db4ccb2 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 09:17:46 -0400 Subject: [PATCH 21/57] * Added framework for Insum Lowerer --- src/finchlite/autoschedule/__init__.py | 2 ++ src/finchlite/autoschedule/insum.py | 34 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 32022d8e..730ec1fe 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -16,6 +16,7 @@ ) from ..symbolic import PostOrderDFS, PostWalk, PreWalk from .compiler import LogicCompiler +from .einsum import EinsumLowerer from .optimize import ( DefaultLogicOptimizer, concordize, @@ -43,6 +44,7 @@ "Aggregate", "Alias", "DefaultLogicOptimizer", + "EinsumLowerer", "Field", "Literal", "LogicCompiler", diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index e69de29b..361462b4 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -0,0 +1,34 @@ +import operator +from typing import Any +import finchlite.finch_einsum as ein +import finchlite.finch_logic as logic +from finchlite.finch_logic.nodes import Table +from finchlite.symbolic import ( + ftype, + PostWalk, + Rewrite, + gensym +) +from finchlite.algebra import ( + overwrite, + init_value +) +from finchlite.autoschedule import ( + EinsumLowerer +) + +class InsumLowerer: + def __init__(self): + self.el = EinsumLowerer() + + def can_optimize(self, node: ein.EinsumNode, sparse_params: set[str]) -> tuple[bool, dict[str, tuple[ein.Index, ...]]]: + pass + + def optimize_einsum(self, einsum: ein.Einsum, sparse_param: str, sparse_param_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: + pass + + def get_sparse_params(self, parameters: dict[str, Table]) -> set[str]: + pass + + def optimize_plan(self, plan: ein.Plan, parameters: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]: + pass \ No newline at end of file From ef393456fc0f7ab5864ac3cc271c8e22958449f6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 09:19:17 -0400 Subject: [PATCH 22/57] * Renamed parameters to binding in einsum lowerer to stay consitent with einsum parser naming --- src/finchlite/autoschedule/einsum.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 483fa3e7..d0212daf 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -25,9 +25,9 @@ class EinsumLowerer: def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: - parameters: dict[str, Any] = {} + bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} - return self.compile_plan(prgm, parameters, definitions), parameters + return self.compile_plan(prgm, bindings, definitions), bindings def get_next_alias(self) -> ein.Alias: return ein.Alias(gensym("einsum")) @@ -47,30 +47,30 @@ def reorder_einsum( return ein.Einsum(einsum.op, einsum.tns, idxs, einsum.arg) def compile_plan( - self, plan: Plan, parameters: dict[str, Any], definitions: dict[str, ein.Einsum] + self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: bodies: list[ein.EinsumNode] = [] for body in plan.bodies: match body: case Plan(_): - inner_plan = self.compile_plan(body, parameters, definitions) + inner_plan = self.compile_plan(body, bindings, definitions) bodies.extend(inner_plan.bodies) break case Query(Alias(name), Table(Literal(val), _)) if isinstance( val, Scalar ): - parameters[name] = val.val + bindings[name] = val.val case Query(Alias(name), Table(Literal(tns), _)) if isinstance( tns, Tensor ): - parameters[name] = ( + bindings[name] = ( tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) ) # type: ignore[attr-defined] case Query(Alias(name), rhs): bodies.append( self.rename_einsum( - self.lower_to_einsum(rhs, bodies, parameters, definitions), + self.lower_to_einsum(rhs, bodies, bindings, definitions), ein.Alias(name), definitions, ) @@ -83,7 +83,7 @@ def compile_plan( else: einsum = self.rename_einsum( self.lower_to_einsum( - arg, bodies, parameters, definitions + arg, bodies, bindings, definitions ), self.get_next_alias(), definitions, @@ -95,7 +95,7 @@ def compile_plan( case _: bodies.append( self.rename_einsum( - self.lower_to_einsum(body, bodies, parameters, definitions), + self.lower_to_einsum(body, bodies, bindings, definitions), self.get_next_alias(), definitions, ) @@ -107,7 +107,7 @@ def lower_to_einsum( self, ex: LogicNode, bodies: list[ein.EinsumNode], - parameters: dict[str, Any], + bindings: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.Einsum: match ex: @@ -115,7 +115,7 @@ def lower_to_einsum( raise Exception("Plans within plans are not supported.") case MapJoin(Literal(operation), args): args_list = [ - self.lower_to_pointwise(arg, bodies, parameters, definitions) + self.lower_to_pointwise(arg, bodies, bindings, definitions) for arg in args ] pointwise_expr = self.lower_to_pointwise_op(operation, tuple(args_list)) @@ -127,7 +127,7 @@ def lower_to_einsum( ) case Reorder(arg, idxs): return self.reorder_einsum( - self.lower_to_einsum(arg, bodies, parameters, definitions), + self.lower_to_einsum(arg, bodies, bindings, definitions), tuple(ein.Index(field.name) for field in idxs), ) case Aggregate(Literal(operation), Literal(init), arg, idxs): @@ -138,7 +138,7 @@ def lower_to_einsum( Non standard init values are not supported. """) aggregate_expr = self.lower_to_pointwise( - arg, bodies, parameters, definitions + arg, bodies, bindings, definitions ) return ein.Einsum( op=ein.Literal(operation), @@ -179,15 +179,15 @@ def lower_to_pointwise( self, ex: LogicNode, bodies: list[ein.EinsumNode], - parameters: dict[str, Any], + bindings: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: case Reorder(arg, idxs): - return self.lower_to_pointwise(arg, bodies, parameters, definitions) + return self.lower_to_pointwise(arg, bodies, bindings, definitions) case MapJoin(Literal(operation), args): args_list = [ - self.lower_to_pointwise(arg, bodies, parameters, definitions) + self.lower_to_pointwise(arg, bodies, bindings, definitions) for arg in args ] return self.lower_to_pointwise_op(operation, tuple(args_list)) @@ -206,7 +206,7 @@ def lower_to_pointwise( aggregate_einsum_alias = self.get_next_alias() bodies.append( self.rename_einsum( - self.lower_to_einsum(ex, bodies, parameters, definitions), + self.lower_to_einsum(ex, bodies, bindings, definitions), aggregate_einsum_alias, definitions, ) From ba0992d76ea298e7fbb5388f4f039e41e58f3296 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 09:20:49 -0400 Subject: [PATCH 23/57] * Renamed parameters to bindings in InsumLowerer for consitency with EinsumLowerer and parser --- src/finchlite/autoschedule/insum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 361462b4..2e5146fe 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -27,8 +27,8 @@ def can_optimize(self, node: ein.EinsumNode, sparse_params: set[str]) -> tuple[b def optimize_einsum(self, einsum: ein.Einsum, sparse_param: str, sparse_param_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: pass - def get_sparse_params(self, parameters: dict[str, Table]) -> set[str]: + def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: pass - def optimize_plan(self, plan: ein.Plan, parameters: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]: + def optimize_plan(self, plan: ein.Plan, bindings: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]: pass \ No newline at end of file From a372a0d775a315195bd66332433679c8859290b1 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 09:43:28 -0400 Subject: [PATCH 24/57] * Implemented can_optimize method in InsumLowerer --- src/finchlite/autoschedule/insum.py | 38 +++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 2e5146fe..3fce47ab 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -1,8 +1,10 @@ import operator -from typing import Any +from typing import Any, cast + +from numpy import isin import finchlite.finch_einsum as ein import finchlite.finch_logic as logic -from finchlite.finch_logic.nodes import Table +from finchlite.finch_logic.nodes import Alias, Table from finchlite.symbolic import ( ftype, PostWalk, @@ -21,10 +23,36 @@ class InsumLowerer: def __init__(self): self.el = EinsumLowerer() - def can_optimize(self, node: ein.EinsumNode, sparse_params: set[str]) -> tuple[bool, dict[str, tuple[ein.Index, ...]]]: - pass + def can_optimize(self, en: ein.EinsumNode, sparse: set[str]) -> tuple[bool, dict[str, tuple[ein.Index, ...]]]: + """ + Checks if an einsum node can be optimized via indirect einsums. + Specifically it checks whether node is an einsum that references any sparse tensor binding/parameter. + """ + if not isinstance(en, ein.Einsum): + return False + + einsum = cast(ein.Einsum, en) + + refed_sparse = dict() + + def sparse_detect(node: ein.EinsumExpr): + nonlocal refed_sparse + + match node: + case ein.Access(ein.Alias(name), idxs): + if name not in sparse: + return None + + if name in refed_sparse and refed_sparse[name] != idxs: + raise ValueError( + f"Sparse binding {name} is being referenced " + "with different indicies.") + refed_sparse[name] = idxs + return None + + PostWalk(sparse_detect)(einsum.arg) - def optimize_einsum(self, einsum: ein.Einsum, sparse_param: str, sparse_param_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: + def optimize_einsum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: pass def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: From d99b5712e0439a8626ef270d8d250fcc63a7a28d Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 10:03:01 -0400 Subject: [PATCH 25/57] Impelemented get_sparse_params in InsumLowerer --- src/finchlite/autoschedule/insum.py | 32 ++++++++++++++++++++++++++++- src/finchlite/tensor/__init__.py | 14 ++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 3fce47ab..933a0559 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -18,6 +18,9 @@ from finchlite.autoschedule import ( EinsumLowerer ) +from finchlite.tensor import ( + SparseTensorFType +) class InsumLowerer: def __init__(self): @@ -27,6 +30,15 @@ def can_optimize(self, en: ein.EinsumNode, sparse: set[str]) -> tuple[bool, dict """ Checks if an einsum node can be optimized via indirect einsums. Specifically it checks whether node is an einsum that references any sparse tensor binding/parameter. + + Arguments: + en: The einsum node to check. + sparse: The set of aliases of sparse tensor bindings/parameters. + + Returns: + A tuple containing: + - A boolean indicating if the einsum node can be optimized. + - A dictionary mapping sparse binding aliases to the indices they are referenced with. """ if not isinstance(en, ein.Einsum): return False @@ -56,7 +68,25 @@ def optimize_einsum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ei pass def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: - pass + """ + Gets the set of sparse binding aliases from the bindings dictionary. + + Arguments: + bindings: The bindings dictionary. + + Returns: + A set of sparse binding aliases. + """ + + sparse = set() + + for alias, value in bindings.items(): + match value: + case logic.Table(logic.Literal(tensor_value), _): + if isinstance(ftype(tensor_value), SparseTensorFType): + sparse.add(alias) + + return sparse def optimize_plan(self, plan: ein.Plan, bindings: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]: pass \ No newline at end of file diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index 9c45ad8c..0e7091d1 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -1,4 +1,14 @@ -from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, tensor +from .fiber_tensor import ( + FiberTensor, + FiberTensorFType, + Level, + LevelFType, + tensor +) +from .sparse_tensor import ( + SparseTensor, + SparseTensorFType +) from .level import ( DenseLevel, DenseLevelFType, @@ -15,6 +25,8 @@ "ElementLevelFType", "FiberTensor", "FiberTensorFType", + "SparseTensor", + "SparseTensorFType", "Level", "LevelFType", "dense", From 56333d7e267299924db648fe7c9264851601f4f8 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 10:28:54 -0400 Subject: [PATCH 26/57] * Added GetAttribute Einsum IR node * Get Attribute is a general purpose node that can be used to retreive the coordinate and element arrays from a sparse tensor --- src/finchlite/autoschedule/insum.py | 17 +++++++++++++ src/finchlite/finch_einsum/nodes.py | 37 ++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 933a0559..7ff2751e 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -1,3 +1,4 @@ +from ast import alias import operator from typing import Any, cast @@ -65,6 +66,22 @@ def sparse_detect(node: ein.EinsumExpr): PostWalk(sparse_detect)(einsum.arg) def optimize_einsum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: + #bodies: list[ein.EinsumNode] = [] + + # initialize mask tensor T which is a boolean that represents whether each reduced fiber in the sparse tensor has non-zero elements or not + # Essentially T[idxs...] = whether the sparse tensor fiber being reduced at idxs... has any non-zero elements in it + #T_idxs = tuple(idx for idx in einsum.idxs if idx in sparse_idxs) + #bodies.append(ein.Einsum( #initialize every element of T to 0 + # op=ein.Literal(overwrite), + # alias=ein.Alias(gensym(f"{sparse}_T")), + # idxs=T_idxs, + # arg=ein.Literal(0) + #)) + #bodies.append(ein.Einsum( + # op=ein.Literal(operator.add), + # alias=ein.Alias(gensym(f"{sparse}_T")), + # idxs= + #)) pass def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 34e92ad8..5a7dcc3d 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -1,7 +1,7 @@ import operator from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Self, cast +from typing import Any, Optional, Self, cast from finchlite.algebra import ( overwrite, @@ -175,6 +175,41 @@ def get_idxs(self) -> set["Index"]: idxs.update(arg.get_idxs()) return idxs +@dataclass(eq=True, frozen=True) +class GetAttribute(EinsumExpr, EinsumTree): + """ + Gets an attribute of a tensor. + + Attributes: + obj: The object to get the attribute from. + attr: The name of the attribute to get. + """ + + obj: EinsumExpr + attr: Literal + idx: Optional[Index] + + @classmethod + def from_children(cls, *children: Term) -> Self: + # Expects 3 children: obj, attr, idx + if len(children) != 3: + raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") + obj = cast(EinsumExpr, children[0]) + attr = cast(Literal, children[1]) + idx = cast(Optional[Index], children[2]) + return cls(obj, attr, idx) + + @property + def children(self): + return [self.obj, self.attr, self.idx] + + def get_idxs(self) -> set["Index"]: + idxs = set() + idxs.update(self.obj.get_idxs()) + if self.idx is not None: + idxs.update(self.idx.get_idxs()) + return idxs + @dataclass(eq=True, frozen=True) class Einsum(EinsumTree): From 0e2645ee66cdf1f16581355fcb7e7f8112b4a375 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 20:15:15 -0400 Subject: [PATCH 27/57] * Implemented to_insum method in InsumLowerer * We create a count tensor T, where each element in T is the number of non-zero elements in each corresponding reduced fiber of the sparse tensor * We initialize the initial reduction values based on whether each reduced fiber is all non-zero or not --- src/finchlite/autoschedule/einsum.py | 8 +- src/finchlite/autoschedule/insum.py | 147 +++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 19 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d0212daf..79ff8e27 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -4,8 +4,12 @@ import numpy as np import finchlite.finch_einsum as ein -from finchlite.algebra import init_value, is_commutative, overwrite -from finchlite.algebra.tensor import Tensor +from finchlite.algebra import ( + init_value, + is_commutative, + overwrite, + Tensor +) from finchlite.finch_logic import ( Aggregate, Alias, diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 7ff2751e..3449245e 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -14,7 +14,8 @@ ) from finchlite.algebra import ( overwrite, - init_value + init_value, + ifelse ) from finchlite.autoschedule import ( EinsumLowerer @@ -65,24 +66,136 @@ def sparse_detect(node: ein.EinsumExpr): PostWalk(sparse_detect)(einsum.arg) - def optimize_einsum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: - #bodies: list[ein.EinsumNode] = [] - + def to_insum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: + bodies: list[ein.EinsumNode] = [] + reduced_idx = ein.Index(gensym(f"pos")) # initialize mask tensor T which is a boolean that represents whether each reduced fiber in the sparse tensor has non-zero elements or not # Essentially T[idxs...] = whether the sparse tensor fiber being reduced at idxs... has any non-zero elements in it - #T_idxs = tuple(idx for idx in einsum.idxs if idx in sparse_idxs) - #bodies.append(ein.Einsum( #initialize every element of T to 0 - # op=ein.Literal(overwrite), - # alias=ein.Alias(gensym(f"{sparse}_T")), - # idxs=T_idxs, - # arg=ein.Literal(0) - #)) - #bodies.append(ein.Einsum( - # op=ein.Literal(operator.add), - # alias=ein.Alias(gensym(f"{sparse}_T")), - # idxs= - #)) - pass + T_idxs = tuple(idx for idx in einsum.idxs if idx in sparse_idxs) + T_mask = ein.Alias(gensym(f"{sparse}_T")) + bodies.append(ein.Einsum( #initialize every element of T to 0 + op=ein.Literal(overwrite), + alias=T_mask, + idxs=T_idxs, + arg=ein.Literal(0) + )) + bodies.append(ein.Einsum( + op=ein.Literal(operator.add), + alias=T_mask, + idxs=( + ein.Access( + ein.GetAttribute( + obj=ein.Alias(sparse), + attr=ein.Literal("coords"), + idx=None + ), + (reduced_idx,) + ), + ), + arg=ein.Literal(1) + )) + + # get the reduced indicies in the sparse tensor + reduced_idxs = tuple(idx for idx in einsum.idxs if idx not in sparse_idxs) + + # get the size of the fiber in the sparse tensor being reduced + reduced_fiber_size = ein.Call(ein.Literal(operator.mul), ( + ein.Literal(1), + *[ein.GetAttribute( + obj=ein.Alias(sparse), + attr=ein.Literal("shape"), + idx=idx + ) for idx in reduced_idxs] + )) + + # rewrite the indicies used to iterate over the sparse tensor + def rewrite_indicies(idxs: tuple[ein.EinsumExpr, ...]) -> tuple[ein.EinsumExpr, ...]: + if idxs == sparse_idxs: + return (ein.Access( + ein.GetAttribute( + obj=ein.Alias(sparse), + attr=ein.Literal("coords"), + idx=None + ), + (reduced_idx,) + ),) + + new_idxs = [] + for idx in idxs: + match idx: + case ein.Index(_) if idx in sparse_idxs: + new_idxs.append(ein.Access( + ein.GetAttribute( + obj=ein.Alias(sparse), + attr=ein.Literal("coords"), + idx=idx + ), + (reduced_idx,) + )) + case _: + new_idxs.append(idx) + return tuple(new_idxs) + + # pattern matching rule to rewrite all indicies in arg + def rewrite_all_indicies(node: ein.EinsumExpr) -> ein.EinsumExpr: + match node: + case ein.Access(ein.Alias(name), idxs) if name == sparse and idxs == sparse_idxs: + return ein.Access( + ein.GetAttribute( + obj=ein.Alias(sparse), + attr=ein.Literal("elems"), + idx=None + ), + (reduced_idx,) + ) + case ein.Access(ein.Alias(name), idxs): + return ein.Access(ein.Alias(name), rewrite_indicies(idxs)) + + # rewrite a pointwise expression to assume that the sparse tensor is all-zero + def rewrite_zero(node: ein.EinsumExpr) -> ein.EinsumExpr: + match node: + case ein.Access(ein.Alias(name), _) if name == sparse: + return ein.Literal(0) + case ein.Access(ein.GetAttribute(ein.Alias(name), ein.Literal("elems"), None), _) if name == sparse: + return ein.Literal(0) + + # rewrite + new_einarg = Rewrite(PostWalk(rewrite_all_indicies))(einsum.arg) + zero_einarg = Rewrite(PostWalk(rewrite_zero))(einsum.arg) + + # initialize the reduction values + # essentially, we calculate the reduction values for the reduced fibers of the sparse tensor that are non zero, and hence who's iterations asre skipped + # we make the following core assumption: that the reduction operator, $f$ is associative and commutative. + # In other words, $f(a, f(b, c)) = f(f(a, b), c)$ for all $a, b, c$. + # In essence we assume a single zero element combined with the initial value passed through the reduction operator will + # be equal to the effect of one or more zero elements at any point in the reduced fiber combined with the initial value. + init = 0 if einsum.op == overwrite else init_value(einsum.op, type(0)) + bodies.append(ein.Einsum( + op=ein.Literal(overwrite), + alias=einsum.alias, + idxs=einsum.idxs, + arg=ein.Call(ein.Literal(ifelse), ( + ein.Call(ein.Literal(operator.eq), ( #check if T[idxs...] == reduced_fiber_size + ein.Access(T_mask, (reduced_idx,)), + reduced_fiber_size + )), + init, # if fiber is all non-zero initial reduction value is default + ein.Call(ein.Literal(einsum.op), ( + ein.Literal(init), + zero_einarg + )) + )) + )) + + #finally we execute the naive einsum -> insum + bodies.append(ein.Einsum( + op=einsum.op, + alias=einsum.alias, + idxs=rewrite_indicies(einsum.idxs), + arg=new_einarg + )) + + return bodies def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: """ From 79fda71624d9a8f6ef6b930c144c5a8cdfdfa2fb Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 20 Oct 2025 20:21:13 -0400 Subject: [PATCH 28/57] * Added top level optimize plan method to insum lowerer --- src/finchlite/autoschedule/insum.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/finchlite/autoschedule/insum.py b/src/finchlite/autoschedule/insum.py index 3449245e..9f2c42c7 100644 --- a/src/finchlite/autoschedule/insum.py +++ b/src/finchlite/autoschedule/insum.py @@ -1,11 +1,8 @@ -from ast import alias import operator from typing import Any, cast -from numpy import isin import finchlite.finch_einsum as ein import finchlite.finch_logic as logic -from finchlite.finch_logic.nodes import Alias, Table from finchlite.symbolic import ( ftype, PostWalk, @@ -65,6 +62,7 @@ def sparse_detect(node: ein.EinsumExpr): return None PostWalk(sparse_detect)(einsum.arg) + return len(refed_sparse) > 0, refed_sparse def to_insum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]: bodies: list[ein.EinsumNode] = [] @@ -219,4 +217,16 @@ def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]: return sparse def optimize_plan(self, plan: ein.Plan, bindings: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]: - pass \ No newline at end of file + sparse = self.get_sparse_params(bindings) + + new_bodies = [] + for body in plan.bodies: + can_optimize, all_sparse = self.can_optimize(body, sparse) + if can_optimize: + sparse_binding, sparse_idxs = next(iter(all_sparse)) + new_bodies.extend(self.to_insum(body, sparse_binding, sparse_idxs)) + else: + new_bodies.append(body) + + return ein.Plan(new_bodies), bindings + \ No newline at end of file From d3cf66f735406d0d0f235e8a219e11fee57267d8 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 10:41:13 -0400 Subject: [PATCH 29/57] * Added barebones test utilities to insum lowerer pytest. --- src/finchlite/autoschedule/__init__.py | 2 ++ tests/test_insum_lowerer.py | 45 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/test_insum_lowerer.py diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 730ec1fe..51aff62b 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -17,6 +17,7 @@ from ..symbolic import PostOrderDFS, PostWalk, PreWalk from .compiler import LogicCompiler from .einsum import EinsumLowerer +from .insum import InsumLowerer from .optimize import ( DefaultLogicOptimizer, concordize, @@ -45,6 +46,7 @@ "Alias", "DefaultLogicOptimizer", "EinsumLowerer", + "InsumLowerer", "Field", "Literal", "LogicCompiler", diff --git a/tests/test_insum_lowerer.py b/tests/test_insum_lowerer.py new file mode 100644 index 00000000..f24a589e --- /dev/null +++ b/tests/test_insum_lowerer.py @@ -0,0 +1,45 @@ +from importlib.abc import InspectLoader +from typing import Any, cast + +import pytest + +import numpy as np + +import finchlite +from finchlite.autoschedule import ( + EinsumLowerer, + InsumLowerer, + optimize +) +import finchlite.finch_einsum as ein +import finchlite.finch_logic as logic +from finchlite.symbolic import gensym + +@pytest.fixture +def rng(): + return np.random.default_rng(42) + +def test_einsum_to_insum(plan: ein.Plan, bindings: dict[str, Any]): + """Test converting an einsum plan to an insum plan""" + lowerer = InsumLowerer() + insum_plan, bindings = lowerer.optimize_plan(plan, bindings) + + interpreter = ein.EinsumInterpreter(bindings=bindings) + result = interpreter(insum_plan)[0] + result2 = interpreter(plan)[0] + + return np.allclose(result, result2) + +def test_logic_to_insum(ir: logic.LogicNode): + """Test converting a logic plan to an insum plan""" + + # Optimize into a plan + var = logic.Alias(gensym("result")) + plan = logic.Plan((logic.Query(var, ir), logic.Produces((var,)))) + optimized_plan = cast(logic.Plan, optimize(plan)) + + # Lower to einsum IR + lowerer = EinsumLowerer() + einsum_plan, bindings = lowerer(optimized_plan) + + test_einsum_to_insum(einsum_plan, bindings) \ No newline at end of file From 92f59b5a2de2a7d2efdb30577961aace34102519 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 10:43:39 -0400 Subject: [PATCH 30/57] * Added support for GetAttribute EinsumExpr in EinsumPrinterContext --- src/finchlite/finch_einsum/nodes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 5a7dcc3d..2c074b4d 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -369,6 +369,10 @@ def __call__(self, prgm: EinsumNode): if len(args) == 1 and fn.val in unary_strs: return f"{unary_strs[fn.val]}{args_e[0]}" return f"{self(fn)}({', '.join(args_e)})" + case GetAttribute(obj, attr, idx): + if idx is not None: + return f"{self(obj)}.{self(attr)}[{self(idx)}]" + return f"{self(obj)}.{self(attr)}" case Einsum(op, tns, idxs, arg): op_str = infix_strs.get(op.val, op.val.__name__) self.exec( From f24b43abe001a25adea0cedb81d4961ac6c06e4e Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 17:26:33 -0400 Subject: [PATCH 31/57] * Simplified Einsum Lowerer by inlining lower_to_einsum into compile_plan * Renamed lower_to_pointwise_op to compile_expr and lower_to_pointwise to compile_operand for clarity --- src/finchlite/autoschedule/einsum.py | 129 ++++++++------------------- tests/test_einsum_lowerer.py | 10 ++- 2 files changed, 46 insertions(+), 93 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d0212daf..363015e1 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -54,102 +54,49 @@ def compile_plan( for body in plan.bodies: match body: case Plan(_): - inner_plan = self.compile_plan(body, bindings, definitions) - bodies.extend(inner_plan.bodies) - break - case Query(Alias(name), Table(Literal(val), _)) if isinstance( - val, Scalar - ): - bindings[name] = val.val - case Query(Alias(name), Table(Literal(tns), _)) if isinstance( - tns, Tensor - ): - bindings[name] = ( - tns.to_numpy() if hasattr(tns, "to_numpy") else np.asarray(tns) - ) # type: ignore[attr-defined] - case Query(Alias(name), rhs): - bodies.append( - self.rename_einsum( - self.lower_to_einsum(rhs, bodies, bindings, definitions), - ein.Alias(name), - definitions, - ) - ) + bodies.append(self.compile_plan(body, bindings, definitions)) + case Query(Alias(name), Table(Literal(val), _)): + bindings[name] = val + case Query(Alias(name), MapJoin(Literal(operation), args)): + args_list = [ + self.compile_operand(arg, bodies, bindings, definitions) + for arg in args + ] + bodies.append(ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), + arg= self.compile_expr(operation, tuple(args_list)), + )) + case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)): + remaining_idxs = tuple(ein.Index(field.name) for field in body.rhs.fields) + if init != init_value(operation, type(init)): + bodies.append(ein.Einsum( + op = ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=remaining_idxs, + arg=ein.Literal(init), + )) + bodies.append(ein.Einsum( + op = ein.Literal(operation), + tns=ein.Alias(name), + idxs=remaining_idxs, + arg=self.compile_operand(arg, bodies, bindings, definitions), + )) case Produces(args): returnValues = [] for arg in args: - if isinstance(arg, Alias): - returnValues.append(ein.Alias(arg.name)) - else: - einsum = self.rename_einsum( - self.lower_to_einsum( - arg, bodies, bindings, definitions - ), - self.get_next_alias(), - definitions, - ) - bodies.append(einsum) - returnValues.append(einsum.tns) + if not isinstance(arg, Alias): + raise Exception(f"Unrecognized logic: {arg}") + returnValues.append(ein.Alias(arg.name)) bodies.append(ein.Produces(tuple(returnValues))) case _: - bodies.append( - self.rename_einsum( - self.lower_to_einsum(body, bodies, bindings, definitions), - self.get_next_alias(), - definitions, - ) - ) + raise Exception(f"Unrecognized logic: {body}") return ein.Plan(tuple(bodies)) - def lower_to_einsum( - self, - ex: LogicNode, - bodies: list[ein.EinsumNode], - bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], - ) -> ein.Einsum: - match ex: - case Plan(_): - raise Exception("Plans within plans are not supported.") - case MapJoin(Literal(operation), args): - args_list = [ - self.lower_to_pointwise(arg, bodies, bindings, definitions) - for arg in args - ] - pointwise_expr = self.lower_to_pointwise_op(operation, tuple(args_list)) - return ein.Einsum( - op=ein.Literal(overwrite), - tns=self.get_next_alias(), - idxs=tuple(ein.Index(field.name) for field in ex.fields), - arg=pointwise_expr, - ) - case Reorder(arg, idxs): - return self.reorder_einsum( - self.lower_to_einsum(arg, bodies, bindings, definitions), - tuple(ein.Index(field.name) for field in idxs), - ) - case Aggregate(Literal(operation), Literal(init), arg, idxs): - if init != init_value(operation, type(init)): - raise Exception(f""" - Init value {init} is not the default value - for operation {operation} of type {type(init)}. - Non standard init values are not supported. - """) - aggregate_expr = self.lower_to_pointwise( - arg, bodies, bindings, definitions - ) - return ein.Einsum( - op=ein.Literal(operation), - tns=self.get_next_alias(), - idxs=tuple(ein.Index(field.name) for field in ex.fields), - arg=aggregate_expr, - ) - case _: - raise Exception(f"Unrecognized logic: {ex}") - - def lower_to_pointwise_op( + def compile_expr( self, operation: Callable, args: tuple[ein.EinsumExpr, ...] ) -> ein.EinsumExpr: # if operation is commutative, we simply pass @@ -175,7 +122,7 @@ def flatten_args( return ein.Call(ein.Literal(operation), args) # lowers nested mapjoin logic IR nodes into a single pointwise expression - def lower_to_pointwise( + def compile_operand( self, ex: LogicNode, bodies: list[ein.EinsumNode], @@ -184,13 +131,13 @@ def lower_to_pointwise( ) -> ein.EinsumExpr: match ex: case Reorder(arg, idxs): - return self.lower_to_pointwise(arg, bodies, bindings, definitions) + return self.compile_operand(arg, bodies, bindings, definitions) case MapJoin(Literal(operation), args): args_list = [ - self.lower_to_pointwise(arg, bodies, bindings, definitions) + self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] - return self.lower_to_pointwise_op(operation, tuple(args_list)) + return self.compile_expr(operation, tuple(args_list)) case Relabel( Alias(name), idxs ): # relable is really just a glorified pointwise access diff --git a/tests/test_einsum_lowerer.py b/tests/test_einsum_lowerer.py index 43be026f..43b54138 100644 --- a/tests/test_einsum_lowerer.py +++ b/tests/test_einsum_lowerer.py @@ -36,10 +36,16 @@ def lower_and_execute(ir: LogicNode): # Lower to einsum IR lowerer = EinsumLowerer() - einsum_plan, plan_parameters = lowerer(optimized_plan) + einsum_plan, bindings = lowerer(optimized_plan) + + for k, v in bindings.items(): + if hasattr(v, "to_numpy"): + bindings[k] = v.to_numpy() + elif isinstance(v, finchlite.interface.Scalar): + bindings[k] = v.val # Interpret and execute - interpreter = EinsumInterpreter(bindings=plan_parameters) + interpreter = EinsumInterpreter(bindings=bindings) return interpreter(einsum_plan)[0] From 725040e23d18fdddbb2193de99684fa2ae74486b Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 17:30:49 -0400 Subject: [PATCH 32/57] Refactored EinsumLowerer to streamline aggregate handling and removed unused rename_einsum method. Updated logic for processing Aggregate cases to improve clarity and efficiency. --- src/finchlite/autoschedule/einsum.py | 45 +++++++++++----------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 363015e1..47339550 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,11 +1,7 @@ from collections.abc import Callable from typing import Any - -import numpy as np - import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite -from finchlite.algebra.tensor import Tensor from finchlite.finch_logic import ( Aggregate, Alias, @@ -19,7 +15,6 @@ Reorder, Table, ) -from finchlite.interface import Scalar from finchlite.symbolic import gensym @@ -32,15 +27,6 @@ def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: def get_next_alias(self) -> ein.Alias: return ein.Alias(gensym("einsum")) - def rename_einsum( - self, - einsum: ein.Einsum, - new_alias: ein.Alias, - definitions: dict[str, ein.Einsum], - ) -> ein.Einsum: - definitions[new_alias.name] = einsum - return ein.Einsum(einsum.op, new_alias, einsum.idxs, einsum.arg) - def reorder_einsum( self, einsum: ein.Einsum, idxs: tuple[ein.Index, ...] ) -> ein.Einsum: @@ -147,20 +133,25 @@ def compile_operand( ) case Literal(value): return ein.Literal(val=value) - case Aggregate( - _, _, _, _ - ): # aggregate has to be computed seperatley as it's own einsum - aggregate_einsum_alias = self.get_next_alias() - bodies.append( - self.rename_einsum( - self.lower_to_einsum(ex, bodies, bindings, definitions), - aggregate_einsum_alias, - definitions, - ) - ) + case Aggregate(Literal(operation), Literal(init), arg, _): + alias = self.get_next_alias() + remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) + if init != init_value(operation, type(init)): + bodies.append(ein.Einsum( + op = ein.Literal(overwrite), + tns=ein.Alias(alias), + idxs=remaining_idxs, + arg=ein.Literal(init), + )) + bodies.append(ein.Einsum( + op = ein.Literal(operation), + tns=ein.Alias(alias), + idxs=remaining_idxs, + arg=self.compile_operand(arg, bodies, bindings, definitions), + )) return ein.Access( - tns=aggregate_einsum_alias, - idxs=tuple(ein.Index(field.name) for field in ex.fields), + tns=alias, + idxs=remaining_idxs, ) case _: raise Exception(f"Unrecognized logic: {ex}") From 0c26cc55852f2f7c3f34473bad89428a302f612f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 17:41:38 -0400 Subject: [PATCH 33/57] Enhanced EinsumLowerer to support Reformat cases in MapJoin and Aggregate queries, improving flexibility in handling various query formats. --- src/finchlite/autoschedule/einsum.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 47339550..de3e34c9 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -13,11 +13,11 @@ Query, Relabel, Reorder, + Reformat, Table, ) from finchlite.symbolic import gensym - class EinsumLowerer: def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} @@ -43,7 +43,8 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), MapJoin(Literal(operation), args)): + case Query(Alias(name), MapJoin(Literal(operation), args)) |\ + Query(Alias(name), Reformat(_, MapJoin(Literal(operation), args))): args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args @@ -54,7 +55,8 @@ def compile_plan( idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), arg= self.compile_expr(operation, tuple(args_list)), )) - case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)): + case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)) |\ + Query(Alias(name), Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _))): remaining_idxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): bodies.append(ein.Einsum( From ade72926a1f8e3f8bacd623bee264f9dc1bfa4b4 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 17:46:13 -0400 Subject: [PATCH 34/57] * Ran ruff check and ruff format --- src/finchlite/autoschedule/einsum.py | 87 +++++++++++++++++----------- 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index de3e34c9..963379ba 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,5 +1,6 @@ from collections.abc import Callable from typing import Any + import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite from finchlite.finch_logic import ( @@ -11,13 +12,14 @@ Plan, Produces, Query, + Reformat, Relabel, Reorder, - Reformat, Table, ) from finchlite.symbolic import gensym + class EinsumLowerer: def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} @@ -43,34 +45,49 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), MapJoin(Literal(operation), args)) |\ - Query(Alias(name), Reformat(_, MapJoin(Literal(operation), args))): + case Query(Alias(name), MapJoin(Literal(operation), args)) | Query( + Alias(name), Reformat(_, MapJoin(Literal(operation), args)) + ): args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] - bodies.append(ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), - arg= self.compile_expr(operation, tuple(args_list)), - )) - case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)) |\ - Query(Alias(name), Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _))): - remaining_idxs = tuple(ein.Index(field.name) for field in body.rhs.fields) + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple( + ein.Index(field.name) for field in body.rhs.fields + ), + arg=self.compile_expr(operation, tuple(args_list)), + ) + ) + case Query( + Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) + ) | Query( + Alias(name), + Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _)), + ): + einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): - bodies.append(ein.Einsum( - op = ein.Literal(overwrite), + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=einidxs, + arg=ein.Literal(init), + ) + ) + bodies.append( + ein.Einsum( + op=ein.Literal(operation), tns=ein.Alias(name), - idxs=remaining_idxs, - arg=ein.Literal(init), - )) - bodies.append(ein.Einsum( - op = ein.Literal(operation), - tns=ein.Alias(name), - idxs=remaining_idxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), - )) + idxs=einidxs, + arg=self.compile_operand( + arg, bodies, bindings, definitions + ), + ) + ) case Produces(args): returnValues = [] for arg in args: @@ -139,18 +156,22 @@ def compile_operand( alias = self.get_next_alias() remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) if init != init_value(operation, type(init)): - bodies.append(ein.Einsum( - op = ein.Literal(overwrite), + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(alias), + idxs=remaining_idxs, + arg=ein.Literal(init), + ) + ) + bodies.append( + ein.Einsum( + op=ein.Literal(operation), tns=ein.Alias(alias), idxs=remaining_idxs, - arg=ein.Literal(init), - )) - bodies.append(ein.Einsum( - op = ein.Literal(operation), - tns=ein.Alias(alias), - idxs=remaining_idxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), - )) + arg=self.compile_operand(arg, bodies, bindings, definitions), + ) + ) return ein.Access( tns=alias, idxs=remaining_idxs, From 426e12c88ec7fd6e2cc475f3bcc041cfb2444971 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:11:56 -0400 Subject: [PATCH 35/57] * Moved mapjoin and aggregate compilation into seperate functions to avoid code repitition * Fixed pytests --- src/finchlite/autoschedule/einsum.py | 148 +++++++++++++++------------ 1 file changed, 83 insertions(+), 65 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 963379ba..46cb7071 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -16,6 +16,7 @@ Relabel, Reorder, Table, + Field ) from finchlite.symbolic import gensym @@ -26,14 +27,57 @@ def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings - def get_next_alias(self) -> ein.Alias: - return ein.Alias(gensym("einsum")) - def reorder_einsum( self, einsum: ein.Einsum, idxs: tuple[ein.Index, ...] ) -> ein.Einsum: return ein.Einsum(einsum.op, einsum.tns, idxs, einsum.arg) + def compile_mapjoin( + self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], name: str, fields: list[Field], + operation: Callable, args: tuple[ein.EinsumExpr, ...] + ) -> ein.EinsumExpr: + args_list = [ + self.compile_operand(arg, bodies, bindings, definitions) + for arg in args + ] + return ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple( + ein.Index(field.name) for field in fields + ), + arg=self.compile_expr(operation, tuple(args_list)), + ) + + def compile_aggregate( + self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], name: str, fields: list[Field], + operation: Callable, init: Any, arg: LogicNode + ) -> ein.Plan: + einidxs = tuple(ein.Index(field.name) for field in fields) + bodies = [] + if init != init_value(operation, type(init)): + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=einidxs, + arg=ein.Literal(init), + ) + ) + bodies.append( + ein.Einsum( + op=ein.Literal(operation), + tns=ein.Alias(name), + idxs=einidxs, + arg=self.compile_operand( + arg, bodies, bindings, definitions + ), + ) + ) + return ein.Plan(tuple(bodies)) + def compile_plan( self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: @@ -45,49 +89,36 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), MapJoin(Literal(operation), args)) | Query( - Alias(name), Reformat(_, MapJoin(Literal(operation), args)) - ): - args_list = [ - self.compile_operand(arg, bodies, bindings, definitions) - for arg in args - ] - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple( - ein.Index(field.name) for field in body.rhs.fields - ), - arg=self.compile_expr(operation, tuple(args_list)), - ) - ) - case Query( - Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) - ) | Query( - Alias(name), - Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _)), - ): - einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) - if init != init_value(operation, type(init)): - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=einidxs, - arg=ein.Literal(init), - ) - ) - bodies.append( - ein.Einsum( - op=ein.Literal(operation), - tns=ein.Alias(name), - idxs=einidxs, - arg=self.compile_operand( - arg, bodies, bindings, definitions - ), - ) - ) + case Query(Alias(name), MapJoin(Literal(operation), args)): + bodies.append(self.compile_mapjoin( + bodies, bindings, definitions, + name, body.rhs.fields, operation, args + )) + case Query(Alias(name), Reformat(_, MapJoin(Literal(operation), args))): + bodies.append(self.compile_mapjoin( + bodies, bindings, definitions, + name, body.rhs.fields, operation, args + )) + case Query(Alias(name), Reorder(MapJoin(Literal(operation), args), idxs)): + bodies.append(self.compile_mapjoin( + bodies, bindings, definitions, + name, idxs, operation, args + )) + case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)): + bodies.append(self.compile_aggregate( + bodies, bindings, definitions, + name, body.rhs.fields, operation, init, arg + )) + case Query(Alias(name), Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _))): + bodies.append(self.compile_aggregate( + bodies, bindings, definitions, + name, body.rhs.fields, operation, init, arg + )) + case Query(Alias(name), Reorder(Aggregate(Literal(operation), Literal(init), arg, _), idxs)): + bodies.append(self.compile_aggregate( + bodies, bindings, definitions, + name, idxs, operation, init, arg + )) case Produces(args): returnValues = [] for arg in args: @@ -153,28 +184,15 @@ def compile_operand( case Literal(value): return ein.Literal(val=value) case Aggregate(Literal(operation), Literal(init), arg, _): - alias = self.get_next_alias() + alias = ein.Alias(gensym("E")) remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) - if init != init_value(operation, type(init)): - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(alias), - idxs=remaining_idxs, - arg=ein.Literal(init), - ) - ) - bodies.append( - ein.Einsum( - op=ein.Literal(operation), - tns=ein.Alias(alias), - idxs=remaining_idxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), - ) - ) + bodies.append(self.compile_aggregate( + bodies, bindings, definitions, + alias, ex.fields, operation, init, arg + )) return ein.Access( tns=alias, idxs=remaining_idxs, ) case _: - raise Exception(f"Unrecognized logic: {ex}") + raise Exception(f"Unrecognized logic: {ex}") \ No newline at end of file From c94f2d0b89b6d9fd199358a3abad003d0b0bf32b Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:15:56 -0400 Subject: [PATCH 36/57] * Fixed mypy typing errors --- src/finchlite/autoschedule/einsum.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 46cb7071..d2ce6978 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -34,8 +34,8 @@ def reorder_einsum( def compile_mapjoin( self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], name: str, fields: list[Field], - operation: Callable, args: tuple[ein.EinsumExpr, ...] + definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], + operation: Callable, args: tuple[LogicNode, ...] ) -> ein.EinsumExpr: args_list = [ self.compile_operand(arg, bodies, bindings, definitions) @@ -52,7 +52,7 @@ def compile_mapjoin( def compile_aggregate( self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], name: str, fields: list[Field], + definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], operation: Callable, init: Any, arg: LogicNode ) -> ein.Plan: einidxs = tuple(ein.Index(field.name) for field in fields) From 2db61acb5f79fa7563a783b1d9455f4b436965af Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:23:17 -0400 Subject: [PATCH 37/57] * Fixed more mypy errors --- src/finchlite/autoschedule/einsum.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d2ce6978..9aa33242 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -8,6 +8,7 @@ Alias, Literal, LogicNode, + LogicExpression, MapJoin, Plan, Produces, @@ -35,8 +36,8 @@ def reorder_einsum( def compile_mapjoin( self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], - operation: Callable, args: tuple[LogicNode, ...] - ) -> ein.EinsumExpr: + operation: Callable, args: tuple[LogicExpression, ...] + ) -> ein.EinsumNode: args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args @@ -53,8 +54,8 @@ def compile_mapjoin( def compile_aggregate( self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], - operation: Callable, init: Any, arg: LogicNode - ) -> ein.Plan: + operation: Callable, init: Any, arg: LogicExpression + ) -> ein.EinsumNode: einidxs = tuple(ein.Index(field.name) for field in fields) bodies = [] if init != init_value(operation, type(init)): @@ -184,14 +185,15 @@ def compile_operand( case Literal(value): return ein.Literal(val=value) case Aggregate(Literal(operation), Literal(init), arg, _): - alias = ein.Alias(gensym("E")) + alias = gensym("E") remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) bodies.append(self.compile_aggregate( bodies, bindings, definitions, - alias, ex.fields, operation, init, arg + alias, tuple(ex.fields), + operation, init, arg )) return ein.Access( - tns=alias, + tns=ein.Alias(alias), idxs=remaining_idxs, ) case _: From 7f10097e2c16ec234b5730dc4f6cce96487c6b68 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:32:59 -0400 Subject: [PATCH 38/57] * Finally fixed all mypy errors --- src/finchlite/autoschedule/einsum.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 9aa33242..be053f88 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -122,10 +122,10 @@ def compile_plan( )) case Produces(args): returnValues = [] - for arg in args: - if not isinstance(arg, Alias): - raise Exception(f"Unrecognized logic: {arg}") - returnValues.append(ein.Alias(arg.name)) + for ret_arg in args: + if not isinstance(ret_arg, Alias): + raise Exception(f"Unrecognized logic: {ret_arg}") + returnValues.append(ein.Alias(ret_arg.name)) bodies.append(ein.Produces(tuple(returnValues))) case _: From 6aaceeea77ebc3e5a8eaca8f3b973c2109f0ac54 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:34:42 -0400 Subject: [PATCH 39/57] * Ran ruff --- src/finchlite/autoschedule/einsum.py | 164 ++++++++++++++++++--------- 1 file changed, 113 insertions(+), 51 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index be053f88..bf9affaa 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -6,9 +6,10 @@ from finchlite.finch_logic import ( Aggregate, Alias, + Field, Literal, - LogicNode, LogicExpression, + LogicNode, MapJoin, Plan, Produces, @@ -17,7 +18,6 @@ Relabel, Reorder, Table, - Field ) from finchlite.symbolic import gensym @@ -34,27 +34,35 @@ def reorder_einsum( return ein.Einsum(einsum.op, einsum.tns, idxs, einsum.arg) def compile_mapjoin( - self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], - operation: Callable, args: tuple[LogicExpression, ...] + self, + bodies: list[ein.EinsumNode], + bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], + name: str, + fields: tuple[Field, ...], + operation: Callable, + args: tuple[LogicExpression, ...], ) -> ein.EinsumNode: args_list = [ - self.compile_operand(arg, bodies, bindings, definitions) - for arg in args + self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] return ein.Einsum( op=ein.Literal(overwrite), tns=ein.Alias(name), - idxs=tuple( - ein.Index(field.name) for field in fields - ), + idxs=tuple(ein.Index(field.name) for field in fields), arg=self.compile_expr(operation, tuple(args_list)), ) def compile_aggregate( - self, bodies: list[ein.EinsumNode], bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], name: str, fields: tuple[Field, ...], - operation: Callable, init: Any, arg: LogicExpression + self, + bodies: list[ein.EinsumNode], + bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], + name: str, + fields: tuple[Field, ...], + operation: Callable, + init: Any, + arg: LogicExpression, ) -> ein.EinsumNode: einidxs = tuple(ein.Index(field.name) for field in fields) bodies = [] @@ -72,9 +80,7 @@ def compile_aggregate( op=ein.Literal(operation), tns=ein.Alias(name), idxs=einidxs, - arg=self.compile_operand( - arg, bodies, bindings, definitions - ), + arg=self.compile_operand(arg, bodies, bindings, definitions), ) ) return ein.Plan(tuple(bodies)) @@ -90,36 +96,85 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), MapJoin(Literal(operation), args)): - bodies.append(self.compile_mapjoin( - bodies, bindings, definitions, - name, body.rhs.fields, operation, args - )) + case Query(Alias(name), MapJoin(Literal(operation), args)): + bodies.append( + self.compile_mapjoin( + bodies, + bindings, + definitions, + name, + body.rhs.fields, + operation, + args, + ) + ) case Query(Alias(name), Reformat(_, MapJoin(Literal(operation), args))): - bodies.append(self.compile_mapjoin( - bodies, bindings, definitions, - name, body.rhs.fields, operation, args - )) - case Query(Alias(name), Reorder(MapJoin(Literal(operation), args), idxs)): - bodies.append(self.compile_mapjoin( - bodies, bindings, definitions, - name, idxs, operation, args - )) - case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _)): - bodies.append(self.compile_aggregate( - bodies, bindings, definitions, - name, body.rhs.fields, operation, init, arg - )) - case Query(Alias(name), Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _))): - bodies.append(self.compile_aggregate( - bodies, bindings, definitions, - name, body.rhs.fields, operation, init, arg - )) - case Query(Alias(name), Reorder(Aggregate(Literal(operation), Literal(init), arg, _), idxs)): - bodies.append(self.compile_aggregate( - bodies, bindings, definitions, - name, idxs, operation, init, arg - )) + bodies.append( + self.compile_mapjoin( + bodies, + bindings, + definitions, + name, + body.rhs.fields, + operation, + args, + ) + ) + case Query( + Alias(name), Reorder(MapJoin(Literal(operation), args), idxs) + ): + bodies.append( + self.compile_mapjoin( + bodies, bindings, definitions, name, idxs, operation, args + ) + ) + case Query( + Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) + ): + bodies.append( + self.compile_aggregate( + bodies, + bindings, + definitions, + name, + body.rhs.fields, + operation, + init, + arg, + ) + ) + case Query( + Alias(name), + Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _)), + ): + bodies.append( + self.compile_aggregate( + bodies, + bindings, + definitions, + name, + body.rhs.fields, + operation, + init, + arg, + ) + ) + case Query( + Alias(name), + Reorder(Aggregate(Literal(operation), Literal(init), arg, _), idxs), + ): + bodies.append( + self.compile_aggregate( + bodies, + bindings, + definitions, + name, + idxs, + operation, + init, + arg, + ) + ) case Produces(args): returnValues = [] for ret_arg in args: @@ -187,14 +242,21 @@ def compile_operand( case Aggregate(Literal(operation), Literal(init), arg, _): alias = gensym("E") remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) - bodies.append(self.compile_aggregate( - bodies, bindings, definitions, - alias, tuple(ex.fields), - operation, init, arg - )) + bodies.append( + self.compile_aggregate( + bodies, + bindings, + definitions, + alias, + tuple(ex.fields), + operation, + init, + arg, + ) + ) return ein.Access( tns=ein.Alias(alias), idxs=remaining_idxs, ) case _: - raise Exception(f"Unrecognized logic: {ex}") \ No newline at end of file + raise Exception(f"Unrecognized logic: {ex}") From b3777484db41ed6095aa0d2df21da8359cbbf85f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 18:54:56 -0400 Subject: [PATCH 40/57] * Undid changes to effectively unchanged files --- src/finchlite/finch_einsum/interpreter.py | 8 +++----- src/finchlite/interface/lazy.py | 2 -- tests/reference/test_einsum_printer.txt | 2 +- tests/test_printers.py | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 20806039..01f256fe 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -95,20 +95,18 @@ def __call__(self, node): case ein.Access(tns, idxs): assert len(idxs) == len(set(idxs)) assert self.loops is not None - perm = [idxs.index(idx) for idx in self.loops if idx in idxs] tns = self(tns) - tns = xp.permute_dims(tns, perm) return xp.expand_dims( tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) case ein.Plan(bodies): - returnVal = None + res = None for body in bodies: - returnVal = self(body) # execute each einsum statement individually - return returnVal + res = self(body) + return res case ein.Produces(args): return tuple(self(arg) for arg in args) case ein.Einsum(op, ein.Alias(tns), idxs, arg): diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index ade2247f..638affaa 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -1806,7 +1806,6 @@ def std( def einop(prgm, **kwargs): stmt = ein.parse_einop(prgm) prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) - xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, dict(**kwargs)) return ctx(prgm)[0] @@ -1815,7 +1814,6 @@ def einop(prgm, **kwargs): def einsum(prgm, *args, **kwargs): stmt, bindings = ein.parse_einsum(prgm, *args) prgm = ein.Plan((stmt, ein.Produces((stmt.tns,)))) - xp = sys.modules[__name__] ctx = ein.EinsumInterpreter(xp, bindings) return ctx(prgm)[0] diff --git a/tests/reference/test_einsum_printer.txt b/tests/reference/test_einsum_printer.txt index 569a686e..f284c993 100644 --- a/tests/reference/test_einsum_printer.txt +++ b/tests/reference/test_einsum_printer.txt @@ -3,4 +3,4 @@ plan: plan: D[i, j] += (A[i, k] * B[k, j]) E[i] min= lshift((A[i, k] + D[k, j]), 1) - return ('C', 'E') \ No newline at end of file + return ('C', 'E') diff --git a/tests/test_printers.py b/tests/test_printers.py index e3186736..e5b373a6 100644 --- a/tests/test_printers.py +++ b/tests/test_printers.py @@ -543,7 +543,7 @@ def test_einsum_printer(file_regression): ) ), ein.Produces((ein.Alias("C"), ein.Alias("E"))), - ), + ) ) file_regression.check(str(prgm), extension=".txt") From fe29f8333d1210614fd882058aa63da73425fa64 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 20:28:06 -0400 Subject: [PATCH 41/57] Refactored EinsumLowerer by removing the compile_aggregate method and inlining its logic directly into the compile_plan method. This change simplifies the code and enhances clarity in handling Aggregate cases. --- src/finchlite/autoschedule/einsum.py | 110 ++++----------------------- 1 file changed, 14 insertions(+), 96 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index bf9affaa..e2783aab 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -28,11 +28,6 @@ def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings - def reorder_einsum( - self, einsum: ein.Einsum, idxs: tuple[ein.Index, ...] - ) -> ein.Einsum: - return ein.Einsum(einsum.op, einsum.tns, idxs, einsum.arg) - def compile_mapjoin( self, bodies: list[ein.EinsumNode], @@ -53,38 +48,6 @@ def compile_mapjoin( arg=self.compile_expr(operation, tuple(args_list)), ) - def compile_aggregate( - self, - bodies: list[ein.EinsumNode], - bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], - name: str, - fields: tuple[Field, ...], - operation: Callable, - init: Any, - arg: LogicExpression, - ) -> ein.EinsumNode: - einidxs = tuple(ein.Index(field.name) for field in fields) - bodies = [] - if init != init_value(operation, type(init)): - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=einidxs, - arg=ein.Literal(init), - ) - ) - bodies.append( - ein.Einsum( - op=ein.Literal(operation), - tns=ein.Alias(name), - idxs=einidxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), - ) - ) - return ein.Plan(tuple(bodies)) - def compile_plan( self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: @@ -131,48 +94,22 @@ def compile_plan( case Query( Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) ): - bodies.append( - self.compile_aggregate( - bodies, - bindings, - definitions, - name, - body.rhs.fields, - operation, - init, - arg, - ) - ) - case Query( - Alias(name), - Reformat(_, Aggregate(Literal(operation), Literal(init), arg, _)), - ): - bodies.append( - self.compile_aggregate( - bodies, - bindings, - definitions, - name, - body.rhs.fields, - operation, - init, - arg, + einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) + if init != init_value(operation, type(init)): + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=einidxs, + arg=ein.Literal(init), + ) ) - ) - case Query( - Alias(name), - Reorder(Aggregate(Literal(operation), Literal(init), arg, _), idxs), - ): bodies.append( - self.compile_aggregate( - bodies, - bindings, - definitions, - name, - idxs, - operation, - init, - arg, + ein.Einsum( + op=ein.Literal(operation), + tns=ein.Alias(name), + idxs=einidxs, + arg=self.compile_operand(arg, bodies, bindings, definitions), ) ) case Produces(args): @@ -239,24 +176,5 @@ def compile_operand( ) case Literal(value): return ein.Literal(val=value) - case Aggregate(Literal(operation), Literal(init), arg, _): - alias = gensym("E") - remaining_idxs = tuple(ein.Index(field.name) for field in ex.fields) - bodies.append( - self.compile_aggregate( - bodies, - bindings, - definitions, - alias, - tuple(ex.fields), - operation, - init, - arg, - ) - ) - return ein.Access( - tns=ein.Alias(alias), - idxs=remaining_idxs, - ) case _: raise Exception(f"Unrecognized logic: {ex}") From 89557e2346581cfe079743350538ccf77e026d79 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 20:57:46 -0400 Subject: [PATCH 42/57] Refactored EinsumLowerer to consolidate mapjoin handling into a single case structure, improving code clarity and reducing redundancy. This change enhances the handling of various query formats, including Reformat and Reorder cases. --- src/finchlite/autoschedule/einsum.py | 76 +++++++++------------------- 1 file changed, 24 insertions(+), 52 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index e2783aab..f2a58b1a 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -28,26 +28,6 @@ def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings - def compile_mapjoin( - self, - bodies: list[ein.EinsumNode], - bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], - name: str, - fields: tuple[Field, ...], - operation: Callable, - args: tuple[LogicExpression, ...], - ) -> ein.EinsumNode: - args_list = [ - self.compile_operand(arg, bodies, bindings, definitions) for arg in args - ] - return ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple(ein.Index(field.name) for field in fields), - arg=self.compile_expr(operation, tuple(args_list)), - ) - def compile_plan( self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: @@ -59,38 +39,30 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), MapJoin(Literal(operation), args)): - bodies.append( - self.compile_mapjoin( - bodies, - bindings, - definitions, - name, - body.rhs.fields, - operation, - args, - ) - ) - case Query(Alias(name), Reformat(_, MapJoin(Literal(operation), args))): - bodies.append( - self.compile_mapjoin( - bodies, - bindings, - definitions, - name, - body.rhs.fields, - operation, - args, - ) - ) - case Query( - Alias(name), Reorder(MapJoin(Literal(operation), args), idxs) - ): - bodies.append( - self.compile_mapjoin( - bodies, bindings, definitions, name, idxs, operation, args - ) - ) + case Query(Alias(name), rhs): + einarg = self.compile_operand(rhs, bodies, bindings, definitions) + bodies.append(ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), + arg=einarg, + )) + case Query(Alias(name), Reformat(_, rhs)): + einarg = self.compile_operand(rhs, bodies, bindings, definitions) + bodies.append(ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), + arg=einarg, + )) + case Query(Alias(name), Reorder(rhs, idxs)): + einarg = self.compile_operand(rhs, bodies, bindings, definitions) + bodies.append(ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=idxs, + arg=einarg, + )) case Query( Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) ): From d34b5b3518285df8215a31b74b335ef11830edf8 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 21:08:05 -0400 Subject: [PATCH 43/57] Refactored EinsumLowerer to streamline Aggregate case handling by consolidating logic into a single case structure, improving code clarity and reducing redundancy. This update also removes outdated code, enhancing maintainability. --- src/finchlite/autoschedule/einsum.py | 41 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index f2a58b1a..0115a81e 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -39,6 +39,26 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val + case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) + ) | Query(Alias(name), Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _)): + einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) + if init != init_value(operation, type(init)): + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=einidxs, + arg=ein.Literal(init), + ) + ) + bodies.append( + ein.Einsum( + op=ein.Literal(operation), + tns=ein.Alias(name), + idxs=einidxs, + arg=self.compile_operand(arg, bodies, bindings, definitions), + ) + ) case Query(Alias(name), rhs): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append(ein.Einsum( @@ -63,27 +83,6 @@ def compile_plan( idxs=idxs, arg=einarg, )) - case Query( - Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) - ): - einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) - if init != init_value(operation, type(init)): - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=einidxs, - arg=ein.Literal(init), - ) - ) - bodies.append( - ein.Einsum( - op=ein.Literal(operation), - tns=ein.Alias(name), - idxs=einidxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), - ) - ) case Produces(args): returnValues = [] for ret_arg in args: From 07c490328b7b8723aed625aae4e873abdf0bf5ea Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 21 Oct 2025 21:12:46 -0400 Subject: [PATCH 44/57] * Ran pytests, ruff, and mypy --- src/finchlite/autoschedule/einsum.py | 61 +++++++++++++++++----------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 0115a81e..c2a5e3b1 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -6,9 +6,7 @@ from finchlite.finch_logic import ( Aggregate, Alias, - Field, Literal, - LogicExpression, LogicNode, MapJoin, Plan, @@ -19,7 +17,6 @@ Reorder, Table, ) -from finchlite.symbolic import gensym class EinsumLowerer: @@ -39,8 +36,12 @@ def compile_plan( bodies.append(self.compile_plan(body, bindings, definitions)) case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case Query(Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) - ) | Query(Alias(name), Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _)): + case Query( + Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) + ) | Query( + Alias(name), + Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _), + ): einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): bodies.append( @@ -56,33 +57,45 @@ def compile_plan( op=ein.Literal(operation), tns=ein.Alias(name), idxs=einidxs, - arg=self.compile_operand(arg, bodies, bindings, definitions), + arg=self.compile_operand( + arg, bodies, bindings, definitions + ), ) ) case Query(Alias(name), rhs): einarg = self.compile_operand(rhs, bodies, bindings, definitions) - bodies.append(ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), - arg=einarg, - )) + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple( + ein.Index(field.name) for field in body.rhs.fields + ), + arg=einarg, + ) + ) case Query(Alias(name), Reformat(_, rhs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) - bodies.append(ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple(ein.Index(field.name) for field in body.rhs.fields), - arg=einarg, - )) + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple( + ein.Index(field.name) for field in body.rhs.fields + ), + arg=einarg, + ) + ) case Query(Alias(name), Reorder(rhs, idxs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) - bodies.append(ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=idxs, - arg=einarg, - )) + bodies.append( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple(ein.Index(idx.name) for idx in idxs), + arg=einarg, + ) + ) case Produces(args): returnValues = [] for ret_arg in args: From 0ebaa91b6b1abb28b42093c73c7b58bb377d015d Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Fri, 24 Oct 2025 17:01:03 -0400 Subject: [PATCH 45/57] suggested changes for clarity, more to come --- src/finchlite/autoschedule/einsum.py | 57 +++++++++++----------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index c2a5e3b1..d86d84d3 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -3,46 +3,33 @@ import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite -from finchlite.finch_logic import ( - Aggregate, - Alias, - Literal, - LogicNode, - MapJoin, - Plan, - Produces, - Query, - Reformat, - Relabel, - Reorder, - Table, -) - +import finchlite.finch_logic as lgc +from finchlite.finch_logic import LogicNode class EinsumLowerer: - def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: + def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings def compile_plan( - self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] + self, plan: lgc.Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: bodies: list[ein.EinsumNode] = [] for body in plan.bodies: match body: - case Plan(_): + case lgc.Plan(_): bodies.append(self.compile_plan(body, bindings, definitions)) - case Query(Alias(name), Table(Literal(val), _)): + case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): bindings[name] = val - case Query( - Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) - ) | Query( - Alias(name), - Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _), + case lgc.Query( + lgc.Alias(name), lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, idxs) + ) | lgc.Query( + lgc.Alias(name), + lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), lgc.Reorder(arg, _), idxs), ): - einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) + einidxs = tuple(ein.Index(field.name) for field in idxs) if init != init_value(operation, type(init)): bodies.append( ein.Einsum( @@ -62,7 +49,7 @@ def compile_plan( ), ) ) - case Query(Alias(name), rhs): + case lgc.Query(lgc.Alias(name), rhs): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -74,7 +61,7 @@ def compile_plan( arg=einarg, ) ) - case Query(Alias(name), Reformat(_, rhs)): + case lgc.Query(lgc.Alias(name), lgc.Reformat(_, rhs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -86,7 +73,7 @@ def compile_plan( arg=einarg, ) ) - case Query(Alias(name), Reorder(rhs, idxs)): + case lgc.Query(lgc.Alias(name), lgc.Reorder(rhs, idxs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -96,10 +83,10 @@ def compile_plan( arg=einarg, ) ) - case Produces(args): + case lgc.Produces(args): returnValues = [] for ret_arg in args: - if not isinstance(ret_arg, Alias): + if not isinstance(ret_arg, lgc.Alias): raise Exception(f"Unrecognized logic: {ret_arg}") returnValues.append(ein.Alias(ret_arg.name)) @@ -143,22 +130,22 @@ def compile_operand( definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: - case Reorder(arg, idxs): + case lgc.Reorder(arg, idxs): return self.compile_operand(arg, bodies, bindings, definitions) - case MapJoin(Literal(operation), args): + case lgc.MapJoin(lgc.Literal(operation), args): args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] return self.compile_expr(operation, tuple(args_list)) - case Relabel( - Alias(name), idxs + case lgc.Relabel( + lgc.Alias(name), idxs ): # relable is really just a glorified pointwise access return ein.Access( tns=ein.Alias(name), idxs=tuple(ein.Index(idx.name) for idx in idxs), ) - case Literal(value): + case lgc.Literal(value): return ein.Literal(val=value) case _: raise Exception(f"Unrecognized logic: {ex}") From 97a815eb3c84b58d1a16f95ba6eb36a83f934276 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 09:32:34 -0400 Subject: [PATCH 46/57] Revert "suggested changes for clarity, more to come" This reverts commit 0ebaa91b6b1abb28b42093c73c7b58bb377d015d. --- src/finchlite/autoschedule/einsum.py | 57 +++++++++++++++++----------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d86d84d3..c2a5e3b1 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -3,33 +3,46 @@ import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite -import finchlite.finch_logic as lgc -from finchlite.finch_logic import LogicNode +from finchlite.finch_logic import ( + Aggregate, + Alias, + Literal, + LogicNode, + MapJoin, + Plan, + Produces, + Query, + Reformat, + Relabel, + Reorder, + Table, +) + class EinsumLowerer: - def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: + def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings def compile_plan( - self, plan: lgc.Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] + self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: bodies: list[ein.EinsumNode] = [] for body in plan.bodies: match body: - case lgc.Plan(_): + case Plan(_): bodies.append(self.compile_plan(body, bindings, definitions)) - case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): + case Query(Alias(name), Table(Literal(val), _)): bindings[name] = val - case lgc.Query( - lgc.Alias(name), lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, idxs) - ) | lgc.Query( - lgc.Alias(name), - lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), lgc.Reorder(arg, _), idxs), + case Query( + Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) + ) | Query( + Alias(name), + Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _), ): - einidxs = tuple(ein.Index(field.name) for field in idxs) + einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): bodies.append( ein.Einsum( @@ -49,7 +62,7 @@ def compile_plan( ), ) ) - case lgc.Query(lgc.Alias(name), rhs): + case Query(Alias(name), rhs): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -61,7 +74,7 @@ def compile_plan( arg=einarg, ) ) - case lgc.Query(lgc.Alias(name), lgc.Reformat(_, rhs)): + case Query(Alias(name), Reformat(_, rhs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -73,7 +86,7 @@ def compile_plan( arg=einarg, ) ) - case lgc.Query(lgc.Alias(name), lgc.Reorder(rhs, idxs)): + case Query(Alias(name), Reorder(rhs, idxs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -83,10 +96,10 @@ def compile_plan( arg=einarg, ) ) - case lgc.Produces(args): + case Produces(args): returnValues = [] for ret_arg in args: - if not isinstance(ret_arg, lgc.Alias): + if not isinstance(ret_arg, Alias): raise Exception(f"Unrecognized logic: {ret_arg}") returnValues.append(ein.Alias(ret_arg.name)) @@ -130,22 +143,22 @@ def compile_operand( definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: - case lgc.Reorder(arg, idxs): + case Reorder(arg, idxs): return self.compile_operand(arg, bodies, bindings, definitions) - case lgc.MapJoin(lgc.Literal(operation), args): + case MapJoin(Literal(operation), args): args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] return self.compile_expr(operation, tuple(args_list)) - case lgc.Relabel( - lgc.Alias(name), idxs + case Relabel( + Alias(name), idxs ): # relable is really just a glorified pointwise access return ein.Access( tns=ein.Alias(name), idxs=tuple(ein.Index(idx.name) for idx in idxs), ) - case lgc.Literal(value): + case Literal(value): return ein.Literal(val=value) case _: raise Exception(f"Unrecognized logic: {ex}") From 523e135cbb25591a1efddfd707b525e87e15c802 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 09:37:09 -0400 Subject: [PATCH 47/57] * Removed unecessary imports from logic in einsum lowerer * Imported logic AST nodes will be refered to with lgc. prefix prependded for clarity in einsumlowerer --- src/finchlite/autoschedule/einsum.py | 55 +++++++++++----------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index c2a5e3b1..193a2cb9 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -3,44 +3,31 @@ import finchlite.finch_einsum as ein from finchlite.algebra import init_value, is_commutative, overwrite -from finchlite.finch_logic import ( - Aggregate, - Alias, - Literal, - LogicNode, - MapJoin, - Plan, - Produces, - Query, - Reformat, - Relabel, - Reorder, - Table, -) +import finchlite.finch_logic as lgc class EinsumLowerer: - def __call__(self, prgm: Plan) -> tuple[ein.Plan, dict[str, Any]]: + def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} return self.compile_plan(prgm, bindings, definitions), bindings def compile_plan( - self, plan: Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] + self, plan: lgc.Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] ) -> ein.Plan: bodies: list[ein.EinsumNode] = [] for body in plan.bodies: match body: - case Plan(_): + case lgc.Plan(_): bodies.append(self.compile_plan(body, bindings, definitions)) - case Query(Alias(name), Table(Literal(val), _)): + case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): bindings[name] = val - case Query( - Alias(name), Aggregate(Literal(operation), Literal(init), arg, _) - ) | Query( - Alias(name), - Aggregate(Literal(operation), Literal(init), Reorder(arg, _), _), + case lgc.Query( + lgc.Alias(name), lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _) + ) | lgc.Query( + lgc.Alias(name), + lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), lgc.Reorder(arg, _), _), ): einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): @@ -62,7 +49,7 @@ def compile_plan( ), ) ) - case Query(Alias(name), rhs): + case lgc.Query(lgc.Alias(name), rhs): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -74,7 +61,7 @@ def compile_plan( arg=einarg, ) ) - case Query(Alias(name), Reformat(_, rhs)): + case lgc.Query(lgc.Alias(name), lgc.Reformat(_, rhs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -86,7 +73,7 @@ def compile_plan( arg=einarg, ) ) - case Query(Alias(name), Reorder(rhs, idxs)): + case lgc.Query(lgc.Alias(name), lgc.Reorder(rhs, idxs)): einarg = self.compile_operand(rhs, bodies, bindings, definitions) bodies.append( ein.Einsum( @@ -96,10 +83,10 @@ def compile_plan( arg=einarg, ) ) - case Produces(args): + case lgc.Produces(args): returnValues = [] for ret_arg in args: - if not isinstance(ret_arg, Alias): + if not isinstance(ret_arg, lgc.Alias): raise Exception(f"Unrecognized logic: {ret_arg}") returnValues.append(ein.Alias(ret_arg.name)) @@ -137,28 +124,28 @@ def flatten_args( # lowers nested mapjoin logic IR nodes into a single pointwise expression def compile_operand( self, - ex: LogicNode, + ex: lgc.LogicNode, bodies: list[ein.EinsumNode], bindings: dict[str, Any], definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: - case Reorder(arg, idxs): + case lgc.Reorder(arg, idxs): return self.compile_operand(arg, bodies, bindings, definitions) - case MapJoin(Literal(operation), args): + case lgc.MapJoin(lgc.Literal(operation), args): args_list = [ self.compile_operand(arg, bodies, bindings, definitions) for arg in args ] return self.compile_expr(operation, tuple(args_list)) - case Relabel( - Alias(name), idxs + case lgc.Relabel( + lgc.Alias(name), idxs ): # relable is really just a glorified pointwise access return ein.Access( tns=ein.Alias(name), idxs=tuple(ein.Index(idx.name) for idx in idxs), ) - case Literal(value): + case lgc.Literal(value): return ein.Literal(val=value) case _: raise Exception(f"Unrecognized logic: {ex}") From cbf5f4385e9efec28ae8e807ad57019d4d488394 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 09:46:59 -0400 Subject: [PATCH 48/57] * Removed unecessary cases, all of which are optimized out --- src/finchlite/autoschedule/einsum.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 193a2cb9..04185dcc 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -25,9 +25,6 @@ def compile_plan( bindings[name] = val case lgc.Query( lgc.Alias(name), lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _) - ) | lgc.Query( - lgc.Alias(name), - lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), lgc.Reorder(arg, _), _), ): einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): @@ -61,28 +58,6 @@ def compile_plan( arg=einarg, ) ) - case lgc.Query(lgc.Alias(name), lgc.Reformat(_, rhs)): - einarg = self.compile_operand(rhs, bodies, bindings, definitions) - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple( - ein.Index(field.name) for field in body.rhs.fields - ), - arg=einarg, - ) - ) - case lgc.Query(lgc.Alias(name), lgc.Reorder(rhs, idxs)): - einarg = self.compile_operand(rhs, bodies, bindings, definitions) - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple(ein.Index(idx.name) for idx in idxs), - arg=einarg, - ) - ) case lgc.Produces(args): returnValues = [] for ret_arg in args: From 87ba773389f36948b725158c326d26d4cae2c895 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 10:03:22 -0400 Subject: [PATCH 49/57] * Added case for logical reformat AST node in compile operand --- src/finchlite/autoschedule/einsum.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 04185dcc..14f9a6e7 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -105,6 +105,8 @@ def compile_operand( definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: match ex: + case lgc.Reformat(_, rhs): + return self.compile_operand(rhs, bodies, bindings, definitions) case lgc.Reorder(arg, idxs): return self.compile_operand(arg, bodies, bindings, definitions) case lgc.MapJoin(lgc.Literal(operation), args): From e0095f097a9c20d5a4ca3b32ff5158234702967c Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 10:05:25 -0400 Subject: [PATCH 50/57] Inlined compile_expression method --- src/finchlite/autoschedule/einsum.py | 38 ++++++++++------------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 14f9a6e7..abc5c31f 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -71,15 +71,15 @@ def compile_plan( return ein.Plan(tuple(bodies)) - def compile_expr( - self, operation: Callable, args: tuple[ein.EinsumExpr, ...] + # lowers nested mapjoin logic IR nodes into a single pointwise expression + def compile_operand( + self, + ex: lgc.LogicNode, + bodies: list[ein.EinsumNode], + bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: - # if operation is commutative, we simply pass - # all the args to the pointwise op since - # order of args does not matter - if is_commutative(operation): - - def flatten_args( + def flatten_args( m_args: tuple[ein.EinsumExpr, ...], ) -> tuple[ein.EinsumExpr, ...]: ret_args: list[ein.EinsumExpr] = [] @@ -90,31 +90,19 @@ def flatten_args( case _: ret_args.append(arg) return tuple(ret_args) - - return ein.Call(ein.Literal(operation), flatten_args(args)) - - # combine args from left to right (i.e a / b / c -> (a / b) / c) - return ein.Call(ein.Literal(operation), args) - - # lowers nested mapjoin logic IR nodes into a single pointwise expression - def compile_operand( - self, - ex: lgc.LogicNode, - bodies: list[ein.EinsumNode], - bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], - ) -> ein.EinsumExpr: + match ex: case lgc.Reformat(_, rhs): return self.compile_operand(rhs, bodies, bindings, definitions) case lgc.Reorder(arg, idxs): return self.compile_operand(arg, bodies, bindings, definitions) case lgc.MapJoin(lgc.Literal(operation), args): - args_list = [ + args = tuple([ self.compile_operand(arg, bodies, bindings, definitions) for arg in args - ] - return self.compile_expr(operation, tuple(args_list)) + ]) + return ein.Call(ein.Literal(operation), args + if is_commutative(operation) else flatten_args(args)) case lgc.Relabel( lgc.Alias(name), idxs ): # relable is really just a glorified pointwise access From 4583f9ca0242fa67567f04b4f1fed7fad3117e54 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 10:07:05 -0400 Subject: [PATCH 51/57] * Removed unused parameters from compile operand --- src/finchlite/autoschedule/einsum.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index abc5c31f..9eafb10a 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -41,13 +41,11 @@ def compile_plan( op=ein.Literal(operation), tns=ein.Alias(name), idxs=einidxs, - arg=self.compile_operand( - arg, bodies, bindings, definitions - ), + arg=self.compile_operand(arg), ) ) case lgc.Query(lgc.Alias(name), rhs): - einarg = self.compile_operand(rhs, bodies, bindings, definitions) + einarg = self.compile_operand(rhs) bodies.append( ein.Einsum( op=ein.Literal(overwrite), @@ -75,9 +73,6 @@ def compile_plan( def compile_operand( self, ex: lgc.LogicNode, - bodies: list[ein.EinsumNode], - bindings: dict[str, Any], - definitions: dict[str, ein.Einsum], ) -> ein.EinsumExpr: def flatten_args( m_args: tuple[ein.EinsumExpr, ...], @@ -93,12 +88,12 @@ def flatten_args( match ex: case lgc.Reformat(_, rhs): - return self.compile_operand(rhs, bodies, bindings, definitions) + return self.compile_operand(rhs) case lgc.Reorder(arg, idxs): - return self.compile_operand(arg, bodies, bindings, definitions) + return self.compile_operand(arg) case lgc.MapJoin(lgc.Literal(operation), args): args = tuple([ - self.compile_operand(arg, bodies, bindings, definitions) + self.compile_operand(arg) for arg in args ]) return ein.Call(ein.Literal(operation), args From 8ca2eb788a97d3eab028d5c86aa28d4340c195d3 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 10:11:45 -0400 Subject: [PATCH 52/57] * Fixed potential mypy error in EinsumLowerer compile_operand --- src/finchlite/autoschedule/einsum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 9eafb10a..224b81f9 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -91,10 +91,10 @@ def flatten_args( return self.compile_operand(rhs) case lgc.Reorder(arg, idxs): return self.compile_operand(arg) - case lgc.MapJoin(lgc.Literal(operation), args): + case lgc.MapJoin(lgc.Literal(operation), lgcargs): args = tuple([ self.compile_operand(arg) - for arg in args + for arg in lgcargs ]) return ein.Call(ein.Literal(operation), args if is_commutative(operation) else flatten_args(args)) From 08dabeb45e965c90e19b6cd5175b0845f73ca827 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 10:15:06 -0400 Subject: [PATCH 53/57] * Fixed ruff errors --- src/finchlite/autoschedule/einsum.py | 44 +++++++++++++++------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 224b81f9..90695266 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,9 +1,8 @@ -from collections.abc import Callable from typing import Any import finchlite.finch_einsum as ein -from finchlite.algebra import init_value, is_commutative, overwrite import finchlite.finch_logic as lgc +from finchlite.algebra import init_value, is_commutative, overwrite class EinsumLowerer: @@ -13,7 +12,10 @@ def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: return self.compile_plan(prgm, bindings, definitions), bindings def compile_plan( - self, plan: lgc.Plan, bindings: dict[str, Any], definitions: dict[str, ein.Einsum] + self, + plan: lgc.Plan, + bindings: dict[str, Any], + definitions: dict[str, ein.Einsum], ) -> ein.Plan: bodies: list[ein.EinsumNode] = [] @@ -24,7 +26,8 @@ def compile_plan( case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): bindings[name] = val case lgc.Query( - lgc.Alias(name), lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _) + lgc.Alias(name), + lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _), ): einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) if init != init_value(operation, type(init)): @@ -75,29 +78,28 @@ def compile_operand( ex: lgc.LogicNode, ) -> ein.EinsumExpr: def flatten_args( - m_args: tuple[ein.EinsumExpr, ...], - ) -> tuple[ein.EinsumExpr, ...]: - ret_args: list[ein.EinsumExpr] = [] - for arg in m_args: - match arg: - case ein.Call(ein.Literal(op2), _) if op2 == operation: - ret_args.extend(flatten_args(arg.args)) - case _: - ret_args.append(arg) - return tuple(ret_args) - + m_args: tuple[ein.EinsumExpr, ...], + ) -> tuple[ein.EinsumExpr, ...]: + ret_args: list[ein.EinsumExpr] = [] + for arg in m_args: + match arg: + case ein.Call(ein.Literal(op2), _) if op2 == operation: + ret_args.extend(flatten_args(arg.args)) + case _: + ret_args.append(arg) + return tuple(ret_args) + match ex: case lgc.Reformat(_, rhs): return self.compile_operand(rhs) case lgc.Reorder(arg, idxs): return self.compile_operand(arg) case lgc.MapJoin(lgc.Literal(operation), lgcargs): - args = tuple([ - self.compile_operand(arg) - for arg in lgcargs - ]) - return ein.Call(ein.Literal(operation), args - if is_commutative(operation) else flatten_args(args)) + args = tuple([self.compile_operand(arg) for arg in lgcargs]) + return ein.Call( + ein.Literal(operation), + args if is_commutative(operation) else flatten_args(args), + ) case lgc.Relabel( lgc.Alias(name), idxs ): # relable is really just a glorified pointwise access From 140051c8da4d72bb52f7804209e34cf869f8ca42 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 11:24:38 -0400 Subject: [PATCH 54/57] * Removed flatten args * Will add back in another optimization pass, sometime in the future --- src/finchlite/autoschedule/einsum.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 90695266..2129294a 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -2,7 +2,7 @@ import finchlite.finch_einsum as ein import finchlite.finch_logic as lgc -from finchlite.algebra import init_value, is_commutative, overwrite +from finchlite.algebra import init_value, overwrite class EinsumLowerer: @@ -77,18 +77,6 @@ def compile_operand( self, ex: lgc.LogicNode, ) -> ein.EinsumExpr: - def flatten_args( - m_args: tuple[ein.EinsumExpr, ...], - ) -> tuple[ein.EinsumExpr, ...]: - ret_args: list[ein.EinsumExpr] = [] - for arg in m_args: - match arg: - case ein.Call(ein.Literal(op2), _) if op2 == operation: - ret_args.extend(flatten_args(arg.args)) - case _: - ret_args.append(arg) - return tuple(ret_args) - match ex: case lgc.Reformat(_, rhs): return self.compile_operand(rhs) @@ -96,10 +84,7 @@ def flatten_args( return self.compile_operand(arg) case lgc.MapJoin(lgc.Literal(operation), lgcargs): args = tuple([self.compile_operand(arg) for arg in lgcargs]) - return ein.Call( - ein.Literal(operation), - args if is_commutative(operation) else flatten_args(args), - ) + return ein.Call(ein.Literal(operation), args) case lgc.Relabel( lgc.Alias(name), idxs ): # relable is really just a glorified pointwise access From 42eed927e3f91f5a4e1fe23d4481b48b00551039 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 11:46:20 -0400 Subject: [PATCH 55/57] Refactored compile_plan as a fucntion that only returns an EinsumNode and doesn't rely on building a vector of statements/bodies --- src/finchlite/autoschedule/einsum.py | 97 ++++++++++++++-------------- 1 file changed, 48 insertions(+), 49 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 2129294a..4f414618 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -13,64 +13,63 @@ def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: def compile_plan( self, - plan: lgc.Plan, + node: lgc.LogicNode, bindings: dict[str, Any], definitions: dict[str, ein.Einsum], - ) -> ein.Plan: - bodies: list[ein.EinsumNode] = [] - - for body in plan.bodies: - match body: - case lgc.Plan(_): - bodies.append(self.compile_plan(body, bindings, definitions)) - case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): - bindings[name] = val - case lgc.Query( - lgc.Alias(name), - lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _), - ): - einidxs = tuple(ein.Index(field.name) for field in body.rhs.fields) - if init != init_value(operation, type(init)): - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=einidxs, - arg=ein.Literal(init), - ) - ) - bodies.append( + ) -> ein.EinsumNode: + match node: + case lgc.Plan(bodies): + ein_bodies = [self.compile_plan(body, bindings, definitions) for body in bodies] + not_none_bodies = [body for body in ein_bodies if body is not None] + return ein.Plan(tuple(not_none_bodies)) + case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): + bindings[name] = val + return None + case lgc.Query( + lgc.Alias(name), + lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _), + ): + einidxs = tuple(ein.Index(field.name) for field in node.rhs.fields) + my_bodies = [] + if init != init_value(operation, type(init)): + my_bodies.append( ein.Einsum( - op=ein.Literal(operation), + op=ein.Literal(overwrite), tns=ein.Alias(name), idxs=einidxs, - arg=self.compile_operand(arg), + arg=ein.Literal(init), ) ) - case lgc.Query(lgc.Alias(name), rhs): - einarg = self.compile_operand(rhs) - bodies.append( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias(name), - idxs=tuple( - ein.Index(field.name) for field in body.rhs.fields - ), - arg=einarg, - ) + my_bodies.append( + ein.Einsum( + op=ein.Literal(operation), + tns=ein.Alias(name), + idxs=einidxs, + arg=self.compile_operand(arg), ) - case lgc.Produces(args): - returnValues = [] - for ret_arg in args: - if not isinstance(ret_arg, lgc.Alias): - raise Exception(f"Unrecognized logic: {ret_arg}") - returnValues.append(ein.Alias(ret_arg.name)) - - bodies.append(ein.Produces(tuple(returnValues))) - case _: - raise Exception(f"Unrecognized logic: {body}") + ) + return ein.Plan(tuple(my_bodies)) + case lgc.Query(lgc.Alias(name), rhs): + einarg = self.compile_operand(rhs) + return ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias(name), + idxs=tuple( + ein.Index(field.name) for field in node.rhs.fields + ), + arg=einarg, + ) + + case lgc.Produces(args): + returnValues = [] + for ret_arg in args: + if not isinstance(ret_arg, lgc.Alias): + raise Exception(f"Unrecognized logic: {ret_arg}") + returnValues.append(ein.Alias(ret_arg.name)) - return ein.Plan(tuple(bodies)) + return ein.Produces(tuple(returnValues)) + case _: + raise Exception(f"Unrecognized logic: {node}") # lowers nested mapjoin logic IR nodes into a single pointwise expression def compile_operand( From 874c76de077613c51b37807b7abad522d097d251 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 11:48:00 -0400 Subject: [PATCH 56/57] * Fixed potential mypy errors --- src/finchlite/autoschedule/einsum.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 4f414618..02c2ad94 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import finchlite.finch_einsum as ein import finchlite.finch_logic as lgc @@ -9,14 +9,14 @@ class EinsumLowerer: def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} - return self.compile_plan(prgm, bindings, definitions), bindings + return cast(ein.Plan,self.compile_plan(prgm, bindings, definitions)), bindings def compile_plan( self, node: lgc.LogicNode, bindings: dict[str, Any], definitions: dict[str, ein.Einsum], - ) -> ein.EinsumNode: + ) -> ein.EinsumNode | None: match node: case lgc.Plan(bodies): ein_bodies = [self.compile_plan(body, bindings, definitions) for body in bodies] From 4c4813cbd6ec852b5202b4230f570d21d4a1ee42 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 27 Oct 2025 11:49:12 -0400 Subject: [PATCH 57/57] * Fixed ruff errors --- src/finchlite/autoschedule/einsum.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 02c2ad94..d2baf341 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -9,7 +9,7 @@ class EinsumLowerer: def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]: bindings: dict[str, Any] = {} definitions: dict[str, ein.Einsum] = {} - return cast(ein.Plan,self.compile_plan(prgm, bindings, definitions)), bindings + return cast(ein.Plan, self.compile_plan(prgm, bindings, definitions)), bindings def compile_plan( self, @@ -19,7 +19,9 @@ def compile_plan( ) -> ein.EinsumNode | None: match node: case lgc.Plan(bodies): - ein_bodies = [self.compile_plan(body, bindings, definitions) for body in bodies] + ein_bodies = [ + self.compile_plan(body, bindings, definitions) for body in bodies + ] not_none_bodies = [body for body in ein_bodies if body is not None] return ein.Plan(tuple(not_none_bodies)) case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)): @@ -54,12 +56,10 @@ def compile_plan( return ein.Einsum( op=ein.Literal(overwrite), tns=ein.Alias(name), - idxs=tuple( - ein.Index(field.name) for field in node.rhs.fields - ), + idxs=tuple(ein.Index(field.name) for field in node.rhs.fields), arg=einarg, ) - + case lgc.Produces(args): returnValues = [] for ret_arg in args: