diff --git a/src/finchlite/algebra/__init__.py b/src/finchlite/algebra/__init__.py index 6411af90..fa048411 100644 --- a/src/finchlite/algebra/__init__.py +++ b/src/finchlite/algebra/__init__.py @@ -4,6 +4,7 @@ init_value, is_annihilator, is_associative, + is_commutative, is_distributive, is_idempotent, is_identity, @@ -37,11 +38,8 @@ "TensorFType", "TensorPlaceholder", "conjugate", - "conjugate", - "element_type", "element_type", "fill_value", - "fill_value", "first_arg", "fixpoint_type", "identity", @@ -49,6 +47,7 @@ "is_annihilator", "is_associative", "is_distributive", + "is_commutative", "is_idempotent", "is_identity", "overwrite", diff --git a/src/finchlite/algebra/algebra.py b/src/finchlite/algebra/algebra.py index fd2ae6a3..718b1c0b 100644 --- a/src/finchlite/algebra/algebra.py +++ b/src/finchlite/algebra/algebra.py @@ -345,6 +345,18 @@ def is_associative(op: Any) -> bool: register_property(np.logical_or, "__call__", "is_associative", lambda op: True) register_property(np.logical_xor, "__call__", "is_associative", lambda op: True) +# Commutative properties + +for op in (operator.add, operator.mul, operator.and_, operator.or_, operator.xor, + np.logical_and, np.logical_or, np.logical_xor): + register_property(op, "__call__", "is_commutative", lambda op: True) + +for op in (operator.sub, operator.truediv, operator.floordiv, operator.pow, + operator.lshift, operator.rshift): + register_property(op, "__call__", "is_commutative", lambda op: False) + +def is_commutative(op: Any) -> bool: + return query_property(op, "__call__", "is_commutative") def is_identity(op: Any, val: Any) -> bool: """ diff --git a/src/finchlite/algebra/tensor.py b/src/finchlite/algebra/tensor.py index aeeaa24b..8d328c10 100644 --- a/src/finchlite/algebra/tensor.py +++ b/src/finchlite/algebra/tensor.py @@ -6,7 +6,6 @@ from ..algebra import register_property from ..symbolic import FType, FTyped, ftype - class TensorFType(FType, ABC): @property def ndim(self) -> np.intp: diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 32022d8e..9f8f0faa 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -38,11 +38,21 @@ push_fields, set_loop_order, ) +from .einsum import ( + Einsum, + EinsumPlan, + EinsumCompiler, + EinsumScheduler +) __all__ = [ "Aggregate", "Alias", "DefaultLogicOptimizer", + "Einsum", + "EinsumPlan", + "EinsumCompiler", + "EinsumScheduler", "Field", "Literal", "LogicCompiler", @@ -51,6 +61,7 @@ "PostOrderDFS", "PostWalk", "PreWalk", + "PrintingLogicOptimizer", "Produces", "Query", "Reformat", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py new file mode 100644 index 00000000..1c1db625 --- /dev/null +++ b/src/finchlite/autoschedule/einsum.py @@ -0,0 +1,547 @@ +from dataclasses import dataclass +from abc import ABC +import operator +from typing import Callable, Self + +from finchlite.algebra.tensor import Tensor +from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel +from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table +from finchlite.symbolic import Term, TermTree, PostWalk, Rewrite +from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min +import numpy as np + +@dataclass(eq=True, frozen=True) +class PointwiseNode(Term, ABC): + """ + PointwiseNode + + Represents an AST node in the Einsum Pointwise Expression IR + """ + + @classmethod + def head(cls): + """Returns the head of the node.""" + return cls + + @classmethod + def make_term(cls, head, *children: Term) -> Self: + return head.from_children(*children) + + @classmethod + def from_children(cls, *children: Term) -> Self: + return cls(*children) + + def __str__(self): + ctx = EinsumPrinterContext() + return ctx.print_pointwise_expr(self) + +@dataclass(eq=True, frozen=True) +class PointwiseAccess(PointwiseNode, TermTree): + """ + PointwiseAccess + + Tensor access like a[i, j]. + + Attributes: + tensor: The tensor to access. + idxs: The indices at which to access the tensor. + """ + + alias: str + idxs: tuple[Field, ...] # (Field('i'), Field('j')) + # Children: None (leaf) + + @classmethod + def from_children(cls, alias: str, idxs: tuple[Field, ...]) -> Self: + return cls(alias, idxs) + + @property + def children(self): + return [self.alias, self.idxs] + + +@dataclass(eq=True, frozen=True) +class PointwiseIndirectCOOAccess(PointwiseNode, TermTree): + """ + PointwiseIndirectCOOAccess + + Tensor access like a[i, j] but for sparse tensors. So in reality it's like a[COO_coords[i]] = ... + + Attributes: + tensor: The tensor to access. + coo_coords: The COO coordinates at which to access the tensor (this is also a tensor). + idxs: The indices at which to access the tensor. + """ + + alias: str + coo_coord_alias: str + idx: Field #only one index is needed to access the COO coord tensor + # Children: None (leaf) + + @classmethod + def from_children(cls, alias: str, coo_coord_alias: str, idx: Field) -> Self: + return cls(alias, coo_coord_alias, idx) + + @property + def children(self): + return [self.alias, self.coo_coord_alias, self.idx] + +@dataclass(eq=True, frozen=True) +class PointwiseOp(PointwiseNode, TermTree): + """ + PointwiseOp + + Represents an operation like + or * on pointwise expressions for multiple operands. + If operation is not commutative, pointwise node must be binary, with 2 args at most. + + Attributes: + op: The function to apply e.g., operator.add, operator.mul, operator.subtract, operator.div, etc... Must be a callable. + args: The arguments to the operation. + """ + + op: Callable #the function to apply e.g., operator.add + args: tuple[PointwiseNode, ...] # Subtrees + #input_fields: tuple[tuple[Field, ...], ...] + # Children: The args + + @classmethod + def from_children(cls, op: Callable, *args: tuple[PointwiseNode, ...]) -> Self: + return cls(op, args) + + @property + def children(self): + return [self.op, *self.args] + +@dataclass(eq=True, frozen=True) +class PointwiseLiteral(PointwiseNode): + """ + PointwiseLiteral + + A scalar literal/value for pointwise operations. + """ + + val: float + + def __hash__(self): + return hash(self.val) + + def __eq__(self, other): + return isinstance(other, PointwiseLiteral) and self.val == other.val + +@dataclass(eq=True, frozen=True) +class EinsumPlanStatement(Term, ABC): + """ + EinsumPlanStatement + + Represents an AST node in the Einsum Plan IR + """ + + @classmethod + def head(cls): + """Returns the head of the node.""" + return cls + + @classmethod + def make_term(cls, head, *children: Term) -> Self: + return head.from_children(*children) + + @classmethod + def from_children(cls, *children: Term) -> Self: + return cls(*children) + + def __str__(self): + ctx = EinsumPrinterContext() + return ctx.print_einsum_plan_statement(self) + +#einsum and einsum ast not part of logic IR +#transform to it's own language +@dataclass(eq=True, frozen=True) +class Einsum(EinsumPlanStatement, TermTree): + """ + Einsum + + A einsum operation that maps pointwise expressions and aggregates them. + + Attributes: + updateOp: The function to apply to the pointwise expressions (e.g. +=, f=, max=, etc...). + input_fields: The indices that are used in the pointwise expression (i.e. i, j, k). + output_fields: The indices that are used in the output (i.e. i, j). + pointwise_expr: The pointwise expression that is mapped and aggregated. + """ + + reduceOp: Callable #technically a reduce operation, much akin to the one in aggregate + + #input_fields: tuple[Field, ...] #redundant remove later + output_fields: tuple[Field, ...] + pointwise_expr: PointwiseNode + + output_alias: str | None + indirect_coo_alias: str | None + + @classmethod + def from_children(cls, reduceOp: Callable, output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode, output_alias: str | None, indirect_coo_alias: str | None) -> Self: + #return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) + return cls(reduceOp, output_fields, pointwise_expr, output_alias, indirect_coo_alias) + + @property + def children(self): + #return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] + return [self.output_alias, self.reduceOp, self.output_fields, self.pointwise_expr, self.indirect_coo_alias] + + def rename(self, new_alias: str): + #return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) + return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, new_alias, self.indirect_coo_alias) + + def reorder(self, idxs: tuple[Field, ...]): + #return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) + return Einsum(self.reduceOp, idxs, self.pointwise_expr, self.output_alias, self.indirect_coo_alias) + +@dataclass(eq=True, frozen=True) +class ExtractCOOFromSparse(EinsumPlanStatement): + """ + ExtractCOOFromSparse + + A plan statement that contains an extract's the coordinate array from COO sparse tensor. + """ + alias: str + + @classmethod + def from_children(cls, alias: str) -> Self: + return cls(alias) + + @property + def children(self): + return [self.alias] + +@dataclass(eq=True, frozen=True) +class ExtractValuesFromSparse(EinsumPlanStatement): + """ + ExtractValuesFromSparse + + A plan statement that contains an extract's the values array from COO sparse tensor. + """ + alias: str + + @classmethod + def from_children(cls, alias: str) -> Self: + return cls(alias) + + @property + def children(self): + return [self.alias] + +@dataclass(eq=True, frozen=True) +class EinsumPlan(Plan): + """ + EinsumPlan + + A plan that contains einsum operations. Basically a list of einsum operations. + """ + + bodies: tuple[EinsumPlanStatement, ...] = () + returnValues: tuple[Einsum | str] = () + + @classmethod + def from_children(cls, bodies: tuple[EinsumPlanStatement, ...], returnValues: tuple[Einsum | str]) -> Self: + return cls(bodies, returnValues) + + @property + def children(self): + return [*self.bodies, self.returnValues] + + def __str__(self): + ctx = EinsumPrinterContext() + return ctx(self) + +class EinsumLowerer: + alias_counter: int = 0 + + def __call__(self, prgm: Plan, parameters: dict[str, Table], definitions: dict[str, Einsum]) -> EinsumPlan: + return self.compile_plan(prgm, parameters, definitions) + + def get_next_alias(self) -> str: + self.alias_counter += 1 + return f"einsum_{self.alias_counter}" + + def rename_einsum(self, einsum: Einsum, new_alias: str, definitions: dict[str, Einsum]) -> Einsum: + definitions[new_alias] = einsum + return einsum.rename(new_alias) + + def compile_plan(self, plan: Plan, parameters: dict[str, Table], definitions: dict[str, Einsum]) -> EinsumPlan: + einsum_statements: list[EinsumPlanStatement] = [] + returnValue = [] + + for body in plan.bodies: + match body: + case Plan(_): + einsum_plan = self.compile_plan(body, parameters, definitions) + einsum_statements.extend(einsum_plan.bodies) + + if einsum_plan.returnValues: + if returnValue: + raise Exception("Cannot invoke return more than once.") + returnValue = einsum_plan.returnValues + case Query(Alias(name), Table(_, _)): + parameters[name] = body.rhs + case Query(Alias(name), rhs): + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(rhs, einsum_statements, parameters, definitions), name, definitions)) + case Produces(args): + if returnValue: + raise Exception("Cannot invoke return more than once.") + for arg in args: + returnValue.append(arg.name if isinstance(arg, Alias) else self.lower_to_einsum(arg, einsum_statements, parameters, definitions)) + case _: + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(body, einsum_statements, parameters, definitions), self.get_next_alias(), definitions)) + + return EinsumPlan(tuple(einsum_statements), tuple(returnValue)) + + def lower_to_einsum(self, ex: LogicNode, einsum_statements: list[EinsumPlanStatement], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> Einsum: + match ex: + case Plan(_): + plan = self.compile_plan(ex, parameters, definitions) + einsum_statements.extend(plan.bodies) + + if not plan.returnValues: + raise Exception("Plans with no return value are not statements, but rather are expressions.") + + if len(plan.returnValues) > 1: + raise Exception("Only one return value is supported.") + + if isinstance(plan.returnValues[0], str): + returned_alias = plan.returnValues[0] + returned_einsum = definitions[returned_alias] + return PointwiseAccess(alias=returned_alias, idxs=returned_einsum.output_fields) + + return plan.returnValues[0] + case MapJoin(Literal(operation), args): + args = [self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) for arg in args] + pointwise_expr = self.lower_to_pointwise_op(operation, args) + #return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) + return Einsum(reduceOp=overwrite, output_fields=tuple(ex.fields), pointwise_expr=pointwise_expr, output_alias=None, indirect_coo_alias=None) + case Reorder(arg, idxs): + return self.lower_to_einsum(arg, einsum_statements, parameters, definitions).reorder(idxs) + case Aggregate(Literal(operation), Literal(init), arg, idxs): + if init != init_value(operation, type(init)): + raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") + pointwise_expr = self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) + #return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) + return Einsum(operation, tuple(ex.fields), pointwise_expr, self.get_next_alias(), None) + case _: + raise Exception(f"Unrecognized logic: {ex}") + + def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, ...]) -> PointwiseOp: + # if operation is commutative, we simply pass all the args to the pointwise op since order of args does not matter + if is_commutative(operation): + def flatten_args(m_args: tuple[PointwiseNode, ...]) -> tuple[PointwiseNode, ...]: + ret_args = [] + for arg in m_args: + match arg: + case PointwiseOp(op2, _) if op2 == operation: + ret_args.extend(flatten_args(arg.args)) + case _: + ret_args.append(arg) + return tuple(ret_args) + return PointwiseOp(operation, flatten_args(args)) + + # combine args from left to right (i.e a / b / c -> (a / b) / c) + assert len(args) > 1 + result = PointwiseOp(operation, (args[0], args[1])) + for arg in args[2:]: + result = PointwiseOp(operation, (result, arg)) + return result + + # lowers nested mapjoin logic IR nodes into a single pointwise expression + def lower_to_pointwise(self, ex: LogicNode, einsum_statements: list[EinsumPlanStatement], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> PointwiseNode: + match ex: + case Reorder(arg, idxs): + return self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) + case MapJoin(Literal(operation), args): + args = [self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) for arg in args] + return self.lower_to_pointwise_op(operation, args) + case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access + return PointwiseAccess(alias=name, idxs=idxs) + case Literal(value): + return PointwiseLiteral(val=value) + case Aggregate(_, _, _, _): # aggregate has to be computed seperatley as it's own einsum + aggregate_einsum_alias = self.get_next_alias() + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(ex, einsum_statements, parameters, definitions), aggregate_einsum_alias, definitions)) + return PointwiseAccess(alias=aggregate_einsum_alias, idxs=tuple(ex.fields)) + case _: + raise Exception(f"Unrecognized logic: {ex}") + +class EinsumPrinterContext: + def print_indicies(self, idxs: tuple[Field, ...]): + return ", ".join([str(idx) for idx in idxs]) + + def print_reducer(self, reducer: Callable): + str_map = { + overwrite: "=", + operator.add: "+=", + operator.sub: "-=", + operator.mul: "*=", + operator.truediv: "/=", + operator.mod: "%=", + operator.pow: "**=", + operator.and_: "&=", + operator.or_: "|=", + operator.xor: "^=", + operator.floordiv: "//=", + operator.mod: "%=", + operator.pow: "**=", + promote_max: "max=", + promote_min: "min=", + } + return str_map[reducer] + + def print_pointwise_op_callable(self, op: Callable): + str_map = { + operator.add: "+", + operator.sub: "-", + operator.mul: "*", + operator.truediv: "/", + operator.mod: "%", + operator.pow: "**", + } + return str_map[op] + + def print_pointwise_op(self, pointwise_op: PointwiseOp): + if is_commutative(pointwise_op.op) == False: + return f"({pointwise_op.args[0]} {self.print_pointwise_op_callable(pointwise_op.op)} {pointwise_op.args[1]})" + return f"({f" {self.print_pointwise_op_callable(pointwise_op.op)} ".join(self.print_pointwise_expr(arg) for arg in pointwise_op.args)})" + + def print_pointwise_expr(self, pointwise_expr: PointwiseNode): + match pointwise_expr: + case PointwiseAccess(alias, idxs): + return f"{alias}[{self.print_indicies(idxs)}]" + case PointwiseIndirectCOOAccess(alias, coo_coord_alias, idx): + return f"{alias}[{coo_coord_alias}[{self.print_indicies((idx, ))}]]" + case PointwiseOp(_, __): + return self.print_pointwise_op(pointwise_expr) + case PointwiseLiteral(val): + return str(val) + + def print_einsum(self, einsum: Einsum) -> str: + if einsum.indirect_coo_alias: + return f"{einsum.output_alias}[{einsum.indirect_coo_alias}[{self.print_indicies(einsum.output_fields)}]] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" + return f"{einsum.output_alias}[{self.print_indicies(einsum.output_fields)}] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" + + def print_return_value(self, return_value: Einsum | str) -> str: + return return_value if isinstance(return_value, str) else self.print_einsum(return_value) + + def print_einsum_plan_statement(self, einsum_plan_statement: EinsumPlanStatement) -> str: + match einsum_plan_statement: + case Einsum(_, _, _, _, _): + return self.print_einsum(einsum_plan_statement) + case ExtractCOOFromSparse(alias): + return f"{alias}_coo = ExtractCOO({alias})" + case ExtractValuesFromSparse(alias): + return f"{alias}_values = ExtractValues({alias})" + case _: + raise Exception(f"Unrecognized einsum plan statement: {einsum_plan_statement}") + + def print_einsum_plan(self, einsum_plan: EinsumPlan) -> str: + if not einsum_plan.returnValues: + return "\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies]) + return f"{"\n".join([self.print_einsum_plan_statement(statement) for statement in einsum_plan.bodies])}\nreturn {", ".join([self.print_return_value(return_value) for return_value in einsum_plan.returnValues])}" + + def __call__(self, prgm: EinsumPlan) -> str: + return self.print_einsum_plan(prgm) + +class EinsumInterpreter: + def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + return self.print(einsum_plan, parameters) + + def print(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + for (str, table) in parameters.items(): + print(f"Parameter: {str} = {table}") + + print(einsum_plan) + return (np.arange(6, dtype=np.float32).reshape(2, 3),) + + +class EinsumCompiler: + def __init__(self): + self.el = EinsumLowerer() + + def find_sparse_tensors(self, parameters: dict[str, Table])-> tuple[dict, dict]: # -> dict[str, Tuple[Field, ...]]: getting type errors here + from finchlite.autoschedule.sparse_tensor import SparseTensor + + special_field_number = 1 + sparse_tensors = dict() + sparse_fields = dict() + + for alias, value in parameters.items(): + match value: + case Table(Literal(tensor), idxs): + if isinstance(tensor, SparseTensor): + if idxs in sparse_fields: + sparse_tensors[alias] = sparse_fields[idxs][1] + else: + sparse_tensors[alias] = Field(f"sf_{special_field_number}") + sparse_fields[idxs] = (alias, Field(f"sf_{special_field_number}")) + special_field_number += 1 + return sparse_tensors, sparse_fields + + #getting type errors here if I use dict[str, Tuple[Field, ...]] + def optimize_einsum(self, einsum_plan: EinsumPlan, sparse_aliases: dict, sparse_fields: dict) -> EinsumPlan: + def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[EinsumPlanStatement]) -> Einsum: + extracted_value_alias = set() + extracted_coo_alias = set() + + def optimize_pointwise_access(node: PointwiseNode) -> PointwiseNode: + match node: + case PointwiseAccess(alias, idxs): + if idxs in sparse_fields: + if alias in sparse_aliases: + if alias not in extracted_value_alias: + extracted_value_alias.add(alias) + extra_ops.append(ExtractValuesFromSparse(alias)) + return PointwiseAccess(f"{alias}_values", (sparse_aliases[alias],)) + else: + if alias not in extracted_coo_alias: + extracted_coo_alias.add(alias) + extra_ops.append(ExtractCOOFromSparse(alias)) + return PointwiseIndirectCOOAccess(alias, f"{sparse_fields[idxs][0]}_coo", sparse_fields[idxs][1]) + return node + + new_pointwise_expr = Rewrite(PostWalk(optimize_pointwise_access))(einsum.pointwise_expr) + + if einsum.output_fields in sparse_fields: + return Einsum(einsum.reduceOp, (sparse_fields[einsum.output_fields][1],), new_pointwise_expr, einsum.output_alias, f"{sparse_fields[einsum.output_fields][0]}_coo") + return Einsum(einsum.reduceOp, einsum.output_fields, new_pointwise_expr, einsum.output_alias, None) + + optimized_einsums: list[EinsumPlanStatement] = [] + for statement in einsum_plan.bodies: + match statement: + case Einsum(_, _, _, _, _): + optimized_einsums.append(optimize_sparse_einsum(statement, optimized_einsums)) + case _: + optimized_einsums.append(statement) + + optimized_returns = [] + for return_value in einsum_plan.returnValues: + match return_value: + case Einsum(_, _, _, _, _): + optimized_returns.append(optimize_sparse_einsum(return_value, optimized_einsums)) + case _: + optimized_returns.append(return_value) + return EinsumPlan(tuple(optimized_einsums), tuple(optimized_returns)) + + def __call__(self, prgm: Plan): + parameters = {} + definitions = {} + einsum_plan = self.el(prgm, parameters, definitions) + + sparse_aliases, sparse_fields = self.find_sparse_tensors(parameters) + einsum_plan = self.optimize_einsum(einsum_plan, sparse_aliases, sparse_fields) + + return einsum_plan, parameters, definitions + +class EinsumScheduler: + def __init__(self, ctx: EinsumCompiler): + self.ctx = ctx + self.interpret = EinsumInterpreter() + + def __call__(self, prgm: LogicNode): + einsum_plan, parameters, _ = self.ctx(prgm) + return self.interpret(einsum_plan, parameters) \ No newline at end of file diff --git a/src/finchlite/autoschedule/optimize.py b/src/finchlite/autoschedule/optimize.py index 10b9d984..0b1765ba 100644 --- a/src/finchlite/autoschedule/optimize.py +++ b/src/finchlite/autoschedule/optimize.py @@ -6,6 +6,7 @@ from typing import TypeVar, overload from finchlite.algebra.algebra import is_annihilator, is_distributive, is_identity +#from finchlite.autoschedule.einsum import EinsumCompiler from ..finch_logic import ( Aggregate, @@ -773,7 +774,7 @@ def rule_1(ex): class DefaultLogicOptimizer: - def __init__(self, ctx: LogicCompiler): + def __init__(self, ctx): self.ctx = ctx def __call__(self, prgm: LogicNode): diff --git a/src/finchlite/autoschedule/sparse_tensor.py b/src/finchlite/autoschedule/sparse_tensor.py new file mode 100644 index 00000000..9e4b914b --- /dev/null +++ b/src/finchlite/autoschedule/sparse_tensor.py @@ -0,0 +1,91 @@ +from typing import override +from finchlite.algebra import TensorFType +from finchlite.interface.eager import EagerTensor +import numpy as np + +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type): + self.shape = shape + self._element_type = element_type + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type + + def __hash__(self): + return hash((self.shape, self.element_type)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def element_type(self): + return self._element_type + + @property + def fill_value(self): + return 0 + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_type=np.float64): + self.coords = coords + self.data = data + self._shape = shape + self._element_type = element_type + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: np.ndarray): + + coords = np.where(dense_tensor != 0) + data = dense_tensor[coords] + shape = dense_tensor.shape + element_type = dense_tensor.dtype.type # Get the type, not the dtype + # Convert coords from tuple of arrays to array of coordinates + coords_array = np.array(coords).T + return cls(data, coords_array, shape, element_type) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self._element_type) + + @property + def shape(self): + return self._shape + + @property + def ndim(self) -> int: + return len(self._shape) + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return self.coords.shape[0] / np.prod(self.shape) + + def __getitem__(self, idx: tuple): + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + # coords is a 2D array where each row is a coordinate + mask = np.all(self.coords == idx, axis=1) + matching_indices = np.where(mask)[0] + + if len(matching_indices) > 0: + return self.data[matching_indices[0]] + return 0 + + def __str__(self): + return f"SparseTensor(data={self.data}, coords={self.coords}, shape={self.shape}, element_type={self._element_type})" + + def to_dense(self) -> np.ndarray: + dense_tensor = np.zeros(self.shape, dtype=self._element_type) + for i in range(self.coords.shape[0]): + dense_tensor[tuple(self.coords[i])] = self.data[i] + return dense_tensor \ No newline at end of file diff --git a/src/finchlite/interface/eager.py b/src/finchlite/interface/eager.py index 56857189..b972316c 100644 --- a/src/finchlite/interface/eager.py +++ b/src/finchlite/interface/eager.py @@ -8,7 +8,6 @@ from .fuse import compute from .overrides import OverrideTensor - class EagerTensor(OverrideTensor, ABC): def override_module(self): return sys.modules[__name__] @@ -226,7 +225,6 @@ def __ne__(self, other): register_property(EagerTensor, "asarray", "__attr__", lambda x: x) - def permute_dims(arg, /, axis: tuple[int, ...]): if isinstance(arg, lazy.LazyTensor): return lazy.permute_dims(arg, axis=axis) @@ -266,8 +264,7 @@ def reduce( if isinstance(x, lazy.LazyTensor): return lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) return compute( - lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) - ) + lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init)) def sum( @@ -288,7 +285,7 @@ def prod( /, *, axis: int | tuple[int, ...] | None = None, - dtype=None, + dtype=None, keepdims: bool = False, ): if isinstance(x, lazy.LazyTensor): @@ -301,7 +298,6 @@ def elementwise(f: Callable, *args): return lazy.elementwise(f, *args) return compute(lazy.elementwise(f, *args)) - def add(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.add(x1, x2)