From 193dbb19f9ec3ee2744cd20c9b0ddb14823dc677 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 28 Oct 2025 09:38:35 -0400 Subject: [PATCH 01/45] * Added basic sparse tensor type. * Uses COO internally * Lookup is currently naive and O(n) --- src/finchlite/tensor/sparse_tensor.py | 90 +++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 src/finchlite/tensor/sparse_tensor.py diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py new file mode 100644 index 00000000..60f441f3 --- /dev/null +++ b/src/finchlite/tensor/sparse_tensor.py @@ -0,0 +1,90 @@ +from finchlite.algebra import TensorFType +from finchlite.interface.eager import EagerTensor +import numpy as np + +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type): + self.shape = shape + self._element_type = element_type + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type + + def __hash__(self): + return hash((self.shape, self.element_type)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def element_type(self): + return self._element_type + + @property + def fill_value(self): + return 0 + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_type=np.float64): + self.coords = coords + self.data = data + self._shape = shape + self._element_type = element_type + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: np.ndarray): + + coords = np.where(dense_tensor != 0) + data = dense_tensor[coords] + shape = dense_tensor.shape + element_type = dense_tensor.dtype.type # Get the type, not the dtype + # Convert coords from tuple of arrays to array of coordinates + coords_array = np.array(coords).T + return cls(data, coords_array, shape, element_type) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self._element_type) + + @property + def shape(self): + return self._shape + + @property + def ndim(self) -> int: + return len(self._shape) + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return self.coords.shape[0] / np.prod(self.shape) + + def __getitem__(self, idx: tuple): + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + # coords is a 2D array where each row is a coordinate + mask = np.all(self.coords == idx, axis=1) + matching_indices = np.where(mask)[0] + + if len(matching_indices) > 0: + return self.data[matching_indices[0]] + return 0 + + def __str__(self): + return f"SparseTensor(data={self.data}, coords={self.coords}, shape={self.shape}, element_type={self._element_type})" + + def to_dense(self) -> np.ndarray: + dense_tensor = np.zeros(self.shape, dtype=self._element_type) + for i in range(self.coords.shape[0]): + dense_tensor[tuple(self.coords[i])] = self.data[i] + return dense_tensor \ No newline at end of file From 101caf3a8c6fc6b0bf7bf37fe19fbf319ae8bc80 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 28 Oct 2025 09:52:54 -0400 Subject: [PATCH 02/45] * Fixed ruff errors --- src/finchlite/tensor/sparse_tensor.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py index 60f441f3..d842c69f 100644 --- a/src/finchlite/tensor/sparse_tensor.py +++ b/src/finchlite/tensor/sparse_tensor.py @@ -1,6 +1,8 @@ +import numpy as np + from finchlite.algebra import TensorFType from finchlite.interface.eager import EagerTensor -import numpy as np + class SparseTensorFType(TensorFType): def __init__(self, shape: tuple, element_type: type): @@ -31,9 +33,16 @@ def element_type(self): def fill_value(self): return 0 + # currently implemented with COO tensor class SparseTensor(EagerTensor): - def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_type=np.float64): + def __init__( + self, + data: np.array, + coords: np.ndarray, + shape: tuple, + element_type=np.float64, + ): self.coords = coords self.data = data self._shape = shape @@ -42,7 +51,6 @@ def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_typ # converts an eager tensor to a sparse tensor @classmethod def from_dense_tensor(cls, dense_tensor: np.ndarray): - coords = np.where(dense_tensor != 0) data = dense_tensor[coords] shape = dense_tensor.shape @@ -81,10 +89,13 @@ def __getitem__(self, idx: tuple): return 0 def __str__(self): - return f"SparseTensor(data={self.data}, coords={self.coords}, shape={self.shape}, element_type={self._element_type})" + return ( + f"SparseTensor(data={self.data}, coords={self.coords}," + f" shape={self.shape}, element_type={self._element_type})" + ) def to_dense(self) -> np.ndarray: dense_tensor = np.zeros(self.shape, dtype=self._element_type) for i in range(self.coords.shape[0]): dense_tensor[tuple(self.coords[i])] = self.data[i] - return dense_tensor \ No newline at end of file + return dense_tensor From ae3818376ca9fe9dacd3a43f1f71c257c3779e1e Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 28 Oct 2025 10:31:14 -0400 Subject: [PATCH 03/45] * Fixed mypy typing issues --- src/finchlite/tensor/sparse_tensor.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py index d842c69f..362184cb 100644 --- a/src/finchlite/tensor/sparse_tensor.py +++ b/src/finchlite/tensor/sparse_tensor.py @@ -38,8 +38,8 @@ def fill_value(self): class SparseTensor(EagerTensor): def __init__( self, - data: np.array, - coords: np.ndarray, + data: np.typing.NDArray, + coords: np.typing.NDArray, shape: tuple, element_type=np.float64, ): @@ -54,8 +54,7 @@ def from_dense_tensor(cls, dense_tensor: np.ndarray): coords = np.where(dense_tensor != 0) data = dense_tensor[coords] shape = dense_tensor.shape - element_type = dense_tensor.dtype.type # Get the type, not the dtype - # Convert coords from tuple of arrays to array of coordinates + element_type = dense_tensor.dtype.type coords_array = np.array(coords).T return cls(data, coords_array, shape, element_type) @@ -68,8 +67,8 @@ def shape(self): return self._shape @property - def ndim(self) -> int: - return len(self._shape) + def ndim(self) -> np.intp: + return np.intp(len(self._shape)) # calculates the ratio of non-zero elements to the total number of elements @property From bddbac9c17561e464b55892ac516e5ec03c643c9 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 28 Oct 2025 10:40:44 -0400 Subject: [PATCH 04/45] * Fixed ruff whitespace errors --- src/finchlite/tensor/sparse_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py index 362184cb..e998df27 100644 --- a/src/finchlite/tensor/sparse_tensor.py +++ b/src/finchlite/tensor/sparse_tensor.py @@ -54,7 +54,7 @@ def from_dense_tensor(cls, dense_tensor: np.ndarray): coords = np.where(dense_tensor != 0) data = dense_tensor[coords] shape = dense_tensor.shape - element_type = dense_tensor.dtype.type + element_type = dense_tensor.dtype.type coords_array = np.array(coords).T return cls(data, coords_array, shape, element_type) From 7318931eec2141d94cae8448402ac136f366c909 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 28 Oct 2025 10:49:09 -0400 Subject: [PATCH 05/45] * Added in dimension check for safety --- src/finchlite/tensor/sparse_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py index e998df27..6d798615 100644 --- a/src/finchlite/tensor/sparse_tensor.py +++ b/src/finchlite/tensor/sparse_tensor.py @@ -43,6 +43,9 @@ def __init__( shape: tuple, element_type=np.float64, ): + if data.shape[0] != coords.shape[0]: + raise ValueError("data and coords must have the same number of rows") + self.coords = coords self.data = data self._shape = shape From 335c75ce069fc486f88de67ba4bb505133dd2b07 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 20:10:06 -0500 Subject: [PATCH 06/45] * Added GetAttribute Einsum IR node * Get Attribute is a general purpose node that can be used to retreive the coordinate and element arrays from a sparse tensor --- src/finchlite/finch_einsum/nodes.py | 37 ++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 34e92ad8..64692c67 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -1,7 +1,7 @@ import operator from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Self, cast +from typing import Any, Optional, Self, cast from finchlite.algebra import ( overwrite, @@ -176,6 +176,41 @@ def get_idxs(self) -> set["Index"]: return idxs +@dataclass(eq=True, frozen=True) +class GetAttribute(EinsumExpr, EinsumTree): + """ + Gets an attribute of a tensor. + Attributes: + obj: The object to get the attribute from. + attr: The name of the attribute to get. + """ + + obj: EinsumExpr + attr: Literal + idx: Optional[Index] + + @classmethod + def from_children(cls, *children: Term) -> Self: + # Expects 3 children: obj, attr, idx + if len(children) != 3: + raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") + obj = cast(EinsumExpr, children[0]) + attr = cast(Literal, children[1]) + idx = cast(Optional[Index], children[2]) + return cls(obj, attr, idx) + + @property + def children(self): + return [self.obj, self.attr, self.idx] + + def get_idxs(self) -> set["Index"]: + idxs = set() + idxs.update(self.obj.get_idxs()) + if self.idx is not None: + idxs.update(self.idx.get_idxs()) + return idxs + + @dataclass(eq=True, frozen=True) class Einsum(EinsumTree): """ From 6ab5e06ee472e0d6b0b4465161e64b4798b5b8f2 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 20:11:25 -0500 Subject: [PATCH 07/45] * Added support for printing GetAttribute Einsum IR node --- src/finchlite/finch_einsum/nodes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 64692c67..3284b6e2 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -369,6 +369,10 @@ def __call__(self, prgm: EinsumNode): if len(args) == 1 and fn.val in unary_strs: return f"{unary_strs[fn.val]}{args_e[0]}" return f"{self(fn)}({', '.join(args_e)})" + case GetAttribute(obj, attr, idx): + if idx is not None: + return f"{self(obj)}.{self(attr)}[{self(idx)}]" + return f"{self(obj)}.{self(attr)}" case Einsum(op, tns, idxs, arg): op_str = infix_strs.get(op.val, op.val.__name__) self.exec( From e176ef406076e03903873e77afc7444b546492e6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 20:42:58 -0500 Subject: [PATCH 08/45] * Added einsum interpreter support for evaluating GetAttribute IR node * Currently only supports querying the non-zero elements/data array of a sparse tensor and querying the coordinate array (either for all indicies or a specified dimension --- src/finchlite/finch_einsum/interpreter.py | 25 +++++++++++++++++++++-- src/finchlite/finch_einsum/nodes.py | 13 ++++++------ src/finchlite/tensor/__init__.py | 4 ++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 01f256fe..5c2eb9ae 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -4,6 +4,9 @@ from ..algebra import overwrite, promote_max, promote_min from . import nodes as ein +from ..tensor import SparseTensor, SparseTensorFType +from ..symbolic import ftype + nary_ops = { operator.add: "add", @@ -95,9 +98,12 @@ def __call__(self, node): case ein.Access(tns, idxs): assert len(idxs) == len(set(idxs)) assert self.loops is not None + + #convert named idxs to positional, integer indices perm = [idxs.index(idx) for idx in self.loops if idx in idxs] - tns = self(tns) - tns = xp.permute_dims(tns, perm) + + tns = self(tns) #evaluate the tensor + tns = xp.permute_dims(tns, perm) #permute the dimensions return xp.expand_dims( tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], @@ -109,6 +115,21 @@ def __call__(self, node): return res case ein.Produces(args): return tuple(self(arg) for arg in args) + + #get non-zero elements/data array of a sparse tensor + case ein.GetAttribute(obj, ein.Literal("elems"), _): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + return obj.data + #get coord array of a sparse tensor + case ein.GetAttribute(obj, ein.Literal("coords"), dim): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + + # return the coord array for the given dimension or all dimensions + return obj.coords if dim is None else obj.coords[dim, :] case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution loops = arg.get_idxs() diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 3284b6e2..f4072526 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -183,11 +183,14 @@ class GetAttribute(EinsumExpr, EinsumTree): Attributes: obj: The object to get the attribute from. attr: The name of the attribute to get. + + dim: The dimension to get the attribute from. + Note this is an integer index, not a named index. """ obj: EinsumExpr attr: Literal - idx: Optional[Index] + dim: Optional[int] @classmethod def from_children(cls, *children: Term) -> Self: @@ -196,18 +199,16 @@ def from_children(cls, *children: Term) -> Self: raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") obj = cast(EinsumExpr, children[0]) attr = cast(Literal, children[1]) - idx = cast(Optional[Index], children[2]) - return cls(obj, attr, idx) + dim = cast(Optional[int], children[2]) + return cls(obj, attr, dim) @property def children(self): - return [self.obj, self.attr, self.idx] + return [self.obj, self.attr, self.dim] def get_idxs(self) -> set["Index"]: idxs = set() idxs.update(self.obj.get_idxs()) - if self.idx is not None: - idxs.update(self.idx.get_idxs()) return idxs diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index 9c45ad8c..6b2c1c54 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -4,6 +4,8 @@ DenseLevelFType, ElementLevel, ElementLevelFType, + SparseTensor, + SparseTensorFType, dense, element, ) @@ -17,6 +19,8 @@ "FiberTensorFType", "Level", "LevelFType", + "SparseTensor", + "SparseTensorFType", "dense", "element", "tensor", From 2936eac33453ca3473425d4ccb975bed39a348da Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 21:16:26 -0500 Subject: [PATCH 09/45] * Added support for indirection in ein.Access * Currently requires the number of input indicies to match the number of dimensions of a tensor --- src/finchlite/finch_einsum/interpreter.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 5c2eb9ae..661aa7f9 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -5,7 +5,7 @@ from ..algebra import overwrite, promote_max, promote_min from . import nodes as ein from ..tensor import SparseTensor, SparseTensorFType -from ..symbolic import ftype +from ..symbolic import ftype, gensym nary_ops = { @@ -96,18 +96,33 @@ def __call__(self, node): vals = [self(arg) for arg in args] return func(*vals) case ein.Access(tns, idxs): - assert len(idxs) == len(set(idxs)) assert self.loops is not None + + dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} + # evaluate the idxs that are not indices + evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} + idxs_to_perm = [idx if idx in dummy_idxs else dummy_idxs[idx] for idx in idxs] #convert named idxs to positional, integer indices - perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + perm = [idxs_to_perm.index(idx) for idx in self.loops if idx in idxs_to_perm] tns = self(tns) #evaluate the tensor tns = xp.permute_dims(tns, perm) #permute the dimensions - return xp.expand_dims( + tns = xp.expand_dims( #broadcast the tensor to the new dimensions tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) + + # we need to remove indicies not accessed by dummy tensors + for dummy_idx, evaled_idx in evaled_idxs.items(): + axis_to_crop = idxs_to_perm.index(dummy_idx) + axis_size = tns.shape[axis_to_crop] + + idxs_to_crop = np.setdiff1d(np.arange(axis_size), evaled_idx) + tns = xp.delete(tns, idxs_to_crop, axis=axis_to_crop) + + return tns + case ein.Plan(bodies): res = None for body in bodies: From 2c061ab4f93dc994ee66a3f7417cad3065334189 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 21:24:10 -0500 Subject: [PATCH 10/45] * Added support for indirection with mutli-dimension indicies. I.e. A[B], where B is a 2d array where each row is an index, and each col represents a different dimension for that index --- src/finchlite/finch_einsum/interpreter.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 661aa7f9..26589dfd 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -96,12 +96,20 @@ def __call__(self, node): vals = [self(arg) for arg in args] return func(*vals) case ein.Access(tns, idxs): + assert len(idxs) == len(set(idxs)) assert self.loops is not None - dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} - # evaluate the idxs that are not indices - evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} - idxs_to_perm = [idx if idx in dummy_idxs else dummy_idxs[idx] for idx in idxs] + if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): + evaled_idxs = self(idxs[0]) + idx_count = evaled_idxs.size[1] + + idxs_to_perm = [ein.Index(gensym("dummy")) for _ in range(idx_count)] + evaled_idxs = {idxs_to_perm[i]: evaled_idxs[:, i] for i in range(idx_count)} + else: + dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} + # evaluate the idxs that are not indices + evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} + idxs_to_perm = [idx if idx in dummy_idxs else dummy_idxs[idx] for idx in idxs] #convert named idxs to positional, integer indices perm = [idxs_to_perm.index(idx) for idx in self.loops if idx in idxs_to_perm] From 46fda9ccbb3c28af7332c87485502873bf3403ae Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 4 Nov 2025 21:26:26 -0500 Subject: [PATCH 11/45] * Added comments explaining indirect access implementation --- src/finchlite/finch_einsum/interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 26589dfd..40bf1d04 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -122,6 +122,8 @@ def __call__(self, node): ) # we need to remove indicies not accessed by dummy tensors + # we basically remove all indicies not accessed + # the dummy tensor system assumes all indicies are accessed at first for dummy_idx, evaled_idx in evaled_idxs.items(): axis_to_crop = idxs_to_perm.index(dummy_idx) axis_size = tns.shape[axis_to_crop] From d32ae5b3395badad8f3980d24e616ab55be14aed Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 6 Nov 2025 10:20:01 -0500 Subject: [PATCH 12/45] * Fixed bugs in einsum interpreter that stemmed from newly added support of indirection in ein.Access --- src/finchlite/finch_einsum/interpreter.py | 8 ++++++-- src/finchlite/tensor/__init__.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 40bf1d04..58bef857 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -4,7 +4,6 @@ from ..algebra import overwrite, promote_max, promote_min from . import nodes as ein -from ..tensor import SparseTensor, SparseTensorFType from ..symbolic import ftype, gensym @@ -81,6 +80,11 @@ def __init__(self, xp=None, bindings=None, loops=None): self.loops = loops def __call__(self, node): + from ..tensor import ( + SparseTensor, + SparseTensorFType, + ) + xp = self.xp match node: case ein.Literal(val): @@ -109,7 +113,7 @@ def __call__(self, node): dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} # evaluate the idxs that are not indices evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} - idxs_to_perm = [idx if idx in dummy_idxs else dummy_idxs[idx] for idx in idxs] + idxs_to_perm = [(dummy_idxs[idx] if idx in dummy_idxs else idx) for idx in idxs] #convert named idxs to positional, integer indices perm = [idxs_to_perm.index(idx) for idx in self.loops if idx in idxs_to_perm] diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index 6b2c1c54..b950cb1a 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -4,11 +4,11 @@ DenseLevelFType, ElementLevel, ElementLevelFType, - SparseTensor, - SparseTensorFType, dense, element, ) +from .sparse_tensor import SparseTensor, SparseTensorFType + __all__ = [ "DenseLevel", From 5f09be44c81619494b367760692fddc37e11ba1a Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 6 Nov 2025 11:42:46 -0500 Subject: [PATCH 13/45] * Fixed some bugs --- src/finchlite/finch_einsum/__init__.py | 2 + src/finchlite/finch_einsum/interpreter.py | 13 ++++-- src/finchlite/finch_einsum/nodes.py | 6 +++ tests/test_einsum.py | 54 +++++++++++++++++++++++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index 223a5d4a..e620e997 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -7,6 +7,7 @@ EinsumExpr, EinsumNode, Index, + GetAttribute, Literal, Plan, Produces, @@ -25,6 +26,7 @@ "EinsumScheduler", "EinsumScheduler", "Index", + "GetAttribute", "Literal", "Plan", "Produces", diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 58bef857..b5f3b209 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -2,7 +2,7 @@ import numpy as np -from ..algebra import overwrite, promote_max, promote_min +from ..algebra import overwrite, promote_max, promote_min, TensorFType from . import nodes as ein from ..symbolic import ftype, gensym @@ -103,13 +103,13 @@ def __call__(self, node): assert len(idxs) == len(set(idxs)) assert self.loops is not None - if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): + if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): #pass in the full coords array evaled_idxs = self(idxs[0]) - idx_count = evaled_idxs.size[1] + idx_count = evaled_idxs.shape[1] idxs_to_perm = [ein.Index(gensym("dummy")) for _ in range(idx_count)] evaled_idxs = {idxs_to_perm[i]: evaled_idxs[:, i] for i in range(idx_count)} - else: + else: #pass in a mixture of indicies and other expressions dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} # evaluate the idxs that are not indices evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} @@ -119,6 +119,10 @@ def __call__(self, node): perm = [idxs_to_perm.index(idx) for idx in self.loops if idx in idxs_to_perm] tns = self(tns) #evaluate the tensor + + if isinstance(ftype(tns), TensorFType) and len(perm) < tns.ndim: + perm += [i for i in range(len(perm), tns.ndim)] + tns = xp.permute_dims(tns, perm) #permute the dimensions tns = xp.expand_dims( #broadcast the tensor to the new dimensions tns, @@ -162,6 +166,7 @@ def __call__(self, node): case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution loops = arg.get_idxs() + assert set(idxs).issubset(loops) loops = sorted(loops, key=lambda x: x.name) ctx = EinsumInterpreter(self.xp, self.bindings, loops) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index f4072526..f626d02f 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -250,6 +250,12 @@ def from_children(cls, *children: Term) -> Self: def children(self): return [self.op, self.tns, self.idxs, self.arg] + def get_idxs(self) -> set["Index"]: + idxs = list() + for idx in self.idxs: + idxs.extend(idx.get_idxs()) + return idxs + @dataclass(eq=True, frozen=True) class Plan(EinsumTree): diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 434180ae..c144cfe4 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1,9 +1,15 @@ +from typing import Any import pytest import numpy as np import finchlite +import finchlite.finch_einsum as ein +from finchlite.algebra import overwrite +import operator +from finchlite.tensor import SparseTensor + @pytest.fixture def rng(): @@ -1123,3 +1129,51 @@ def test_complex_operations(self, rng, dtype): expected = np.einsum("ij", A) assert np.allclose(result, expected) + + +class TestEinsumIndirectAccess: + """Test einsum with indirect access""" + + def run_einsum_plan(self, prgm: ein.Plan, bindings: dict[str, Any], expected: np.ndarray): + """Runs an einsum plan and returns the result""" + interpreter = ein.EinsumInterpreter(bindings=bindings) + result = interpreter(prgm)[0] + assert np.allclose(result, expected) + + def test_indirect_access(self, rng): + """Test indirect access""" + + A = rng.random((5, 5)) + B = rng.random((5, 5)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + result = finchlite.multiply(A, B) + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) \ No newline at end of file From 860e016cbe7c1550524d635c2b67307adeaab890 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 10:07:59 -0500 Subject: [PATCH 14/45] * Revert changes to ein.Access --- src/finchlite/finch_einsum/interpreter.py | 34 ++--------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index b5f3b209..0647c045 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -102,45 +102,17 @@ def __call__(self, node): case ein.Access(tns, idxs): assert len(idxs) == len(set(idxs)) assert self.loops is not None - - if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): #pass in the full coords array - evaled_idxs = self(idxs[0]) - idx_count = evaled_idxs.shape[1] - - idxs_to_perm = [ein.Index(gensym("dummy")) for _ in range(idx_count)] - evaled_idxs = {idxs_to_perm[i]: evaled_idxs[:, i] for i in range(idx_count)} - else: #pass in a mixture of indicies and other expressions - dummy_idxs = {idx: ein.Index(gensym("dummy")) for idx in idxs if not isinstance(idx, ein.Index)} - # evaluate the idxs that are not indices - evaled_idxs = {idx: self(idx) for idx in idxs if not isinstance(idx, ein.Index)} - idxs_to_perm = [(dummy_idxs[idx] if idx in dummy_idxs else idx) for idx in idxs] #convert named idxs to positional, integer indices - perm = [idxs_to_perm.index(idx) for idx in self.loops if idx in idxs_to_perm] + perm = [idxs.index(idx) for idx in self.loops if idx in idxs] tns = self(tns) #evaluate the tensor - - if isinstance(ftype(tns), TensorFType) and len(perm) < tns.ndim: - perm += [i for i in range(len(perm), tns.ndim)] - tns = xp.permute_dims(tns, perm) #permute the dimensions - tns = xp.expand_dims( #broadcast the tensor to the new dimensions + return xp.expand_dims( tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - - # we need to remove indicies not accessed by dummy tensors - # we basically remove all indicies not accessed - # the dummy tensor system assumes all indicies are accessed at first - for dummy_idx, evaled_idx in evaled_idxs.items(): - axis_to_crop = idxs_to_perm.index(dummy_idx) - axis_size = tns.shape[axis_to_crop] - - idxs_to_crop = np.setdiff1d(np.arange(axis_size), evaled_idx) - tns = xp.delete(tns, idxs_to_crop, axis=axis_to_crop) - - return tns - + case ein.Plan(bodies): res = None for body in bodies: From bf43dbb65ab1b42e927de0792d9dd3d547b330c2 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 10:20:41 -0500 Subject: [PATCH 15/45] * Added seperate match case handlers in einsum interpreter loop to handle indirect access --- src/finchlite/finch_einsum/interpreter.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 0647c045..e37ea82b 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -99,7 +99,9 @@ def __call__(self, node): func = getattr(xp, nary_ops[func]) vals = [self(arg) for arg in args] return func(*vals) - case ein.Access(tns, idxs): + + #access a tensor with only indices + case ein.Access(tns, idxs) if all(isinstance(idx, ein.Index) for idx in idxs): assert len(idxs) == len(set(idxs)) assert self.loops is not None @@ -112,7 +114,22 @@ def __call__(self, node): tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - + + #access a tensor with only one indirect access index + case ein.Access(tns, idxs): + assert len(idxs) == 1 + true_idx = node.get_idxs()[0] + assert isinstance(true_idx, ein.Index) + + raise NotImplementedError("Access with only one indirect access index is not implemented") + + #access a tensor with a mixture of indices and other expressions + case ein.Access(tns, idxs): + true_idxs = node.get_idxs() #true field iteratior indicies + assert all(isinstance(idx, ein.Index) for idx in true_idxs) + assert self.loops is not None + + raise NotImplementedError("Access with a mixture of indices and other expressions is not implemented") case ein.Plan(bodies): res = None for body in bodies: From 2f14a25b27782ecc0bcd566fda03a30064b233ab Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 10:26:31 -0500 Subject: [PATCH 16/45] * Added support for one index indirect access --- src/finchlite/finch_einsum/interpreter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index e37ea82b..2b22b6ec 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -118,18 +118,17 @@ def __call__(self, node): #access a tensor with only one indirect access index case ein.Access(tns, idxs): assert len(idxs) == 1 - true_idx = node.get_idxs()[0] - assert isinstance(true_idx, ein.Index) - raise NotImplementedError("Access with only one indirect access index is not implemented") + idx = self(idxs[0]) + tns = self(tns) #evaluate the tensor + + flat_idx = xp.ravel_multi_index(idx.T, tns.shape) + return tns.flat[flat_idx] #return a 1-d array by definition #access a tensor with a mixture of indices and other expressions case ein.Access(tns, idxs): - true_idxs = node.get_idxs() #true field iteratior indicies - assert all(isinstance(idx, ein.Index) for idx in true_idxs) - assert self.loops is not None - raise NotImplementedError("Access with a mixture of indices and other expressions is not implemented") + case ein.Plan(bodies): res = None for body in bodies: From 9c573749ace78fbbf40c7438671274de12f89ef7 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 10:36:24 -0500 Subject: [PATCH 17/45] * Added support for multiple indirect indicies in access in einsum interpreter --- src/finchlite/finch_einsum/interpreter.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 2b22b6ec..23a707cb 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -127,7 +127,18 @@ def __call__(self, node): #access a tensor with a mixture of indices and other expressions case ein.Access(tns, idxs): - raise NotImplementedError("Access with a mixture of indices and other expressions is not implemented") + assert self.loops is not None + true_idxs = node.get_idxs() #true field iteratior indicies + assert all(isinstance(idx, ein.Index) for idx in true_idxs) + + tns = self(tns) + assert len(idxs) == len(tns.shape) + + # evaluate all the indirect access indicies, and evaluate field indicies as a "grab all" + evaled_idxs = [(xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) for i, idx in enumerate(idxs)] + evaled_idxs = xp.vstack(evaled_idxs) + flat_idx = xp.ravel_multi_index(evaled_idxs, tns.shape) + return tns.flat[flat_idx] case ein.Plan(bodies): res = None @@ -150,7 +161,7 @@ def __call__(self, node): assert isinstance(obj, SparseTensor) # return the coord array for the given dimension or all dimensions - return obj.coords if dim is None else obj.coords[dim, :] + return obj.coords if dim is None else obj.coords[:, dim] case ein.Einsum(op, ein.Alias(tns), idxs, arg): # This is the main entry point for einsum execution loops = arg.get_idxs() From 5ea1dddc5b6fe0dd78fa6b3af0da69b5db320ebd Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 11:04:43 -0500 Subject: [PATCH 18/45] Enhanced einsum interpreter to evaluate tensor shapes and permute dimensions for indirect access, ensuring compatibility with parent indices. --- src/finchlite/finch_einsum/interpreter.py | 36 +++++++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 23a707cb..2b3702a8 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -131,14 +131,44 @@ def __call__(self, node): true_idxs = node.get_idxs() #true field iteratior indicies assert all(isinstance(idx, ein.Index) for idx in true_idxs) + # evaluate the tensor to access tns = self(tns) assert len(idxs) == len(tns.shape) # evaluate all the indirect access indicies, and evaluate field indicies as a "grab all" evaled_idxs = [(xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) for i, idx in enumerate(idxs)] - evaled_idxs = xp.vstack(evaled_idxs) - flat_idx = xp.ravel_multi_index(evaled_idxs, tns.shape) - return tns.flat[flat_idx] + # evaluate the output tensor as a flat array + flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns.shape) + tns = tns.flat[flat_idx] + + # calculate the final shape of the tensor + # we assert that all the indirect access indicies from the parent idxs have the same size + #calculate child idxs, idxs computed using the parent "true idxs" + child_idxs = { + idx: [ + child_idx for child_idx in child_idxs + if (parent_idx in child_idx.get_idxs()) + ] for parent_idx in true_idxs + } + assert all( + child_idxs[parent_idx].count(child_idxs[parent_idx][0]) == len(child_idxs[parent_idx]) + for parent_idx in true_idxs + ) + # we merge the child idxs to get the final shape that matches the true idxs + final_shape = tuple( + evaled_idxs[idxs.index(idx)].size + for idx in idxs if idx in true_idxs + ) + tns = tns.reshape(final_shape) + idxs = [idx for idx in idxs if idx in true_idxs] + + # permute and broadcast the tensor to be compatible with rest of expression + perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + tns = xp.permute_dims(tns, perm) + return xp.expand_dims( + tns, + [i for i in range(len(self.loops)) if self.loops[i] not in idxs] + ) case ein.Plan(bodies): res = None From 71268a8aaa6fa5051f7c43e99cfafcece99c4a1a Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 11:27:35 -0500 Subject: [PATCH 19/45] * Added match case in einsum interpreter loop to hand einsums with indirect output assignment --- src/finchlite/finch_einsum/interpreter.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 2b3702a8..c49d422e 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -192,7 +192,9 @@ def __call__(self, node): # return the coord array for the given dimension or all dimensions return obj.coords if dim is None else obj.coords[:, dim] - case ein.Einsum(op, ein.Alias(tns), idxs, arg): + + # standard einsum + case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(isinstance(idx, ein.Index) for idx in idxs): # This is the main entry point for einsum execution loops = arg.get_idxs() @@ -212,5 +214,13 @@ def __call__(self, node): axis = [dropped.index(idx) for idx in idxs] self.bindings[tns] = xp.permute_dims(val, axis) return (tns,) + + # indirect einsum + case ein.Einsum(op, ein.Alias(tns), idxs, arg): + loops = arg.get_idxs() + true_idxs = node.get_idxs() + assert true_idxs.issubset(loops) + + raise NotImplementedError("Indirect einsum is not implemented yet") case _: raise ValueError(f"Unknown einsum type: {type(node)}") From 24ef52c2e174d3422dad5389c9c48f939c305834 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 12:16:51 -0500 Subject: [PATCH 20/45] * Added support for getting shape attribute of a sparse tensor --- src/finchlite/finch_einsum/interpreter.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index c49d422e..a4b56611 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -136,10 +136,14 @@ def __call__(self, node): assert len(idxs) == len(tns.shape) # evaluate all the indirect access indicies, and evaluate field indicies as a "grab all" - evaled_idxs = [(xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) for i, idx in enumerate(idxs)] + evaled_idxs = [ + (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) + for i, idx in enumerate(idxs) + ] + # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns.shape) - tns = tns.flat[flat_idx] + tns = xp.take(tns, flat_idx) # calculate the final shape of the tensor # we assert that all the indirect access indicies from the parent idxs have the same size @@ -192,6 +196,15 @@ def __call__(self, node): # return the coord array for the given dimension or all dimensions return obj.coords if dim is None else obj.coords[:, dim] + # gets the shape of a sparse tensor at a given dimension + case ein.GetAttribute(obj, ein.Literal("shape"), dim): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + assert dim is not None + + # return the shape for the given dimension + return obj.shape[dim] # standard einsum case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(isinstance(idx, ein.Index) for idx in idxs): @@ -217,10 +230,6 @@ def __call__(self, node): # indirect einsum case ein.Einsum(op, ein.Alias(tns), idxs, arg): - loops = arg.get_idxs() - true_idxs = node.get_idxs() - assert true_idxs.issubset(loops) - raise NotImplementedError("Indirect einsum is not implemented yet") case _: raise ValueError(f"Unknown einsum type: {type(node)}") From 85d81f7de16b15782fb93088a29af7a03b789526 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 12:21:54 -0500 Subject: [PATCH 21/45] Implemented direct einsum handling for indirect assignments without reduction in the einsum interpreter, enhancing tensor evaluation and index handling. --- src/finchlite/finch_einsum/interpreter.py | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index a4b56611..f5cbbb5b 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -228,8 +228,45 @@ def __call__(self, node): self.bindings[tns] = xp.permute_dims(val, axis) return (tns,) + # indirect einsum with no reduction + case ein.Einsum(op, ein.Alias(tns), idxs, arg) if op == overwrite: + loops = arg.get_idxs() + true_idxs = node.get_idxs() + assert true_idxs.equals(loops) + + # evalaute the tensor to access/write to + # output tensor must exist with initial reduction values + assert tns in self.bindings + tns = self.bindings[tns] + assert len(idxs) == len(tns.shape) + + # evaluate all the indirect assignment indicies, and evaluate field indicies as a "grab all" + evaled_idxs = [ + (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) + for i, idx in enumerate(idxs) + ] + evaled_idxs = xp.vstack(evaled_idxs) + flat_idx = xp.ravel_multi_index(evaled_idxs, tns.shape) + + arg = self(arg) + tns.flat[flat_idx] = arg #no need to assign to bindings + return (tns,) + # indirect einsum case ein.Einsum(op, ein.Alias(tns), idxs, arg): + loops = arg.get_idxs() + true_idxs = node.get_idxs() + assert true_idxs.issubset(loops) + + # evalaute the tensor to access/write to + # output tensor must exist with initial reduction values + assert tns in self.bindings + tns = self.bindings[tns] + assert len(idxs) == len(tns.shape) + + reduced_axis = [idx for idx in loops if idx in true_idxs] + + # evaluate all the indirect assignment indicies, and evaluate field indicies as a "grab all" raise NotImplementedError("Indirect einsum is not implemented yet") case _: raise ValueError(f"Unknown einsum type: {type(node)}") From 710a2e300233d3dbc2e4fdc7551e44764f2de8c4 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 7 Nov 2025 12:55:26 -0500 Subject: [PATCH 22/45] Refactored indirect einsum handling in the interpreter to support reduction operations, enhancing tensor evaluation and index management. Updated assertions and added context for evaluation, ensuring proper handling of output tensors. --- src/finchlite/finch_einsum/interpreter.py | 59 ++++++++++++----------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index f5cbbb5b..f9837b4d 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -228,45 +228,46 @@ def __call__(self, node): self.bindings[tns] = xp.permute_dims(val, axis) return (tns,) - # indirect einsum with no reduction - case ein.Einsum(op, ein.Alias(tns), idxs, arg) if op == overwrite: + # indirect einsum with reduction + case ein.Einsum(op, ein.Alias(tns), idxs, arg): loops = arg.get_idxs() true_idxs = node.get_idxs() - assert true_idxs.equals(loops) + assert true_idxs.issubset(loops) # evalaute the tensor to access/write to # output tensor must exist with initial reduction values - assert tns in self.bindings - tns = self.bindings[tns] - assert len(idxs) == len(tns.shape) + assert tns in self.bindings + tns_out = self.bindings[tns] + assert len(idxs) == len(tns_out.shape) + + #Evaluate arg with full loop context + loops_sorted = sorted(loops, key=lambda x: x.name) + ctx = EinsumInterpreter(self.xp, self.bindings, loops_sorted) + arg_val = ctx(arg) + + #Reduce over non-true_idxs axes using the reduction op + reduce_axes = tuple(i for i in range(len(loops_sorted)) + if loops_sorted[i] not in true_idxs) + if reduce_axes: + op_func = self(op) # Get the actual operation + reduce_func = getattr(xp, reduction_ops[op_func]) + reduced_val = reduce_func(arg_val, axis=reduce_axes) + else: + reduced_val = arg_val - # evaluate all the indirect assignment indicies, and evaluate field indicies as a "grab all" + # Evaluate indirect index expressions (same context) evaled_idxs = [ - (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) + (xp.arange(tns_out.shape[i]) if isinstance(idx, ein.Index) else ctx(idx)) for i, idx in enumerate(idxs) ] - evaled_idxs = xp.vstack(evaled_idxs) - flat_idx = xp.ravel_multi_index(evaled_idxs, tns.shape) + flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns_out.shape) - arg = self(arg) - tns.flat[flat_idx] = arg #no need to assign to bindings + # Scatter/accumulate using the SAME operation's ufunc + scatter_func = getattr(xp, nary_ops[op_func]) + scatter_func.at(tns_out.flat, flat_idx, reduced_val) + + # assign the output tensor to the bindings + self.bindings[tns] = tns_out return (tns,) - - # indirect einsum - case ein.Einsum(op, ein.Alias(tns), idxs, arg): - loops = arg.get_idxs() - true_idxs = node.get_idxs() - assert true_idxs.issubset(loops) - - # evalaute the tensor to access/write to - # output tensor must exist with initial reduction values - assert tns in self.bindings - tns = self.bindings[tns] - assert len(idxs) == len(tns.shape) - - reduced_axis = [idx for idx in loops if idx in true_idxs] - - # evaluate all the indirect assignment indicies, and evaluate field indicies as a "grab all" - raise NotImplementedError("Indirect einsum is not implemented yet") case _: raise ValueError(f"Unknown einsum type: {type(node)}") From 1dc9667391a0d5f4b129da52b339ff4458776999 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 11:18:00 -0500 Subject: [PATCH 23/45] Removed implementation of indirect einsum with reduction in the interpreter, raising a NotImplementedError for unsupported operations. This change is temporary, and is simply for the purpose for a PR and version control --- src/finchlite/finch_einsum/interpreter.py | 42 ++--------------------- 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index f9837b4d..47361608 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -228,46 +228,8 @@ def __call__(self, node): self.bindings[tns] = xp.permute_dims(val, axis) return (tns,) - # indirect einsum with reduction + # indirect einsum case ein.Einsum(op, ein.Alias(tns), idxs, arg): - loops = arg.get_idxs() - true_idxs = node.get_idxs() - assert true_idxs.issubset(loops) - - # evalaute the tensor to access/write to - # output tensor must exist with initial reduction values - assert tns in self.bindings - tns_out = self.bindings[tns] - assert len(idxs) == len(tns_out.shape) - - #Evaluate arg with full loop context - loops_sorted = sorted(loops, key=lambda x: x.name) - ctx = EinsumInterpreter(self.xp, self.bindings, loops_sorted) - arg_val = ctx(arg) - - #Reduce over non-true_idxs axes using the reduction op - reduce_axes = tuple(i for i in range(len(loops_sorted)) - if loops_sorted[i] not in true_idxs) - if reduce_axes: - op_func = self(op) # Get the actual operation - reduce_func = getattr(xp, reduction_ops[op_func]) - reduced_val = reduce_func(arg_val, axis=reduce_axes) - else: - reduced_val = arg_val - - # Evaluate indirect index expressions (same context) - evaled_idxs = [ - (xp.arange(tns_out.shape[i]) if isinstance(idx, ein.Index) else ctx(idx)) - for i, idx in enumerate(idxs) - ] - flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns_out.shape) - - # Scatter/accumulate using the SAME operation's ufunc - scatter_func = getattr(xp, nary_ops[op_func]) - scatter_func.at(tns_out.flat, flat_idx, reduced_val) - - # assign the output tensor to the bindings - self.bindings[tns] = tns_out - return (tns,) + raise NotImplementedError("Indirect einsum assignment is not implemented") case _: raise ValueError(f"Unknown einsum type: {type(node)}") From d9537574b7688d6ddc13d2ebb24e5e16d19b58ba Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 11:37:50 -0500 Subject: [PATCH 24/45] Enhanced EinsumInterpreter to handle cases with fewer indices than dimensions during tensor permutation. Updated test to flatten the result of multiplication for sparse tensors. --- src/finchlite/finch_einsum/interpreter.py | 11 +++++++---- tests/test_einsum.py | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 47361608..34f561fd 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -2,7 +2,7 @@ import numpy as np -from ..algebra import overwrite, promote_max, promote_min, TensorFType +from ..algebra import Tensor, overwrite, promote_max, promote_min, TensorFType from . import nodes as ein from ..symbolic import ftype, gensym @@ -109,6 +109,11 @@ def __call__(self, node): perm = [idxs.index(idx) for idx in self.loops if idx in idxs] tns = self(tns) #evaluate the tensor + + #if there are fewer indicies than dimensions, add the remaining dimensions as if they werent permutated + if hasattr(tns, "ndim") and len(perm) < tns.ndim: + perm = perm + [i for i in range(len(perm), tns.ndim)] + tns = xp.permute_dims(tns, perm) #permute the dimensions return xp.expand_dims( tns, @@ -116,9 +121,7 @@ def __call__(self, node): ) #access a tensor with only one indirect access index - case ein.Access(tns, idxs): - assert len(idxs) == 1 - + case ein.Access(tns, idxs) if len(idxs) == 1: idx = self(idxs[0]) tns = self(tns) #evaluate the tensor diff --git a/tests/test_einsum.py b/tests/test_einsum.py index c144cfe4..862683a3 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1148,6 +1148,8 @@ def test_indirect_access(self, rng): sparse_A = SparseTensor.from_dense_tensor(A) + # A is sparse + # C[i] = AElems[i] * B[ACoords[i]] prgm = ein.Plan(( ein.Einsum( op=ein.Literal(overwrite), @@ -1175,5 +1177,5 @@ def test_indirect_access(self, rng): ein.Produces((ein.Alias("C"),)), )) - result = finchlite.multiply(A, B) + result = finchlite.multiply(A, B).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) \ No newline at end of file From 09c9879400d3c8376d76a8ecdcdfe2119a311c78 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 11:44:26 -0500 Subject: [PATCH 25/45] Renamed test for indirect access to clarify focus on elementwise multiplication without indirect assignment in einsum functionality. --- tests/test_einsum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 862683a3..3a264028 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1140,8 +1140,8 @@ def run_einsum_plan(self, prgm: ein.Plan, bindings: dict[str, Any], expected: np result = interpreter(prgm)[0] assert np.allclose(result, expected) - def test_indirect_access(self, rng): - """Test indirect access""" + def test_indirect_elementwise_multiplication(self, rng): + """Test indirect elementwise multiplication but no indirect assignment""" A = rng.random((5, 5)) B = rng.random((5, 5)) From f65e4f914c6d8b913acf2350fd6608623b1e557a Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 11:55:03 -0500 Subject: [PATCH 26/45] Add tests for indirect access in einsum operations Implemented multiple test cases for indirect elementwise addition, multiple reads, operations with constants, nested operations, and direct access only. These tests enhance coverage for the einsum functionality, ensuring proper handling of sparse tensors and indirect indexing. --- tests/test_einsum.py | 207 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 206 insertions(+), 1 deletion(-) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 3a264028..e20577b9 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1178,4 +1178,209 @@ def test_indirect_elementwise_multiplication(self, rng): )) result = finchlite.multiply(A, B).flatten() - self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) \ No newline at end of file + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_elementwise_addition(self, rng): + """Test indirect elementwise addition""" + + A = rng.random((4, 4)) + B = rng.random((4, 4)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = AElems[i] + B[ACoords[i]] + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + result = (A + B).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_multiple_reads(self, rng): + """Test multiple indirect reads from the same tensor""" + + A = rng.random((3, 3)) + B = rng.random((3, 3)) + + sparse_A = SparseTensor.from_dense_tensor(A) + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i] = AElems[i] * BElems[i] + # Both A and B are sparse, reading their elements directly + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + result = finchlite.multiply(A, B).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": sparse_B}, result) + + def test_indirect_with_constant(self, rng): + """Test indirect access combined with constant""" + + A = rng.random((4, 4)) + B = rng.random((4, 4)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = AElems[i] * B[ACoords[i]] + 5.0 + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ein.Literal(5.0), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + result = (A * B + 5.0).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_nested_operations(self, rng): + """Test nested operations with indirect access""" + + A = rng.random((3, 3)) + B = rng.random((3, 3)) + C = rng.random((3, 3)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # D[i] = (AElems[i] + B[ACoords[i]]) * C[ACoords[i]] + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + )) + + result = ((A + B) * C).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B, "C": C}, result) + + def test_indirect_direct_access_only(self, rng): + """Test accessing only the indirect coordinates""" + + A = rng.random((8,)) + B = rng.random((8,)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = B[ACoords[i]] (read B indirectly, without using A's elements) + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + # Result should be B's values at A's coordinates + expected = B[A != 0] + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, expected) \ No newline at end of file From 7549b629eacd2bdf09ef290a76b0efd393dba9fd Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 13:01:25 -0500 Subject: [PATCH 27/45] Refactor EinsumInterpreter to handle flat indexing for 1D arrays and add extensive tests for various indirect indexing scenarios. This includes double, triple, and mixed direct/indirect indexing cases, ensuring robust handling of sparse tensors and arithmetic operations. --- src/finchlite/finch_einsum/interpreter.py | 4 +- tests/test_einsum.py | 367 +++++++++++++++++++++- 2 files changed, 368 insertions(+), 3 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 34f561fd..6087e7c0 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -124,8 +124,8 @@ def __call__(self, node): case ein.Access(tns, idxs) if len(idxs) == 1: idx = self(idxs[0]) tns = self(tns) #evaluate the tensor - - flat_idx = xp.ravel_multi_index(idx.T, tns.shape) + + flat_idx = idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape) return tns.flat[flat_idx] #return a 1-d array by definition #access a tensor with a mixture of indices and other expressions diff --git a/tests/test_einsum.py b/tests/test_einsum.py index e20577b9..70a7519f 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1383,4 +1383,369 @@ def test_indirect_direct_access_only(self, rng): # Result should be B's values at A's coordinates expected = B[A != 0] - self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, expected) \ No newline at end of file + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, expected) + + def test_double_indirection(self, rng): + """Test double indirection: A[B[CCoords[i]]]""" + + # Create small arrays to ensure valid indexing + A = rng.random((8,)) + # B contains integer indices into A (0-7) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + # C is sparse + C = rng.random((8,)) + + sparse_C = SparseTensor.from_dense_tensor(C) + + # D[i] = A[B[CCoords[i]]] + # First get CCoords[i], then use that to index B, then use B's value to index A + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + )) + + # Expected: for each non-zero position in C, get its coord, index into B, then index into A + c_coords = sparse_C.coords + expected = A[B[c_coords]].flatten() + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) + + def test_triple_indirection(self, rng): + """Test triple indirection: A[B[C[DCoords[i]]]]""" + + # Create arrays with valid indexing ranges + A = rng.random((8,)) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + C = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(C) + D = rng.random((8,)) + + sparse_D = SparseTensor.from_dense_tensor(D) + + # E[i] = A[B[C[DCoords[i]]]] + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("D"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("E"),)), + )) + + # Expected: chain of indirections + d_coords = sparse_D.coords + expected = A[B[C[d_coords]]].flatten() + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": C, "D": sparse_D}, expected) + + def test_mixed_direct_indirect_indexing_2d(self, rng): + """Test mixed indexing: A[BCoords[i], j] - one indirect, one direct""" + + A = rng.random((5, 4)) + B = rng.random((5,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j] = A[BCoords[i], j] + # First index is indirect (from B's coords), second is direct + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ein.Index("j"), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + # Expected: A rows indexed by B's coords, all columns + b_coords = sparse_B.coords + expected = A[b_coords, :] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_mixed_direct_indirect_indexing_reversed(self, rng): + """Test mixed indexing reversed: A[i, BCoords[j]] - first direct, second indirect""" + + A = rng.random((4, 6)) + B = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j] = A[i, BCoords[j]] + # First index is direct, second is indirect + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("j"),) + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + # Expected: all rows of A, columns indexed by B's coords + b_coords = sparse_B.coords + expected = A[:, b_coords] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_both_indices_indirect_same_source(self, rng): + """Test both indices indirect from same source: A[BCoords[i], BCoords[i]]""" + + A = rng.random((6, 6)) + B = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i] = A[BCoords[i], BCoords[i]] + # Extracting diagonal-like elements using indirect coordinates + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + # Expected: A[coords, coords] - pseudo-diagonal at indirect positions + b_coords = sparse_B.coords + expected = A[b_coords, b_coords] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_both_indices_indirect_different_sources(self, rng): + """Test both indices indirect from different sources: A[BCoords[i], CCoords[i]]""" + + A = rng.random((5, 6)) + B = rng.random((5,)) + C = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + sparse_C = SparseTensor.from_dense_tensor(C) + + # D[i] = A[BCoords[i], CCoords[i]] + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + )) + + # Expected: A indexed by pairs of coords from B and C + # This requires both to have the same number of non-zero elements + b_coords = sparse_B.coords + c_coords = sparse_C.coords + min_len = min(len(b_coords), len(c_coords)) + expected = A[b_coords[:min_len], c_coords[:min_len]] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B, "C": sparse_C}, expected) + + def test_double_indirection_with_operation(self, rng): + """Test double indirection combined with arithmetic operation""" + + A = rng.random((8,)) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + C = rng.random((8,)) + D = rng.random((8,)) + + sparse_C = SparseTensor.from_dense_tensor(C) + + # E[i] = A[B[CCoords[i]]] * CElems[i] + # Double indirection plus multiplication with sparse elements + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ), + ), + ), + ), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("E"),)), + )) + + c_coords = sparse_C.coords + c_elems = sparse_C.elems + expected = A[B[c_coords]] * c_elems + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) + + def test_mixed_indexing_with_computation(self, rng): + """Test mixed direct/indirect indexing with computation""" + + A = rng.random((4, 5)) + B = rng.random((4,)) + C = rng.random((5,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # D[i, j] = A[BCoords[i], j] + BElems[i] + # Mixed indexing plus addition with sparse elements + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("i"),) + ), + ein.Index("j"), + ), + ), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + )) + + b_coords = sparse_B.coords + b_elems = sparse_B.elems + # Broadcasting: A[coords, :] has shape (len(coords), 5), b_elems has shape (len(coords),) + expected = A[b_coords, :] + b_elems[:, np.newaxis] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_indirect_3d_tensor_access(self, rng): + """Test indirect access on 3D tensor with mixed indices""" + + A = rng.random((3, 4, 5)) + B = rng.random((4,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j, k] = A[i, BCoords[j], k] + # Middle dimension is indirectly indexed + prgm = ein.Plan(( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j"), ein.Index("k")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), + idxs=(ein.Index("j"),) + ), + ein.Index("k"), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + )) + + b_coords = sparse_B.coords + expected = A[:, b_coords, :] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) \ No newline at end of file From 29f5f03e711934ee43fae4b4c154988ea1b8d5ef Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 18:30:15 -0500 Subject: [PATCH 28/45] Refactor EinsumInterpreter to improve handling of indirect indexing and tensor reshaping. Introduced flattening for index evaluations, optimized axis reordering, and ensured compatibility with parent indices. Updated tests to reflect changes in expected behavior for sparse tensor operations. --- src/finchlite/finch_einsum/interpreter.py | 60 +++++++++++++++++------ tests/test_einsum.py | 6 ++- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 6087e7c0..ff7351e6 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -140,41 +140,72 @@ def __call__(self, node): # evaluate all the indirect access indicies, and evaluate field indicies as a "grab all" evaled_idxs = [ - (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx)) + (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx).flatten()) for i, idx in enumerate(idxs) ] - + + combo_idxs = xp.meshgrid(*evaled_idxs, indexing="ij") + combo_idxs = xp.stack(combo_idxs, axis=-1) + combo_idxs = combo_idxs.reshape(-1, len(idxs)) + # evaluate the output tensor as a flat array - flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns.shape) + flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) tns = xp.take(tns, flat_idx) - - # calculate the final shape of the tensor - # we assert that all the indirect access indicies from the parent idxs have the same size + #calculate child idxs, idxs computed using the parent "true idxs" child_idxs = { - idx: [ - child_idx for child_idx in child_idxs + parent_idx: [ + child_idx for child_idx in idxs if (parent_idx in child_idx.get_idxs()) ] for parent_idx in true_idxs } + #calculate parent idxs, just inverse of child_idxs dictionary + parent_idxs = { + child_idx: parent_idx + for parent_idx in true_idxs + for child_idx in child_idxs[parent_idx] + } + + # we assert that all the indirect access indicies from the parent idxs have the same size assert all( child_idxs[parent_idx].count(child_idxs[parent_idx][0]) == len(child_idxs[parent_idx]) for parent_idx in true_idxs ) + + # calculate the shape of the tensor truest to its current form + current_shape = tuple( + evaled_idxs[idxs.index(idx)].size + for idx in idxs + ) + tns = tns.reshape(current_shape) + + # a mapping from each idx to its axis wrt to current shape + idxs_axis = {idx: i for i, idx in enumerate(idxs)} + + true_idxs = list(true_idxs) + + # reorder the axis so that each child idx of a parent idx are consecutive + new_axes = [ + idxs_axis[child_idx] + for true_idx in true_idxs + for child_idx in child_idxs[true_idx] + ] + tns = xp.transpose(tns, axes=new_axes) + + # calculate the final shape of the tensor # we merge the child idxs to get the final shape that matches the true idxs final_shape = tuple( - evaled_idxs[idxs.index(idx)].size - for idx in idxs if idx in true_idxs + np.prod([evaled_idxs[idxs.index(child_idx)].size for child_idx in child_idxs[parent_idx]]) + for parent_idx in true_idxs ) tns = tns.reshape(final_shape) - idxs = [idx for idx in idxs if idx in true_idxs] # permute and broadcast the tensor to be compatible with rest of expression - perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + perm = [true_idxs.index(idx) for idx in self.loops if idx in true_idxs] tns = xp.permute_dims(tns, perm) return xp.expand_dims( tns, - [i for i in range(len(self.loops)) if self.loops[i] not in idxs] + [i for i in range(len(self.loops)) if self.loops[i] not in true_idxs] ) case ein.Plan(bodies): @@ -198,7 +229,8 @@ def __call__(self, node): assert isinstance(obj, SparseTensor) # return the coord array for the given dimension or all dimensions - return obj.coords if dim is None else obj.coords[:, dim] + toReturn = obj.coords if dim is None else obj.coords[:, dim] + return toReturn # gets the shape of a sparse tensor at a given dimension case ein.GetAttribute(obj, ein.Literal("shape"), dim): obj = self(obj) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 70a7519f..6b1b7b65 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1138,6 +1138,10 @@ def run_einsum_plan(self, prgm: ein.Plan, bindings: dict[str, Any], expected: np """Runs an einsum plan and returns the result""" interpreter = ein.EinsumInterpreter(bindings=bindings) result = interpreter(prgm)[0] + + print(result) + print(expected) + assert np.allclose(result, expected) def test_indirect_elementwise_multiplication(self, rng): @@ -1506,7 +1510,7 @@ def test_mixed_direct_indirect_indexing_2d(self, rng): # Expected: A rows indexed by B's coords, all columns b_coords = sparse_B.coords - expected = A[b_coords, :] + expected = A[b_coords.flatten(), :] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) def test_mixed_direct_indirect_indexing_reversed(self, rng): From 6fe0528e9018fc20ef20dbd823edb8d7dcf802da Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 18:31:24 -0500 Subject: [PATCH 29/45] Remove redundant calculation of parent indices in EinsumInterpreter to streamline indirect indexing logic. This change simplifies the code and maintains the integrity of index assertions. --- src/finchlite/finch_einsum/interpreter.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index ff7351e6..be464c53 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -159,12 +159,6 @@ def __call__(self, node): if (parent_idx in child_idx.get_idxs()) ] for parent_idx in true_idxs } - #calculate parent idxs, just inverse of child_idxs dictionary - parent_idxs = { - child_idx: parent_idx - for parent_idx in true_idxs - for child_idx in child_idxs[parent_idx] - } # we assert that all the indirect access indicies from the parent idxs have the same size assert all( From 8d0a788aab4be9dfb4204890fc6d3578a18b37bc Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 20:22:16 -0500 Subject: [PATCH 30/45] Refactor EinsumInterpreter to enhance indirect indexing logic by separating ein.Index and indirect access positions. Improved cartesian product computation and reshaping of tensors, ensuring correct handling of index sizes. Updated tests to reflect changes in expected behavior for indirect indexing scenarios. --- src/finchlite/finch_einsum/interpreter.py | 88 ++++++++++++++++------- tests/test_einsum.py | 33 +++++---- 2 files changed, 79 insertions(+), 42 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index be464c53..2310784e 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -138,16 +138,59 @@ def __call__(self, node): tns = self(tns) assert len(idxs) == len(tns.shape) - # evaluate all the indirect access indicies, and evaluate field indicies as a "grab all" - evaled_idxs = [ - (xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx).flatten()) - for i, idx in enumerate(idxs) - ] + # Separate ein.Index positions from indirect access positions + ein_idx_positions = [i for i, idx in enumerate(idxs) if isinstance(idx, ein.Index)] + indirect_positions = [i for i, idx in enumerate(idxs) if not isinstance(idx, ein.Index)] + + # Compute cartesian product for ein.Index instances only + if ein_idx_positions: + ein_idx_ranges = [xp.arange(tns.shape[i]) for i in ein_idx_positions] + ein_combo_idxs = xp.meshgrid(*ein_idx_ranges, indexing="ij") + ein_combo_idxs = xp.stack(ein_combo_idxs, axis=-1) + ein_combo_idxs = ein_combo_idxs.reshape(-1, len(ein_idx_positions)) # Shape: (N_ein, num_ein_indices) + else: + ein_combo_idxs = xp.empty((1, 0), dtype=xp.int64) + + # Evaluate indirect accesses as a "super index" (no cartesian product amongst them) + if indirect_positions: + indirect_vals = [self(idxs[i]).flatten() for i in indirect_positions] + indirect_combo = xp.stack(indirect_vals, axis=-1) # Shape: (M_indirect, num_indirect_indices) + else: + indirect_combo = xp.empty((1, 0), dtype=xp.int64) + + # Compute cartesian product between ein.Index group and indirect group + # The indirect group is treated as a single "super index" (no cartesian product amongst indirect values) + n_ein = ein_combo_idxs.shape[0] + n_indirect = indirect_combo.shape[0] + + # Determine which group comes first/last to decide iteration order + # For standard row-major order, the first dimension varies slowest + # Check if ein or indirect comes first + first_ein_pos = ein_idx_positions[0] if ein_idx_positions else float('inf') + first_indirect_pos = indirect_positions[0] if indirect_positions else float('inf') + + if first_ein_pos < first_indirect_pos: + # ein comes first: repeat ein, tile indirect (ein varies slowest) + ein_result = xp.repeat(ein_combo_idxs, n_indirect, axis=0) + indirect_result = xp.tile(indirect_combo, (n_ein, 1)) + else: + # indirect comes first: repeat indirect, tile ein (indirect varies slowest) + indirect_result = xp.repeat(indirect_combo, n_ein, axis=0) + ein_result = xp.tile(ein_combo_idxs, (n_indirect, 1)) + + # Now we need to interleave these back in the original order + combo_idxs = xp.empty((n_ein * n_indirect, len(idxs)), dtype=xp.int64) + idx_sizes = [] # Track the size of each index dimension + for i, idx in enumerate(idxs): + if isinstance(idx, ein.Index): + pos_in_ein = ein_idx_positions.index(i) + combo_idxs[:, i] = ein_result[:, pos_in_ein] + idx_sizes.append(tns.shape[i]) + else: + pos_in_indirect = indirect_positions.index(i) + combo_idxs[:, i] = indirect_result[:, pos_in_indirect] + idx_sizes.append(n_indirect) - combo_idxs = xp.meshgrid(*evaled_idxs, indexing="ij") - combo_idxs = xp.stack(combo_idxs, axis=-1) - combo_idxs = combo_idxs.reshape(-1, len(idxs)) - # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) tns = xp.take(tns, flat_idx) @@ -160,36 +203,27 @@ def __call__(self, node): ] for parent_idx in true_idxs } - # we assert that all the indirect access indicies from the parent idxs have the same size + # we assert that all the indirect access indicies from the parent idxs have the same size + child_idxs_size = { + parent_idx: [idx_sizes[idxs.index(child_idx)] for child_idx in child_idxs[parent_idx]] + for parent_idx in true_idxs + } assert all( - child_idxs[parent_idx].count(child_idxs[parent_idx][0]) == len(child_idxs[parent_idx]) + child_idxs_size[parent_idx].count(child_idxs_size[parent_idx][0]) == len(child_idxs_size[parent_idx]) for parent_idx in true_idxs ) - # calculate the shape of the tensor truest to its current form - current_shape = tuple( - evaled_idxs[idxs.index(idx)].size - for idx in idxs - ) - tns = tns.reshape(current_shape) - # a mapping from each idx to its axis wrt to current shape idxs_axis = {idx: i for i, idx in enumerate(idxs)} true_idxs = list(true_idxs) - - # reorder the axis so that each child idx of a parent idx are consecutive - new_axes = [ - idxs_axis[child_idx] - for true_idx in true_idxs - for child_idx in child_idxs[true_idx] - ] - tns = xp.transpose(tns, axes=new_axes) + true_idxs = sorted(true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]]) + print(true_idxs) # calculate the final shape of the tensor # we merge the child idxs to get the final shape that matches the true idxs final_shape = tuple( - np.prod([evaled_idxs[idxs.index(child_idx)].size for child_idx in child_idxs[parent_idx]]) + idx_sizes[idxs.index(child_idxs[parent_idx][0])] for parent_idx in true_idxs ) tns = tns.reshape(final_shape) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 6b1b7b65..df446d86 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1139,6 +1139,10 @@ def run_einsum_plan(self, prgm: ein.Plan, bindings: dict[str, Any], expected: np interpreter = ein.EinsumInterpreter(bindings=bindings) result = interpreter(prgm)[0] + import numpy as np + import sys + np.set_printoptions(threshold=sys.maxsize) + print(result) print(expected) @@ -1544,7 +1548,7 @@ def test_mixed_direct_indirect_indexing_reversed(self, rng): # Expected: all rows of A, columns indexed by B's coords b_coords = sparse_B.coords - expected = A[:, b_coords] + expected = A[:, b_coords.flatten()] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) def test_both_indices_indirect_same_source(self, rng): @@ -1581,14 +1585,14 @@ def test_both_indices_indirect_same_source(self, rng): # Expected: A[coords, coords] - pseudo-diagonal at indirect positions b_coords = sparse_B.coords - expected = A[b_coords, b_coords] + expected = A[b_coords.flatten(), b_coords.flatten()] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) def test_both_indices_indirect_different_sources(self, rng): """Test both indices indirect from different sources: A[BCoords[i], CCoords[i]]""" - A = rng.random((5, 6)) - B = rng.random((5,)) + A = rng.random((6, 6)) + B = rng.random((6,)) C = rng.random((6,)) sparse_B = SparseTensor.from_dense_tensor(B) @@ -1619,10 +1623,9 @@ def test_both_indices_indirect_different_sources(self, rng): # Expected: A indexed by pairs of coords from B and C # This requires both to have the same number of non-zero elements - b_coords = sparse_B.coords - c_coords = sparse_C.coords - min_len = min(len(b_coords), len(c_coords)) - expected = A[b_coords[:min_len], c_coords[:min_len]] + b_coords = sparse_B.coords.flatten() + c_coords = sparse_C.coords.flatten() + expected = A[b_coords, c_coords] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B, "C": sparse_C}, expected) def test_double_indirection_with_operation(self, rng): @@ -1632,7 +1635,6 @@ def test_double_indirection_with_operation(self, rng): B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) rng.shuffle(B) C = rng.random((8,)) - D = rng.random((8,)) sparse_C = SparseTensor.from_dense_tensor(C) @@ -1670,15 +1672,16 @@ def test_double_indirection_with_operation(self, rng): ein.Produces((ein.Alias("E"),)), )) - c_coords = sparse_C.coords - c_elems = sparse_C.elems + c_coords = sparse_C.coords.flatten() + c_elems = sparse_C.data expected = A[B[c_coords]] * c_elems + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) def test_mixed_indexing_with_computation(self, rng): """Test mixed direct/indirect indexing with computation""" - A = rng.random((4, 5)) + A = rng.random((5, 5)) B = rng.random((4,)) C = rng.random((5,)) @@ -1714,8 +1717,8 @@ def test_mixed_indexing_with_computation(self, rng): ein.Produces((ein.Alias("D"),)), )) - b_coords = sparse_B.coords - b_elems = sparse_B.elems + b_coords = sparse_B.coords.flatten() + b_elems = sparse_B.data # Broadcasting: A[coords, :] has shape (len(coords), 5), b_elems has shape (len(coords),) expected = A[b_coords, :] + b_elems[:, np.newaxis] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) @@ -1750,6 +1753,6 @@ def test_indirect_3d_tensor_access(self, rng): ein.Produces((ein.Alias("C"),)), )) - b_coords = sparse_B.coords + b_coords = sparse_B.coords.flatten() expected = A[:, b_coords, :] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) \ No newline at end of file From cc09bf5e7cd140267dccab5ae22865d0f06e3a90 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 20:38:37 -0500 Subject: [PATCH 31/45] Refactor EinsumInterpreter to improve index handling by grouping ein.Index and indirect access positions. Enhanced cartesian product computation and index size tracking, ensuring accurate evaluation of tensor indices. Updated logic for processing indirect accesses and their dependencies on true indices. --- src/finchlite/finch_einsum/interpreter.py | 117 +++++++++++++--------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 2310784e..168fc4bf 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -137,59 +137,81 @@ def __call__(self, node): # evaluate the tensor to access tns = self(tns) assert len(idxs) == len(tns.shape) - - # Separate ein.Index positions from indirect access positions - ein_idx_positions = [i for i, idx in enumerate(idxs) if isinstance(idx, ein.Index)] - indirect_positions = [i for i, idx in enumerate(idxs) if not isinstance(idx, ein.Index)] - - # Compute cartesian product for ein.Index instances only - if ein_idx_positions: - ein_idx_ranges = [xp.arange(tns.shape[i]) for i in ein_idx_positions] - ein_combo_idxs = xp.meshgrid(*ein_idx_ranges, indexing="ij") - ein_combo_idxs = xp.stack(ein_combo_idxs, axis=-1) - ein_combo_idxs = ein_combo_idxs.reshape(-1, len(ein_idx_positions)) # Shape: (N_ein, num_ein_indices) - else: - ein_combo_idxs = xp.empty((1, 0), dtype=xp.int64) - - # Evaluate indirect accesses as a "super index" (no cartesian product amongst them) - if indirect_positions: - indirect_vals = [self(idxs[i]).flatten() for i in indirect_positions] - indirect_combo = xp.stack(indirect_vals, axis=-1) # Shape: (M_indirect, num_indirect_indices) - else: - indirect_combo = xp.empty((1, 0), dtype=xp.int64) - - # Compute cartesian product between ein.Index group and indirect group - # The indirect group is treated as a single "super index" (no cartesian product amongst indirect values) - n_ein = ein_combo_idxs.shape[0] - n_indirect = indirect_combo.shape[0] - - # Determine which group comes first/last to decide iteration order - # For standard row-major order, the first dimension varies slowest - # Check if ein or indirect comes first - first_ein_pos = ein_idx_positions[0] if ein_idx_positions else float('inf') - first_indirect_pos = indirect_positions[0] if indirect_positions else float('inf') - if first_ein_pos < first_indirect_pos: - # ein comes first: repeat ein, tile indirect (ein varies slowest) - ein_result = xp.repeat(ein_combo_idxs, n_indirect, axis=0) - indirect_result = xp.tile(indirect_combo, (n_ein, 1)) - else: - # indirect comes first: repeat indirect, tile ein (indirect varies slowest) - indirect_result = xp.repeat(indirect_combo, n_ein, axis=0) - ein_result = xp.tile(ein_combo_idxs, (n_indirect, 1)) + # Build a mapping of which positions belong to which true_idx groups + idx_groups = [] # List of (position, group_type, group_id) + idx_sizes = [] - # Now we need to interleave these back in the original order - combo_idxs = xp.empty((n_ein * n_indirect, len(idxs)), dtype=xp.int64) - idx_sizes = [] # Track the size of each index dimension for i, idx in enumerate(idxs): if isinstance(idx, ein.Index): - pos_in_ein = ein_idx_positions.index(i) - combo_idxs[:, i] = ein_result[:, pos_in_ein] + # Direct ein.Index + idx_groups.append((i, 'ein', idx)) idx_sizes.append(tns.shape[i]) else: - pos_in_indirect = indirect_positions.index(i) - combo_idxs[:, i] = indirect_result[:, pos_in_indirect] - idx_sizes.append(n_indirect) + # Indirect access - find which true_idx it depends on + child_true_idxs = idx.get_idxs() + # Should depend on exactly one true_idx for this to work + if len(child_true_idxs) == 1: + parent_idx = list(child_true_idxs)[0] + idx_groups.append((i, 'indirect', parent_idx)) + else: + # Complex case - for now treat as independent + idx_groups.append((i, 'indirect_complex', idx)) + idx_sizes.append(None) # Will be determined after evaluation + + # Group positions by their parent true index + groups_by_parent = {} + for pos, group_type, group_id in idx_groups: + if group_id not in groups_by_parent: + groups_by_parent[group_id] = [] + groups_by_parent[group_id].append((pos, group_type)) + + # Build ranges for the cartesian product + # Each group contributes one dimension to the product + group_ranges = [] + group_positions = [] # Track which positions belong to each group + + # Process groups in position order (based on first occurrence) + processed_groups = [] + for pos, group_type, group_id in idx_groups: + if group_id not in processed_groups: + processed_groups.append(group_id) + positions = [p for p, _ in groups_by_parent[group_id]] + group_positions.append(positions) + + if group_type == 'ein': + # For ein.Index, use the range + group_ranges.append(xp.arange(tns.shape[pos])) + else: + # For indirect group, evaluate the first position to get size + # All positions in the group should have the same size + first_pos = positions[0] + indirect_vals = self(idxs[first_pos]).flatten() + group_ranges.append(xp.arange(len(indirect_vals))) + # Update sizes for all positions in this group + for p in positions: + idx_sizes[p] = len(indirect_vals) + + # Compute cartesian product of group indices + # This gives us which "iteration" of each group we're on + group_grids = xp.meshgrid(*group_ranges, indexing='ij') + group_combos = xp.stack([g.flatten() for g in group_grids], axis=-1) + + # Now build the actual index combinations + combo_idxs = xp.empty((group_combos.shape[0], len(idxs)), dtype=xp.int64) + + for group_idx, (group_id, positions) in enumerate(zip(processed_groups, group_positions)): + group_iterations = group_combos[:, group_idx] + + # Fill in values for all positions in this group + for pos in positions: + if isinstance(idxs[pos], ein.Index): + # Direct index - use the iteration number + combo_idxs[:, pos] = group_iterations + else: + # Indirect access - evaluate and index with iteration + indirect_vals = self(idxs[pos]).flatten() + combo_idxs[:, pos] = indirect_vals[group_iterations] # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) @@ -218,7 +240,6 @@ def __call__(self, node): true_idxs = list(true_idxs) true_idxs = sorted(true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]]) - print(true_idxs) # calculate the final shape of the tensor # we merge the child idxs to get the final shape that matches the true idxs From 169cebf1ba20c2eca5eca257c83522491b05baaf Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 20:46:05 -0500 Subject: [PATCH 32/45] Refactor EinsumInterpreter to enhance index classification and grouping logic. Improved handling of parent indices for both ein.Index and indirect access, optimizing cartesian product computation and index size tracking. Streamlined the process of building index combinations for tensor evaluation. --- src/finchlite/finch_einsum/interpreter.py | 113 +++++++++++----------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 168fc4bf..7661de0b 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -138,80 +138,83 @@ def __call__(self, node): tns = self(tns) assert len(idxs) == len(tns.shape) - # Build a mapping of which positions belong to which true_idx groups - idx_groups = [] # List of (position, group_type, group_id) - idx_sizes = [] + # Classify indices and determine their parent groups + idx_sizes = [] # Track size of each index dimension + # For each index, determine which true_idx it depends on + parent_indices = [] for i, idx in enumerate(idxs): if isinstance(idx, ein.Index): - # Direct ein.Index - idx_groups.append((i, 'ein', idx)) + parent_indices.append(idx) idx_sizes.append(tns.shape[i]) else: - # Indirect access - find which true_idx it depends on - child_true_idxs = idx.get_idxs() - # Should depend on exactly one true_idx for this to work - if len(child_true_idxs) == 1: - parent_idx = list(child_true_idxs)[0] - idx_groups.append((i, 'indirect', parent_idx)) - else: - # Complex case - for now treat as independent - idx_groups.append((i, 'indirect_complex', idx)) - idx_sizes.append(None) # Will be determined after evaluation + child_idxs = idx.get_idxs() + parent_indices.append( + list(child_idxs)[0] if len(child_idxs) == 1 else idx + ) + idx_sizes.append(None) # Will be determined later - # Group positions by their parent true index - groups_by_parent = {} - for pos, group_type, group_id in idx_groups: - if group_id not in groups_by_parent: - groups_by_parent[group_id] = [] - groups_by_parent[group_id].append((pos, group_type)) + # Find unique parent groups and their positions + # Create a mapping from parent to group index + unique_parents = [] + parent_to_group = {} + for parent in parent_indices: + found = False + for i, up in enumerate(unique_parents): + # For ein.Index, compare by name; for others by identity + if isinstance(parent, ein.Index) and isinstance(up, ein.Index): + if parent.name == up.name: + parent_to_group[id(parent)] = i + found = True + break + elif parent is up: + parent_to_group[id(parent)] = i + found = True + break + if not found: + parent_to_group[id(parent)] = len(unique_parents) + unique_parents.append(parent) - # Build ranges for the cartesian product - # Each group contributes one dimension to the product - group_ranges = [] - group_positions = [] # Track which positions belong to each group + # Build group positions from the mapping + group_positions = [[] for _ in range(len(unique_parents))] + for i, parent in enumerate(parent_indices): + group_idx = parent_to_group[id(parent)] + group_positions[group_idx].append(i) - # Process groups in position order (based on first occurrence) - processed_groups = [] - for pos, group_type, group_id in idx_groups: - if group_id not in processed_groups: - processed_groups.append(group_id) - positions = [p for p, _ in groups_by_parent[group_id]] - group_positions.append(positions) - - if group_type == 'ein': - # For ein.Index, use the range - group_ranges.append(xp.arange(tns.shape[pos])) - else: - # For indirect group, evaluate the first position to get size - # All positions in the group should have the same size - first_pos = positions[0] - indirect_vals = self(idxs[first_pos]).flatten() - group_ranges.append(xp.arange(len(indirect_vals))) - # Update sizes for all positions in this group - for p in positions: - idx_sizes[p] = len(indirect_vals) + # Build ranges for cartesian product + group_ranges = [] + for group_idx, positions in enumerate(group_positions): + first_pos = positions[0] + if isinstance(idxs[first_pos], ein.Index): + group_ranges.append(xp.arange(tns.shape[first_pos])) + else: + indirect_vals = self(idxs[first_pos]).flatten() + group_ranges.append(xp.arange(len(indirect_vals))) + # Update sizes for all positions in this group + for p in positions: + idx_sizes[p] = len(indirect_vals) - # Compute cartesian product of group indices - # This gives us which "iteration" of each group we're on + # Compute cartesian product group_grids = xp.meshgrid(*group_ranges, indexing='ij') group_combos = xp.stack([g.flatten() for g in group_grids], axis=-1) - # Now build the actual index combinations + # Build index combinations combo_idxs = xp.empty((group_combos.shape[0], len(idxs)), dtype=xp.int64) - for group_idx, (group_id, positions) in enumerate(zip(processed_groups, group_positions)): + # Pre-evaluate all indirect indices + indirect_vals_cache = {} + for i, idx in enumerate(idxs): + if not isinstance(idx, ein.Index): + indirect_vals_cache[i] = self(idx).flatten() + + # Fill combo_idxs using vectorized operations where possible + for group_idx, positions in enumerate(group_positions): group_iterations = group_combos[:, group_idx] - - # Fill in values for all positions in this group for pos in positions: if isinstance(idxs[pos], ein.Index): - # Direct index - use the iteration number combo_idxs[:, pos] = group_iterations else: - # Indirect access - evaluate and index with iteration - indirect_vals = self(idxs[pos]).flatten() - combo_idxs[:, pos] = indirect_vals[group_iterations] + combo_idxs[:, pos] = indirect_vals_cache[pos][group_iterations] # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) From b2f81e67c296c881ff952d87838b8b2e1c115e0c Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 20:55:46 -0500 Subject: [PATCH 33/45] * Fixed ruff errors --- src/finchlite/finch_einsum/__init__.py | 4 +- src/finchlite/finch_einsum/interpreter.py | 134 ++-- src/finchlite/finch_einsum/nodes.py | 12 +- src/finchlite/tensor/__init__.py | 1 - tests/test_einsum.py | 767 +++++++++++++--------- 5 files changed, 541 insertions(+), 377 deletions(-) diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index e620e997..87f44ef2 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -6,8 +6,8 @@ Einsum, EinsumExpr, EinsumNode, - Index, GetAttribute, + Index, Literal, Plan, Produces, @@ -25,8 +25,8 @@ "EinsumNode", "EinsumScheduler", "EinsumScheduler", - "Index", "GetAttribute", + "Index", "Literal", "Plan", "Produces", diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 7661de0b..4d084392 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -2,10 +2,9 @@ import numpy as np -from ..algebra import Tensor, overwrite, promote_max, promote_min, TensorFType +from ..algebra import overwrite, promote_max, promote_min +from ..symbolic import ftype from . import nodes as ein -from ..symbolic import ftype, gensym - nary_ops = { operator.add: "add", @@ -100,47 +99,52 @@ def __call__(self, node): vals = [self(arg) for arg in args] return func(*vals) - #access a tensor with only indices - case ein.Access(tns, idxs) if all(isinstance(idx, ein.Index) for idx in idxs): + # access a tensor with only indices + case ein.Access(tns, idxs) if all( + isinstance(idx, ein.Index) for idx in idxs + ): assert len(idxs) == len(set(idxs)) assert self.loops is not None - #convert named idxs to positional, integer indices + # convert named idxs to positional, integer indices perm = [idxs.index(idx) for idx in self.loops if idx in idxs] - - tns = self(tns) #evaluate the tensor - #if there are fewer indicies than dimensions, add the remaining dimensions as if they werent permutated - if hasattr(tns, "ndim") and len(perm) < tns.ndim: - perm = perm + [i for i in range(len(perm), tns.ndim)] + tns = self(tns) # evaluate the tensor + + # if there are fewer indicies than dimensions, add the remaining + # dimensions as if they werent permutated + if hasattr(tns, "ndim") and len(perm) < tns.ndim: + perm += list(range(len(perm), tns.ndim)) - tns = xp.permute_dims(tns, perm) #permute the dimensions + tns = xp.permute_dims(tns, perm) # permute the dimensions return xp.expand_dims( tns, [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - #access a tensor with only one indirect access index + # access a tensor with only one indirect access index case ein.Access(tns, idxs) if len(idxs) == 1: idx = self(idxs[0]) - tns = self(tns) #evaluate the tensor - - flat_idx = idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape) - return tns.flat[flat_idx] #return a 1-d array by definition + tns = self(tns) # evaluate the tensor + + flat_idx = ( + idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape) + ) + return tns.flat[flat_idx] # return a 1-d array by definition - #access a tensor with a mixture of indices and other expressions + # access a tensor with a mixture of indices and other expressions case ein.Access(tns, idxs): assert self.loops is not None - true_idxs = node.get_idxs() #true field iteratior indicies + true_idxs = node.get_idxs() # true field iteratior indicies assert all(isinstance(idx, ein.Index) for idx in true_idxs) # evaluate the tensor to access tns = self(tns) assert len(idxs) == len(tns.shape) - + # Classify indices and determine their parent groups idx_sizes = [] # Track size of each index dimension - + # For each index, determine which true_idx it depends on parent_indices = [] for i, idx in enumerate(idxs): @@ -153,7 +157,7 @@ def __call__(self, node): list(child_idxs)[0] if len(child_idxs) == 1 else idx ) idx_sizes.append(None) # Will be determined later - + # Find unique parent groups and their positions # Create a mapping from parent to group index unique_parents = [] @@ -174,16 +178,16 @@ def __call__(self, node): if not found: parent_to_group[id(parent)] = len(unique_parents) unique_parents.append(parent) - + # Build group positions from the mapping group_positions = [[] for _ in range(len(unique_parents))] for i, parent in enumerate(parent_indices): group_idx = parent_to_group[id(parent)] group_positions[group_idx].append(i) - + # Build ranges for cartesian product group_ranges = [] - for group_idx, positions in enumerate(group_positions): + for positions in group_positions: first_pos = positions[0] if isinstance(idxs[first_pos], ein.Index): group_ranges.append(xp.arange(tns.shape[first_pos])) @@ -193,20 +197,22 @@ def __call__(self, node): # Update sizes for all positions in this group for p in positions: idx_sizes[p] = len(indirect_vals) - + # Compute cartesian product - group_grids = xp.meshgrid(*group_ranges, indexing='ij') + group_grids = xp.meshgrid(*group_ranges, indexing="ij") group_combos = xp.stack([g.flatten() for g in group_grids], axis=-1) - + # Build index combinations - combo_idxs = xp.empty((group_combos.shape[0], len(idxs)), dtype=xp.int64) - + combo_idxs = xp.empty( + (group_combos.shape[0], len(idxs)), dtype=xp.int64 + ) + # Pre-evaluate all indirect indices indirect_vals_cache = {} for i, idx in enumerate(idxs): if not isinstance(idx, ein.Index): indirect_vals_cache[i] = self(idx).flatten() - + # Fill combo_idxs using vectorized operations where possible for group_idx, positions in enumerate(group_positions): group_iterations = group_combos[:, group_idx] @@ -214,27 +220,36 @@ def __call__(self, node): if isinstance(idxs[pos], ein.Index): combo_idxs[:, pos] = group_iterations else: - combo_idxs[:, pos] = indirect_vals_cache[pos][group_iterations] - + combo_idxs[:, pos] = indirect_vals_cache[pos][ + group_iterations + ] + # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) tns = xp.take(tns, flat_idx) - #calculate child idxs, idxs computed using the parent "true idxs" - child_idxs = { + # calculate child idxs, idxs computed using the parent "true idxs" + child_idxs = { parent_idx: [ - child_idx for child_idx in idxs + child_idx + for child_idx in idxs if (parent_idx in child_idx.get_idxs()) - ] for parent_idx in true_idxs + ] + for parent_idx in true_idxs } - # we assert that all the indirect access indicies from the parent idxs have the same size + # we assert that all the indirect access indicies + # from the parent idxs have the same size child_idxs_size = { - parent_idx: [idx_sizes[idxs.index(child_idx)] for child_idx in child_idxs[parent_idx]] + parent_idx: [ + idx_sizes[idxs.index(child_idx)] + for child_idx in child_idxs[parent_idx] + ] for parent_idx in true_idxs } assert all( - child_idxs_size[parent_idx].count(child_idxs_size[parent_idx][0]) == len(child_idxs_size[parent_idx]) + child_idxs_size[parent_idx].count(child_idxs_size[parent_idx][0]) + == len(child_idxs_size[parent_idx]) for parent_idx in true_idxs ) @@ -242,22 +257,30 @@ def __call__(self, node): idxs_axis = {idx: i for i, idx in enumerate(idxs)} true_idxs = list(true_idxs) - true_idxs = sorted(true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]]) + true_idxs = sorted( + true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]] + ) # calculate the final shape of the tensor - # we merge the child idxs to get the final shape that matches the true idxs + # we merge the child idxs to get the final shape + # that matches the true idxs final_shape = tuple( - idx_sizes[idxs.index(child_idxs[parent_idx][0])] + idx_sizes[idxs.index(child_idxs[parent_idx][0])] for parent_idx in true_idxs ) tns = tns.reshape(final_shape) - # permute and broadcast the tensor to be compatible with rest of expression + # permute and broadcast the tensor to be + # compatible with rest of expression perm = [true_idxs.index(idx) for idx in self.loops if idx in true_idxs] tns = xp.permute_dims(tns, perm) return xp.expand_dims( - tns, - [i for i in range(len(self.loops)) if self.loops[i] not in true_idxs] + tns, + [ + i + for i in range(len(self.loops)) + if self.loops[i] not in true_idxs + ], ) case ein.Plan(bodies): @@ -268,21 +291,20 @@ def __call__(self, node): case ein.Produces(args): return tuple(self(arg) for arg in args) - #get non-zero elements/data array of a sparse tensor + # get non-zero elements/data array of a sparse tensor case ein.GetAttribute(obj, ein.Literal("elems"), _): obj = self(obj) assert isinstance(ftype(obj), SparseTensorFType) assert isinstance(obj, SparseTensor) - return obj.data - #get coord array of a sparse tensor + return obj.data + # get coord array of a sparse tensor case ein.GetAttribute(obj, ein.Literal("coords"), dim): obj = self(obj) assert isinstance(ftype(obj), SparseTensorFType) assert isinstance(obj, SparseTensor) # return the coord array for the given dimension or all dimensions - toReturn = obj.coords if dim is None else obj.coords[:, dim] - return toReturn + return obj.coords if dim is None else obj.coords[:, dim] # gets the shape of a sparse tensor at a given dimension case ein.GetAttribute(obj, ein.Literal("shape"), dim): obj = self(obj) @@ -294,10 +316,12 @@ def __call__(self, node): return obj.shape[dim] # standard einsum - case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(isinstance(idx, ein.Index) for idx in idxs): + case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all( + isinstance(idx, ein.Index) for idx in idxs + ): # This is the main entry point for einsum execution loops = arg.get_idxs() - + assert set(idxs).issubset(loops) loops = sorted(loops, key=lambda x: x.name) ctx = EinsumInterpreter(self.xp, self.bindings, loops) @@ -317,6 +341,8 @@ def __call__(self, node): # indirect einsum case ein.Einsum(op, ein.Alias(tns), idxs, arg): - raise NotImplementedError("Indirect einsum assignment is not implemented") + raise NotImplementedError( + "Indirect einsum assignment is not implemented" + ) case _: raise ValueError(f"Unknown einsum type: {type(node)}") diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index f626d02f..8777a941 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -1,7 +1,7 @@ import operator from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Self, cast +from typing import Any, Self, cast from finchlite.algebra import ( overwrite, @@ -183,14 +183,14 @@ class GetAttribute(EinsumExpr, EinsumTree): Attributes: obj: The object to get the attribute from. attr: The name of the attribute to get. - - dim: The dimension to get the attribute from. + + dim: The dimension to get the attribute from. Note this is an integer index, not a named index. """ obj: EinsumExpr attr: Literal - dim: Optional[int] + dim: int | None @classmethod def from_children(cls, *children: Term) -> Self: @@ -199,7 +199,7 @@ def from_children(cls, *children: Term) -> Self: raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") obj = cast(EinsumExpr, children[0]) attr = cast(Literal, children[1]) - dim = cast(Optional[int], children[2]) + dim = cast(int | None, children[2]) return cls(obj, attr, dim) @property @@ -251,7 +251,7 @@ def children(self): return [self.op, self.tns, self.idxs, self.arg] def get_idxs(self) -> set["Index"]: - idxs = list() + idxs = [] for idx in self.idxs: idxs.extend(idx.get_idxs()) return idxs diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index b950cb1a..916e5602 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -9,7 +9,6 @@ ) from .sparse_tensor import SparseTensor, SparseTensorFType - __all__ = [ "DenseLevel", "DenseLevelFType", diff --git a/tests/test_einsum.py b/tests/test_einsum.py index df446d86..680f0e13 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1,13 +1,13 @@ +import operator from typing import Any + import pytest import numpy as np import finchlite - import finchlite.finch_einsum as ein from finchlite.algebra import overwrite -import operator from finchlite.tensor import SparseTensor @@ -1134,13 +1134,17 @@ def test_complex_operations(self, rng, dtype): class TestEinsumIndirectAccess: """Test einsum with indirect access""" - def run_einsum_plan(self, prgm: ein.Plan, bindings: dict[str, Any], expected: np.ndarray): + def run_einsum_plan( + self, prgm: ein.Plan, bindings: dict[str, Any], expected: np.ndarray + ): """Runs an einsum plan and returns the result""" interpreter = ein.EinsumInterpreter(bindings=bindings) result = interpreter(prgm)[0] - import numpy as np import sys + + import numpy as np + np.set_printoptions(threshold=sys.maxsize) print(result) @@ -1158,32 +1162,42 @@ def test_indirect_elementwise_multiplication(self, rng): # A is sparse # C[i] = AElems[i] * B[ACoords[i]] - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.mul), - args=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), - ), - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) result = finchlite.multiply(A, B).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) @@ -1197,32 +1211,42 @@ def test_indirect_elementwise_addition(self, rng): sparse_A = SparseTensor.from_dense_tensor(A) # C[i] = AElems[i] + B[ACoords[i]] - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.add), - args=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), - ), - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) result = (A + B).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) @@ -1238,27 +1262,37 @@ def test_indirect_multiple_reads(self, rng): # C[i] = AElems[i] * BElems[i] # Both A and B are sparse, reading their elements directly - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.mul), - args=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), - ), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) result = finchlite.multiply(A, B).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": sparse_B}, result) @@ -1272,38 +1306,48 @@ def test_indirect_with_constant(self, rng): sparse_A = SparseTensor.from_dense_tensor(A) # C[i] = AElems[i] * B[ACoords[i]] + 5.0 - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.add), - args=( - ein.Call( - op=ein.Literal(operator.mul), - args=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), - ), - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), + ein.Literal(5.0), ), - ein.Literal(5.0), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) result = (A * B + 5.0).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) @@ -1318,46 +1362,60 @@ def test_indirect_nested_operations(self, rng): sparse_A = SparseTensor.from_dense_tensor(A) # D[i] = (AElems[i] + B[ACoords[i]]) * C[ACoords[i]] - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("D"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.mul), - args=( - ein.Call( - op=ein.Literal(operator.add), - args=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), - ), - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), - ), - ein.Access( - tns=ein.Alias("C"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), ), - ), - ein.Produces((ein.Alias("D"),)), - )) + ein.Produces((ein.Alias("D"),)), + ) + ) result = ((A + B) * C).flatten() self.run_einsum_plan(prgm, {"A": sparse_A, "B": B, "C": C}, result) @@ -1371,23 +1429,29 @@ def test_indirect_direct_access_only(self, rng): sparse_A = SparseTensor.from_dense_tensor(A) # C[i] = B[ACoords[i]] (read B indirectly, without using A's elements) - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) # Result should be B's values at A's coordinates expected = B[A != 0] @@ -1403,35 +1467,42 @@ def test_double_indirection(self, rng): rng.shuffle(B) # C is sparse C = rng.random((8,)) - + sparse_C = SparseTensor.from_dense_tensor(C) # D[i] = A[B[CCoords[i]]] # First get CCoords[i], then use that to index B, then use B's value to index A - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("D"), - idxs=(ein.Index("i"),), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), ), - ), - ein.Produces((ein.Alias("D"),)), - )) + ein.Produces((ein.Alias("D"),)), + ) + ) - # Expected: for each non-zero position in C, get its coord, index into B, then index into A + # Expected: for each non-zero position in C, get its coord, + # index into B, then index into A c_coords = sparse_C.coords expected = A[B[c_coords]].flatten() self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) @@ -1446,27 +1517,33 @@ def test_triple_indirection(self, rng): C = np.array([0, 1, 2, 3, 4, 5, 6, 7]) rng.shuffle(C) D = rng.random((8,)) - + sparse_D = SparseTensor.from_dense_tensor(D) # E[i] = A[B[C[DCoords[i]]]] - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("E"), - idxs=(ein.Index("i"),), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.Alias("C"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("D"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("D"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), @@ -1474,9 +1551,9 @@ def test_triple_indirection(self, rng): ), ), ), - ), - ein.Produces((ein.Alias("E"),)), - )) + ein.Produces((ein.Alias("E"),)), + ) + ) # Expected: chain of indirections d_coords = sparse_D.coords @@ -1488,29 +1565,35 @@ def test_mixed_direct_indirect_indexing_2d(self, rng): A = rng.random((5, 4)) B = rng.random((5,)) - + sparse_B = SparseTensor.from_dense_tensor(B) # C[i, j] = A[BCoords[i], j] # First index is indirect (from B's coords), second is direct - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"), ein.Index("j")), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Index("j"), ), - ein.Index("j"), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) # Expected: A rows indexed by B's coords, all columns b_coords = sparse_B.coords @@ -1518,33 +1601,42 @@ def test_mixed_direct_indirect_indexing_2d(self, rng): self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) def test_mixed_direct_indirect_indexing_reversed(self, rng): - """Test mixed indexing reversed: A[i, BCoords[j]] - first direct, second indirect""" + """ + Test mixed indexing reversed + A[i, BCoords[j]] - first direct, second indirect + """ A = rng.random((4, 6)) B = rng.random((6,)) - + sparse_B = SparseTensor.from_dense_tensor(B) # C[i, j] = A[i, BCoords[j]] # First index is direct, second is indirect - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"), ein.Index("j")), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Index("i"), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("j"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("j"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) # Expected: all rows of A, columns indexed by B's coords b_coords = sparse_B.coords @@ -1556,32 +1648,42 @@ def test_both_indices_indirect_same_source(self, rng): A = rng.random((6, 6)) B = rng.random((6,)) - + sparse_B = SparseTensor.from_dense_tensor(B) # C[i] = A[BCoords[i], BCoords[i]] # Extracting diagonal-like elements using indirect coordinates - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"),), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) - ), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) # Expected: A[coords, coords] - pseudo-diagonal at indirect positions b_coords = sparse_B.coords @@ -1589,37 +1691,50 @@ def test_both_indices_indirect_same_source(self, rng): self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) def test_both_indices_indirect_different_sources(self, rng): - """Test both indices indirect from different sources: A[BCoords[i], CCoords[i]]""" + """ + Test both indices indirect from different sources: + A[BCoords[i], CCoords[i]] + """ A = rng.random((6, 6)) B = rng.random((6,)) C = rng.random((6,)) - + sparse_B = SparseTensor.from_dense_tensor(B) sparse_C = SparseTensor.from_dense_tensor(C) # D[i] = A[BCoords[i], CCoords[i]] - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("D"), - idxs=(ein.Index("i"),), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) - ), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("D"),)), - )) + ein.Produces((ein.Alias("D"),)), + ) + ) # Expected: A indexed by pairs of coords from B and C # This requires both to have the same number of non-zero elements @@ -1635,42 +1750,52 @@ def test_double_indirection_with_operation(self, rng): B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) rng.shuffle(B) C = rng.random((8,)) - + sparse_C = SparseTensor.from_dense_tensor(C) # E[i] = A[B[CCoords[i]]] * CElems[i] # Double indirection plus multiplication with sparse elements - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("E"), - idxs=(ein.Index("i"),), - arg=ein.Call( - op=ein.Literal(operator.mul), - args=( - ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.Alias("B"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), ), - ), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("C"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("C"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("E"),)), - )) + ein.Produces((ein.Alias("E"),)), + ) + ) c_coords = sparse_C.coords.flatten() c_elems = sparse_C.data @@ -1683,43 +1808,51 @@ def test_mixed_indexing_with_computation(self, rng): A = rng.random((5, 5)) B = rng.random((4,)) - C = rng.random((5,)) - + sparse_B = SparseTensor.from_dense_tensor(B) # D[i, j] = A[BCoords[i], j] + BElems[i] # Mixed indexing plus addition with sparse elements - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("D"), - idxs=(ein.Index("i"), ein.Index("j")), - arg=ein.Call( - op=ein.Literal(operator.add), - args=( - ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("i"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Index("j"), ), - ein.Index("j"), ), - ), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None), - idxs=(ein.Index("i"),), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), ), ), ), - ), - ein.Produces((ein.Alias("D"),)), - )) + ein.Produces((ein.Alias("D"),)), + ) + ) b_coords = sparse_B.coords.flatten() b_elems = sparse_B.data - # Broadcasting: A[coords, :] has shape (len(coords), 5), b_elems has shape (len(coords),) expected = A[b_coords, :] + b_elems[:, np.newaxis] self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) @@ -1728,31 +1861,37 @@ def test_indirect_3d_tensor_access(self, rng): A = rng.random((3, 4, 5)) B = rng.random((4,)) - + sparse_B = SparseTensor.from_dense_tensor(B) # C[i, j, k] = A[i, BCoords[j], k] # Middle dimension is indirectly indexed - prgm = ein.Plan(( - ein.Einsum( - op=ein.Literal(overwrite), - tns=ein.Alias("C"), - idxs=(ein.Index("i"), ein.Index("j"), ein.Index("k")), - arg=ein.Access( - tns=ein.Alias("A"), - idxs=( - ein.Index("i"), - ein.Access( - tns=ein.GetAttribute(obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None), - idxs=(ein.Index("j"),) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j"), ein.Index("k")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttribute( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("j"),), + ), + ein.Index("k"), ), - ein.Index("k"), ), ), - ), - ein.Produces((ein.Alias("C"),)), - )) + ein.Produces((ein.Alias("C"),)), + ) + ) b_coords = sparse_B.coords.flatten() expected = A[:, b_coords, :] - self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) \ No newline at end of file + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) From c31f7afd2e47d0f9b3553cd3ca09ee6600306e64 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 20:58:07 -0500 Subject: [PATCH 34/45] * Fixed mypy issues --- src/finchlite/finch_einsum/nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 8777a941..f8663603 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -251,9 +251,9 @@ def children(self): return [self.op, self.tns, self.idxs, self.arg] def get_idxs(self) -> set["Index"]: - idxs = [] + idxs = set() for idx in self.idxs: - idxs.extend(idx.get_idxs()) + idxs.update(idx.get_idxs()) return idxs @@ -378,7 +378,7 @@ def __call__(self, prgm: EinsumNode): return f"{self(fn)}({', '.join(args_e)})" case GetAttribute(obj, attr, idx): if idx is not None: - return f"{self(obj)}.{self(attr)}[{self(idx)}]" + return f"{self(obj)}.{self(attr)}[{idx}]" return f"{self(obj)}.{self(attr)}" case Einsum(op, tns, idxs, arg): op_str = infix_strs.get(op.val, op.val.__name__) From d6f2d2531c7b89e1e7cde5296f351cac51bd1397 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 11 Nov 2025 21:01:45 -0500 Subject: [PATCH 35/45] * Renamed GetAttribute to GetAttr to be consistent with FinchAssembly, FinchNotation --- src/finchlite/finch_einsum/__init__.py | 4 +- src/finchlite/finch_einsum/interpreter.py | 6 +-- src/finchlite/finch_einsum/nodes.py | 4 +- tests/test_einsum.py | 50 +++++++++++------------ 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index 87f44ef2..58b0b538 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -6,7 +6,7 @@ Einsum, EinsumExpr, EinsumNode, - GetAttribute, + GetAttr, Index, Literal, Plan, @@ -25,7 +25,7 @@ "EinsumNode", "EinsumScheduler", "EinsumScheduler", - "GetAttribute", + "GetAttr", "Index", "Literal", "Plan", diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 4d084392..5053c065 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -292,13 +292,13 @@ def __call__(self, node): return tuple(self(arg) for arg in args) # get non-zero elements/data array of a sparse tensor - case ein.GetAttribute(obj, ein.Literal("elems"), _): + case ein.GetAttr(obj, ein.Literal("elems"), _): obj = self(obj) assert isinstance(ftype(obj), SparseTensorFType) assert isinstance(obj, SparseTensor) return obj.data # get coord array of a sparse tensor - case ein.GetAttribute(obj, ein.Literal("coords"), dim): + case ein.GetAttr(obj, ein.Literal("coords"), dim): obj = self(obj) assert isinstance(ftype(obj), SparseTensorFType) assert isinstance(obj, SparseTensor) @@ -306,7 +306,7 @@ def __call__(self, node): # return the coord array for the given dimension or all dimensions return obj.coords if dim is None else obj.coords[:, dim] # gets the shape of a sparse tensor at a given dimension - case ein.GetAttribute(obj, ein.Literal("shape"), dim): + case ein.GetAttr(obj, ein.Literal("shape"), dim): obj = self(obj) assert isinstance(ftype(obj), SparseTensorFType) assert isinstance(obj, SparseTensor) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index f8663603..2496e8fd 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -177,7 +177,7 @@ def get_idxs(self) -> set["Index"]: @dataclass(eq=True, frozen=True) -class GetAttribute(EinsumExpr, EinsumTree): +class GetAttr(EinsumExpr, EinsumTree): """ Gets an attribute of a tensor. Attributes: @@ -376,7 +376,7 @@ def __call__(self, prgm: EinsumNode): if len(args) == 1 and fn.val in unary_strs: return f"{unary_strs[fn.val]}{args_e[0]}" return f"{self(fn)}({', '.join(args_e)})" - case GetAttribute(obj, attr, idx): + case GetAttr(obj, attr, idx): if idx is not None: return f"{self(obj)}.{self(attr)}[{idx}]" return f"{self(obj)}.{self(attr)}" diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 680f0e13..3e85f1aa 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1172,7 +1172,7 @@ def test_indirect_elementwise_multiplication(self, rng): op=ein.Literal(operator.mul), args=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None, @@ -1183,7 +1183,7 @@ def test_indirect_elementwise_multiplication(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1221,7 +1221,7 @@ def test_indirect_elementwise_addition(self, rng): op=ein.Literal(operator.add), args=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None, @@ -1232,7 +1232,7 @@ def test_indirect_elementwise_addition(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1272,7 +1272,7 @@ def test_indirect_multiple_reads(self, rng): op=ein.Literal(operator.mul), args=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None, @@ -1280,7 +1280,7 @@ def test_indirect_multiple_reads(self, rng): idxs=(ein.Index("i"),), ), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None, @@ -1319,7 +1319,7 @@ def test_indirect_with_constant(self, rng): op=ein.Literal(operator.mul), args=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None, @@ -1330,7 +1330,7 @@ def test_indirect_with_constant(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1375,7 +1375,7 @@ def test_indirect_nested_operations(self, rng): op=ein.Literal(operator.add), args=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("elems"), dim=None, @@ -1386,7 +1386,7 @@ def test_indirect_nested_operations(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1401,7 +1401,7 @@ def test_indirect_nested_operations(self, rng): tns=ein.Alias("C"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1439,7 +1439,7 @@ def test_indirect_direct_access_only(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("A"), attr=ein.Literal("coords"), dim=None, @@ -1485,7 +1485,7 @@ def test_double_indirection(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None, @@ -1537,7 +1537,7 @@ def test_triple_indirection(self, rng): tns=ein.Alias("C"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("D"), attr=ein.Literal("coords"), dim=None, @@ -1580,7 +1580,7 @@ def test_mixed_direct_indirect_indexing_2d(self, rng): tns=ein.Alias("A"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1624,7 +1624,7 @@ def test_mixed_direct_indirect_indexing_reversed(self, rng): idxs=( ein.Index("i"), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1663,7 +1663,7 @@ def test_both_indices_indirect_same_source(self, rng): tns=ein.Alias("A"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1671,7 +1671,7 @@ def test_both_indices_indirect_same_source(self, rng): idxs=(ein.Index("i"),), ), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1714,7 +1714,7 @@ def test_both_indices_indirect_different_sources(self, rng): tns=ein.Alias("A"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1722,7 +1722,7 @@ def test_both_indices_indirect_different_sources(self, rng): idxs=(ein.Index("i"),), ), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None, @@ -1771,7 +1771,7 @@ def test_double_indirection_with_operation(self, rng): tns=ein.Alias("B"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("C"), attr=ein.Literal("coords"), dim=None, @@ -1783,7 +1783,7 @@ def test_double_indirection_with_operation(self, rng): ), ), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("C"), attr=ein.Literal("elems"), dim=None, @@ -1826,7 +1826,7 @@ def test_mixed_indexing_with_computation(self, rng): tns=ein.Alias("A"), idxs=( ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, @@ -1837,7 +1837,7 @@ def test_mixed_indexing_with_computation(self, rng): ), ), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("elems"), dim=None, @@ -1877,7 +1877,7 @@ def test_indirect_3d_tensor_access(self, rng): idxs=( ein.Index("i"), ein.Access( - tns=ein.GetAttribute( + tns=ein.GetAttr( obj=ein.Alias("B"), attr=ein.Literal("coords"), dim=None, From b59cc0a856e182dc2ad44c24eb594044f4175512 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Wed, 12 Nov 2025 11:50:22 -0500 Subject: [PATCH 36/45] Refactor EinsumInterpreter to streamline index evaluation and grouping logic. Consolidated steps for handling unique parent indices and optimized meshgrid creation for index combinations. Enhanced final shape calculation and tensor permutation to ensure accurate evaluation of tensor expressions. --- src/finchlite/finch_einsum/interpreter.py | 152 +++++----------------- 1 file changed, 29 insertions(+), 123 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 5053c065..9677d193 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -142,144 +142,50 @@ def __call__(self, node): tns = self(tns) assert len(idxs) == len(tns.shape) - # Classify indices and determine their parent groups - idx_sizes = [] # Track size of each index dimension - - # For each index, determine which true_idx it depends on - parent_indices = [] - for i, idx in enumerate(idxs): - if isinstance(idx, ein.Index): - parent_indices.append(idx) - idx_sizes.append(tns.shape[i]) - else: - child_idxs = idx.get_idxs() - parent_indices.append( - list(child_idxs)[0] if len(child_idxs) == 1 else idx - ) - idx_sizes.append(None) # Will be determined later - - # Find unique parent groups and their positions - # Create a mapping from parent to group index - unique_parents = [] - parent_to_group = {} - for parent in parent_indices: - found = False - for i, up in enumerate(unique_parents): - # For ein.Index, compare by name; for others by identity - if isinstance(parent, ein.Index) and isinstance(up, ein.Index): - if parent.name == up.name: - parent_to_group[id(parent)] = i - found = True - break - elif parent is up: - parent_to_group[id(parent)] = i - found = True - break - if not found: - parent_to_group[id(parent)] = len(unique_parents) - unique_parents.append(parent) - - # Build group positions from the mapping - group_positions = [[] for _ in range(len(unique_parents))] - for i, parent in enumerate(parent_indices): - group_idx = parent_to_group[id(parent)] - group_positions[group_idx].append(i) - - # Build ranges for cartesian product - group_ranges = [] - for positions in group_positions: - first_pos = positions[0] - if isinstance(idxs[first_pos], ein.Index): - group_ranges.append(xp.arange(tns.shape[first_pos])) - else: - indirect_vals = self(idxs[first_pos]).flatten() - group_ranges.append(xp.arange(len(indirect_vals))) - # Update sizes for all positions in this group - for p in positions: - idx_sizes[p] = len(indirect_vals) - - # Compute cartesian product - group_grids = xp.meshgrid(*group_ranges, indexing="ij") - group_combos = xp.stack([g.flatten() for g in group_grids], axis=-1) - - # Build index combinations - combo_idxs = xp.empty( - (group_combos.shape[0], len(idxs)), dtype=xp.int64 - ) - - # Pre-evaluate all indirect indices - indirect_vals_cache = {} - for i, idx in enumerate(idxs): - if not isinstance(idx, ein.Index): - indirect_vals_cache[i] = self(idx).flatten() - - # Fill combo_idxs using vectorized operations where possible - for group_idx, positions in enumerate(group_positions): - group_iterations = group_combos[:, group_idx] - for pos in positions: - if isinstance(idxs[pos], ein.Index): - combo_idxs[:, pos] = group_iterations - else: - combo_idxs[:, pos] = indirect_vals_cache[pos][ - group_iterations - ] + # Evaluate all indices into arrays + idx_arrays = [xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) + else self(idx).flatten() for i, idx in enumerate(idxs)] + + # Identify unique parent indices and their positions + parents = [idx if isinstance(idx, ein.Index) else list(idx.get_idxs())[0] + for idx in idxs] + unique_parents = list(dict.fromkeys(parents)) + parent_to_idx = {p: i for i, p in enumerate(unique_parents)} + pos_groups = [[i for i, p in enumerate(parents) if p == up] + for up in unique_parents] + + # Create meshgrid only for unique groups + unique_arrays = [idx_arrays[g[0]] for g in pos_groups] + grids = xp.meshgrid(*unique_arrays, indexing='ij') + grid_flat = xp.stack([g.ravel() for g in grids], axis=-1) + + # Build final index combinations + group_idx_map = [parent_to_idx[p] for p in parents] + combo_idxs = xp.stack([idx_arrays[i][grid_flat[:, group_idx_map[i]]] + for i in range(len(idxs))], axis=1) # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) tns = xp.take(tns, flat_idx) - # calculate child idxs, idxs computed using the parent "true idxs" - child_idxs = { - parent_idx: [ - child_idx - for child_idx in idxs - if (parent_idx in child_idx.get_idxs()) - ] - for parent_idx in true_idxs - } - - # we assert that all the indirect access indicies - # from the parent idxs have the same size - child_idxs_size = { - parent_idx: [ - idx_sizes[idxs.index(child_idx)] - for child_idx in child_idxs[parent_idx] - ] - for parent_idx in true_idxs - } - assert all( - child_idxs_size[parent_idx].count(child_idxs_size[parent_idx][0]) - == len(child_idxs_size[parent_idx]) - for parent_idx in true_idxs - ) - - # a mapping from each idx to its axis wrt to current shape - idxs_axis = {idx: i for i, idx in enumerate(idxs)} - - true_idxs = list(true_idxs) - true_idxs = sorted( - true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]] - ) - - # calculate the final shape of the tensor - # we merge the child idxs to get the final shape - # that matches the true idxs - final_shape = tuple( - idx_sizes[idxs.index(child_idxs[parent_idx][0])] - for parent_idx in true_idxs - ) + # Calculate final shape and permutation + true_idx_pos = {idx: pos_groups[unique_parents.index(idx)][0] + for idx in true_idxs} + true_idxs_sorted = sorted(true_idxs, key=true_idx_pos.get) + final_shape = tuple(len(unique_arrays[unique_parents.index(idx)]) + for idx in true_idxs_sorted) tns = tns.reshape(final_shape) # permute and broadcast the tensor to be # compatible with rest of expression - perm = [true_idxs.index(idx) for idx in self.loops if idx in true_idxs] + perm = [true_idxs_sorted.index(idx) for idx in self.loops if idx in true_idxs_sorted] tns = xp.permute_dims(tns, perm) return xp.expand_dims( tns, [ i for i in range(len(self.loops)) - if self.loops[i] not in true_idxs + if self.loops[i] not in true_idxs_sorted ], ) From 5d0b80b0a8e97b571206dee18bbaec1aa4b6635b Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Wed, 12 Nov 2025 11:52:51 -0500 Subject: [PATCH 37/45] * Fixed ruff errors --- src/finchlite/finch_einsum/interpreter.py | 54 ++++++++++++++++------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 9677d193..b1ca5b6c 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -143,42 +143,62 @@ def __call__(self, node): assert len(idxs) == len(tns.shape) # Evaluate all indices into arrays - idx_arrays = [xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) - else self(idx).flatten() for i, idx in enumerate(idxs)] - + idx_arrays = [ + xp.arange(tns.shape[i]) + if isinstance(idx, ein.Index) + else self(idx).flatten() + for i, idx in enumerate(idxs) + ] + # Identify unique parent indices and their positions - parents = [idx if isinstance(idx, ein.Index) else list(idx.get_idxs())[0] - for idx in idxs] + parents = [ + idx if isinstance(idx, ein.Index) else list(idx.get_idxs())[0] + for idx in idxs + ] unique_parents = list(dict.fromkeys(parents)) parent_to_idx = {p: i for i, p in enumerate(unique_parents)} - pos_groups = [[i for i, p in enumerate(parents) if p == up] - for up in unique_parents] - + pos_groups = [ + [i for i, p in enumerate(parents) if p == up] + for up in unique_parents + ] + # Create meshgrid only for unique groups unique_arrays = [idx_arrays[g[0]] for g in pos_groups] - grids = xp.meshgrid(*unique_arrays, indexing='ij') + grids = xp.meshgrid(*unique_arrays, indexing="ij") grid_flat = xp.stack([g.ravel() for g in grids], axis=-1) - + # Build final index combinations group_idx_map = [parent_to_idx[p] for p in parents] - combo_idxs = xp.stack([idx_arrays[i][grid_flat[:, group_idx_map[i]]] - for i in range(len(idxs))], axis=1) + combo_idxs = xp.stack( + [ + idx_arrays[i][grid_flat[:, group_idx_map[i]]] + for i in range(len(idxs)) + ], + axis=1, + ) # evaluate the output tensor as a flat array flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) tns = xp.take(tns, flat_idx) # Calculate final shape and permutation - true_idx_pos = {idx: pos_groups[unique_parents.index(idx)][0] - for idx in true_idxs} + true_idx_pos = { + idx: pos_groups[unique_parents.index(idx)][0] for idx in true_idxs + } true_idxs_sorted = sorted(true_idxs, key=true_idx_pos.get) - final_shape = tuple(len(unique_arrays[unique_parents.index(idx)]) - for idx in true_idxs_sorted) + final_shape = tuple( + len(unique_arrays[unique_parents.index(idx)]) + for idx in true_idxs_sorted + ) tns = tns.reshape(final_shape) # permute and broadcast the tensor to be # compatible with rest of expression - perm = [true_idxs_sorted.index(idx) for idx in self.loops if idx in true_idxs_sorted] + perm = [ + true_idxs_sorted.index(idx) + for idx in self.loops + if idx in true_idxs_sorted + ] tns = xp.permute_dims(tns, perm) return xp.expand_dims( tns, From 64bfe6e53de54f5862c83026149cea4fb7eb0ea6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 22 Nov 2025 11:29:32 -0500 Subject: [PATCH 38/45] * Fixed minor issues in einsum node.py to adjust for new base einsum types --- src/finchlite/finch_einsum/nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index 3d90d9e7..fc7ea306 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -187,7 +187,7 @@ def get_idxs(self) -> set["Index"]: @dataclass(eq=True, frozen=True) -class GetAttr(EinsumExpr, EinsumTree): +class GetAttr(EinsumExpression, EinsumTree): """ Gets an attribute of a tensor. Attributes: @@ -198,7 +198,7 @@ class GetAttr(EinsumExpr, EinsumTree): Note this is an integer index, not a named index. """ - obj: EinsumExpr + obj: EinsumExpression attr: Literal dim: int | None @@ -207,7 +207,7 @@ def from_children(cls, *children: Term) -> Self: # Expects 3 children: obj, attr, idx if len(children) != 3: raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") - obj = cast(EinsumExpr, children[0]) + obj = cast(EinsumExpression, children[0]) attr = cast(Literal, children[1]) dim = cast(int | None, children[2]) return cls(obj, attr, dim) From 8e69d16d36c1c6de738f86ca8e7320b4abdf6de8 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 22 Nov 2025 18:12:36 -0500 Subject: [PATCH 39/45] * Changed indirect access implementation to a recursive approach --- src/finchlite/finch_einsum/interpreter.py | 86 +++++------------------ 1 file changed, 18 insertions(+), 68 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index b1ca5b6c..a928f856 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -133,81 +133,31 @@ def __call__(self, node): return tns.flat[flat_idx] # return a 1-d array by definition # access a tensor with a mixture of indices and other expressions + # note evalaution order is not standard left to right, but rather + # tns, idx0, tns, idx1, tns, idx2, ... case ein.Access(tns, idxs): assert self.loops is not None - true_idxs = node.get_idxs() # true field iteratior indicies - assert all(isinstance(idx, ein.Index) for idx in true_idxs) - - # evaluate the tensor to access + tns = self(tns) - assert len(idxs) == len(tns.shape) - - # Evaluate all indices into arrays - idx_arrays = [ - xp.arange(tns.shape[i]) - if isinstance(idx, ein.Index) - else self(idx).flatten() - for i, idx in enumerate(idxs) - ] + current_idx = idxs[0] - # Identify unique parent indices and their positions - parents = [ - idx if isinstance(idx, ein.Index) else list(idx.get_idxs())[0] - for idx in idxs - ] - unique_parents = list(dict.fromkeys(parents)) - parent_to_idx = {p: i for i, p in enumerate(unique_parents)} - pos_groups = [ - [i for i, p in enumerate(parents) if p == up] - for up in unique_parents - ] + if not isinstance(current_idx, ein.Index): + current_idx = self(current_idx) + tns = xp.take(tns, current_idx, axis=0) + + # rotate current axis to the end + tns = xp.moveaxis(tns, 0, -1) - # Create meshgrid only for unique groups - unique_arrays = [idx_arrays[g[0]] for g in pos_groups] - grids = xp.meshgrid(*unique_arrays, indexing="ij") - grid_flat = xp.stack([g.ravel() for g in grids], axis=-1) + remaining_idxs = idxs[1:] + if len(remaining_idxs) > 0: + new_access = ein.Access(tns, remaining_idxs) + tns = self(new_access) - # Build final index combinations - group_idx_map = [parent_to_idx[p] for p in parents] - combo_idxs = xp.stack( - [ - idx_arrays[i][grid_flat[:, group_idx_map[i]]] - for i in range(len(idxs)) - ], - axis=1, - ) + # rearrange the axis to conform with self.loops + if len(idxs) == tns.ndim: + pass - # evaluate the output tensor as a flat array - flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape) - tns = xp.take(tns, flat_idx) - - # Calculate final shape and permutation - true_idx_pos = { - idx: pos_groups[unique_parents.index(idx)][0] for idx in true_idxs - } - true_idxs_sorted = sorted(true_idxs, key=true_idx_pos.get) - final_shape = tuple( - len(unique_arrays[unique_parents.index(idx)]) - for idx in true_idxs_sorted - ) - tns = tns.reshape(final_shape) - - # permute and broadcast the tensor to be - # compatible with rest of expression - perm = [ - true_idxs_sorted.index(idx) - for idx in self.loops - if idx in true_idxs_sorted - ] - tns = xp.permute_dims(tns, perm) - return xp.expand_dims( - tns, - [ - i - for i in range(len(self.loops)) - if self.loops[i] not in true_idxs_sorted - ], - ) + return tns case ein.Plan(bodies): res = None From 0e527c7d6b8ca6398f8e4faabd44b45a0d280469 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 22 Nov 2025 21:18:03 -0500 Subject: [PATCH 40/45] Rewrote ein.Access implementation to access tensors with a mixture of indices and other expressions --- src/finchlite/finch_einsum/interpreter.py | 61 +++++++++++++++-------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index a928f856..213aa51e 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -133,31 +133,52 @@ def __call__(self, node): return tns.flat[flat_idx] # return a 1-d array by definition # access a tensor with a mixture of indices and other expressions - # note evalaution order is not standard left to right, but rather - # tns, idx0, tns, idx1, tns, idx2, ... case ein.Access(tns, idxs): assert self.loops is not None tns = self(tns) - current_idx = idxs[0] - - if not isinstance(current_idx, ein.Index): - current_idx = self(current_idx) - tns = xp.take(tns, current_idx, axis=0) + indirect_idxs = [ + idx for idx in idxs + if not isinstance(idx, ein.Index) + ] + if len(indirect_idxs) == 0: + return tns + + start_index = idxs.index(indirect_idxs[0]) + iterator_idxs = indirect_idxs[0].get_idxs() + assert len(iterator_idxs) == 1 + + current_idxs = [ + idx for idx in idxs[start_index:] + if idx.get_idxs().issubset(iterator_idxs) + ] + + evaled_idxs = [ + xp.arange(tns.shape[idxs.index(idx)]) + if isinstance(idx, ein.Index) else self(idx) + for idx in current_idxs + ] + + # move the axis to access tns with the evaled idxs + target_axes = [idxs.index(idx) for idx in current_idxs] + dest_axes = [i for i in range(len(current_idxs))] + tns = xp.moveaxis(tns, target_axes, dest_axes) - # rotate current axis to the end - tns = xp.moveaxis(tns, 0, -1) - - remaining_idxs = idxs[1:] - if len(remaining_idxs) > 0: - new_access = ein.Access(tns, remaining_idxs) - tns = self(new_access) - - # rearrange the axis to conform with self.loops - if len(idxs) == tns.ndim: - pass - - return tns + # access the tensor with the evaled idxs + tns = tns[tuple(evaled_idxs)] + + # restore original tensor axis order + tns = xp.moveaxis(tns, source=0, destination=target_axes[0]) + + # we recursiveley call the interpreter with the remaining idxs + new_idxs = [ + iterator_idxs[0] if idx in current_idxs else idx + for idx in idxs + if idx not in current_idxs[1:] + ] + + new_access = ein.Access(tns, new_idxs) + return self(new_access) case ein.Plan(bodies): res = None From f163f65e72fd0c86c0225a415cd8a39efa01cd48 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sat, 22 Nov 2025 23:59:50 -0500 Subject: [PATCH 41/45] * Added two tests to einsum tests * Fixed bugs in new implementation for indirect einsum access --- .gitignore | 3 + src/finchlite/finch_einsum/interpreter.py | 72 ++++++++++++----------- tests/test_einsum.py | 19 +++++- 3 files changed, 56 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index ef712a26..ddc5c234 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,6 @@ cython_debug/ .pixi/ *.egg-info pixi.lock + +# vscode debugging configurations +.vscode/ \ No newline at end of file diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 213aa51e..e07233b8 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -99,31 +99,8 @@ def __call__(self, node): vals = [self(arg) for arg in args] return func(*vals) - # access a tensor with only indices - case ein.Access(tns, idxs) if all( - isinstance(idx, ein.Index) for idx in idxs - ): - assert len(idxs) == len(set(idxs)) - assert self.loops is not None - - # convert named idxs to positional, integer indices - perm = [idxs.index(idx) for idx in self.loops if idx in idxs] - - tns = self(tns) # evaluate the tensor - - # if there are fewer indicies than dimensions, add the remaining - # dimensions as if they werent permutated - if hasattr(tns, "ndim") and len(perm) < tns.ndim: - perm += list(range(len(perm), tns.ndim)) - - tns = xp.permute_dims(tns, perm) # permute the dimensions - return xp.expand_dims( - tns, - [i for i in range(len(self.loops)) if self.loops[i] not in idxs], - ) - # access a tensor with only one indirect access index - case ein.Access(tns, idxs) if len(idxs) == 1: + case ein.Access(tns, idxs) if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): idx = self(idxs[0]) tns = self(tns) # evaluate the tensor @@ -142,7 +119,15 @@ def __call__(self, node): if not isinstance(idx, ein.Index) ] if len(indirect_idxs) == 0: - return tns + perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + if hasattr(tns, "ndim") and len(perm) < tns.ndim: + perm += list(range(len(perm), tns.ndim)) + + tns = xp.permute_dims(tns, perm) # permute the dimensions + return xp.expand_dims( + tns, + [i for i in range(len(self.loops)) if self.loops[i] not in idxs], + ) start_index = idxs.index(indirect_idxs[0]) iterator_idxs = indirect_idxs[0].get_idxs() @@ -155,12 +140,22 @@ def __call__(self, node): evaled_idxs = [ xp.arange(tns.shape[idxs.index(idx)]) - if isinstance(idx, ein.Index) else self(idx) + if isinstance(idx, ein.Index) else self(idx).flat for idx in current_idxs ] # move the axis to access tns with the evaled idxs - target_axes = [idxs.index(idx) for idx in current_idxs] + #old_target_axes = [idxs.index(idx) for idx in current_idxs] + target_axes = [] + current_idxs_i = 0 + for i, idx in enumerate(idxs): + if idx == current_idxs[current_idxs_i]: + target_axes.append(i) + current_idxs_i += 1 + if current_idxs_i == len(current_idxs): + break + assert current_idxs_i == len(current_idxs) + dest_axes = [i for i in range(len(current_idxs))] tns = xp.moveaxis(tns, target_axes, dest_axes) @@ -171,13 +166,20 @@ def __call__(self, node): tns = xp.moveaxis(tns, source=0, destination=target_axes[0]) # we recursiveley call the interpreter with the remaining idxs - new_idxs = [ - iterator_idxs[0] if idx in current_idxs else idx - for idx in idxs - if idx not in current_idxs[1:] - ] - - new_access = ein.Access(tns, new_idxs) + iterator_idx = next(iter(iterator_idxs)) + + new_idxs = [] + current_idxs_i = 0 + for i, idx in enumerate(idxs): + if current_idxs_i < len(current_idxs) and idx == current_idxs[current_idxs_i]: + if current_idxs_i == 0: + new_idxs.append(iterator_idx) + current_idxs_i += 1 + else: + new_idxs.append(idx) + assert current_idxs_i == len(current_idxs) + + new_access = ein.Access(ein.Literal(tns), new_idxs) return self(new_access) case ein.Plan(bodies): @@ -201,7 +203,7 @@ def __call__(self, node): assert isinstance(obj, SparseTensor) # return the coord array for the given dimension or all dimensions - return obj.coords if dim is None else obj.coords[:, dim] + return obj.coords if dim is None else obj.coords[:, dim].flat # gets the shape of a sparse tensor at a given dimension case ein.GetAttr(obj, ein.Literal("shape"), dim): obj = self(obj) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 3e85f1aa..111eb442 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -15,6 +15,22 @@ def rng(): return np.random.default_rng(42) +def test_pass_through(rng): + """Test pass through of a tensor""" + A = rng.random((5, 5)) + + B = finchlite.einop("B[i,j] = A[i,j]", A=A) + + assert np.allclose(B, A) + +def test_transpose(rng): + """Test basic addition with transpose""" + A = rng.random((5, 5)) + + B = finchlite.einop("B[i,j] = A[j, i]", A=A) + B_ref = A.T + + assert np.allclose(B, B_ref) def test_basic_addition_with_transpose(rng): """Test basic addition with transpose""" @@ -1147,9 +1163,6 @@ def run_einsum_plan( np.set_printoptions(threshold=sys.maxsize) - print(result) - print(expected) - assert np.allclose(result, expected) def test_indirect_elementwise_multiplication(self, rng): From 656aabecf0a740410dc7d777b8c3c60c818ca474 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 23 Nov 2025 08:09:48 -0500 Subject: [PATCH 42/45] * Added comments to enhance clarity in ein.Access for multiple indicies * Simplified computation of target_axes, current_idxs, and new_idxs in ein.Access --- src/finchlite/finch_einsum/interpreter.py | 44 +++++++++-------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index e07233b8..825bd0c8 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -118,7 +118,9 @@ def __call__(self, node): idx for idx in idxs if not isinstance(idx, ein.Index) ] - if len(indirect_idxs) == 0: + + # base case: no indirect indices, just permute the dimensions + if len(indirect_idxs) == 0: perm = [idxs.index(idx) for idx in self.loops if idx in idxs] if hasattr(tns, "ndim") and len(perm) < tns.ndim: perm += list(range(len(perm), tns.ndim)) @@ -129,33 +131,26 @@ def __call__(self, node): [i for i in range(len(self.loops)) if self.loops[i] not in idxs], ) - start_index = idxs.index(indirect_idxs[0]) - iterator_idxs = indirect_idxs[0].get_idxs() + start_index = idxs.index(indirect_idxs[0]) # index of first indirect access + iterator_idxs = indirect_idxs[0].get_idxs() # iterator indicies of the first indirect access assert len(iterator_idxs) == 1 - current_idxs = [ - idx for idx in idxs[start_index:] + # get the axes of the idxs that are associated with the current iterator indicies + target_axes = [ + i for i, idx in enumerate(idxs[start_index:], start_index) if idx.get_idxs().issubset(iterator_idxs) ] + # get associated access indicies w/ the first indirect access + current_idxs = [idxs[i] for i in target_axes] + + # evaluate the associated access indicies evaled_idxs = [ xp.arange(tns.shape[idxs.index(idx)]) if isinstance(idx, ein.Index) else self(idx).flat for idx in current_idxs ] - # move the axis to access tns with the evaled idxs - #old_target_axes = [idxs.index(idx) for idx in current_idxs] - target_axes = [] - current_idxs_i = 0 - for i, idx in enumerate(idxs): - if idx == current_idxs[current_idxs_i]: - target_axes.append(i) - current_idxs_i += 1 - if current_idxs_i == len(current_idxs): - break - assert current_idxs_i == len(current_idxs) - dest_axes = [i for i in range(len(current_idxs))] tns = xp.moveaxis(tns, target_axes, dest_axes) @@ -167,17 +162,10 @@ def __call__(self, node): # we recursiveley call the interpreter with the remaining idxs iterator_idx = next(iter(iterator_idxs)) - - new_idxs = [] - current_idxs_i = 0 - for i, idx in enumerate(idxs): - if current_idxs_i < len(current_idxs) and idx == current_idxs[current_idxs_i]: - if current_idxs_i == 0: - new_idxs.append(iterator_idx) - current_idxs_i += 1 - else: - new_idxs.append(idx) - assert current_idxs_i == len(current_idxs) + new_idxs = list(idxs[:start_index]) + [iterator_idx] + [ + idx for idx in idxs[start_index + 1:] + if idx not in current_idxs + ] new_access = ein.Access(ein.Literal(tns), new_idxs) return self(new_access) From 4bf1a787fa41470f8d0d4d4ad60b9decb33c74b0 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 23 Nov 2025 08:15:57 -0500 Subject: [PATCH 43/45] * Simplified calculation of dest_axes in ein.Access --- src/finchlite/finch_einsum/interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 825bd0c8..586ba9af 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -151,7 +151,7 @@ def __call__(self, node): for idx in current_idxs ] - dest_axes = [i for i in range(len(current_idxs))] + dest_axes = list(range(len(current_idxs))) tns = xp.moveaxis(tns, target_axes, dest_axes) # access the tensor with the evaled idxs From bd8e22cfab57f4f9f07e980aded316bd74805e5f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 23 Nov 2025 08:16:39 -0500 Subject: [PATCH 44/45] * Fixed ruff errors --- src/finchlite/finch_einsum/interpreter.py | 57 ++++++++++++++--------- tests/test_einsum.py | 5 +- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 586ba9af..e37ba796 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -100,7 +100,9 @@ def __call__(self, node): return func(*vals) # access a tensor with only one indirect access index - case ein.Access(tns, idxs) if len(idxs) == 1 and not isinstance(idxs[0], ein.Index): + case ein.Access(tns, idxs) if len(idxs) == 1 and not isinstance( + idxs[0], ein.Index + ): idx = self(idxs[0]) tns = self(tns) # evaluate the tensor @@ -112,32 +114,39 @@ def __call__(self, node): # access a tensor with a mixture of indices and other expressions case ein.Access(tns, idxs): assert self.loops is not None - + tns = self(tns) - indirect_idxs = [ - idx for idx in idxs - if not isinstance(idx, ein.Index) - ] + indirect_idxs = [idx for idx in idxs if not isinstance(idx, ein.Index)] # base case: no indirect indices, just permute the dimensions - if len(indirect_idxs) == 0: + if len(indirect_idxs) == 0: perm = [idxs.index(idx) for idx in self.loops if idx in idxs] if hasattr(tns, "ndim") and len(perm) < tns.ndim: perm += list(range(len(perm), tns.ndim)) - + tns = xp.permute_dims(tns, perm) # permute the dimensions return xp.expand_dims( tns, - [i for i in range(len(self.loops)) if self.loops[i] not in idxs], + [ + i + for i in range(len(self.loops)) + if self.loops[i] not in idxs + ], ) - start_index = idxs.index(indirect_idxs[0]) # index of first indirect access - iterator_idxs = indirect_idxs[0].get_idxs() # iterator indicies of the first indirect access + start_index = idxs.index( + indirect_idxs[0] + ) # index of first indirect access + iterator_idxs = indirect_idxs[ + 0 + ].get_idxs() # iterator indicies of the first indirect access assert len(iterator_idxs) == 1 - # get the axes of the idxs that are associated with the current iterator indicies + # get the axes of the idxs that are associated + # with the current iterator indicies target_axes = [ - i for i, idx in enumerate(idxs[start_index:], start_index) + i + for i, idx in enumerate(idxs[start_index:], start_index) if idx.get_idxs().issubset(iterator_idxs) ] @@ -146,26 +155,32 @@ def __call__(self, node): # evaluate the associated access indicies evaled_idxs = [ - xp.arange(tns.shape[idxs.index(idx)]) - if isinstance(idx, ein.Index) else self(idx).flat + xp.arange(tns.shape[idxs.index(idx)]) + if isinstance(idx, ein.Index) + else self(idx).flat for idx in current_idxs ] dest_axes = list(range(len(current_idxs))) tns = xp.moveaxis(tns, target_axes, dest_axes) - + # access the tensor with the evaled idxs tns = tns[tuple(evaled_idxs)] - + # restore original tensor axis order tns = xp.moveaxis(tns, source=0, destination=target_axes[0]) # we recursiveley call the interpreter with the remaining idxs iterator_idx = next(iter(iterator_idxs)) - new_idxs = list(idxs[:start_index]) + [iterator_idx] + [ - idx for idx in idxs[start_index + 1:] - if idx not in current_idxs - ] + new_idxs = ( + list(idxs[:start_index]) + + [iterator_idx] + + [ + idx + for idx in idxs[start_index + 1 :] + if idx not in current_idxs + ] + ) new_access = ein.Access(ein.Literal(tns), new_idxs) return self(new_access) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 111eb442..0f7dcf6d 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -15,14 +15,16 @@ def rng(): return np.random.default_rng(42) + def test_pass_through(rng): """Test pass through of a tensor""" A = rng.random((5, 5)) B = finchlite.einop("B[i,j] = A[i,j]", A=A) - + assert np.allclose(B, A) + def test_transpose(rng): """Test basic addition with transpose""" A = rng.random((5, 5)) @@ -32,6 +34,7 @@ def test_transpose(rng): assert np.allclose(B, B_ref) + def test_basic_addition_with_transpose(rng): """Test basic addition with transpose""" A = rng.random((5, 5)) From dd0167d7c6d8e1ac237a10684df0f3007adae5af Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 23 Nov 2025 08:21:32 -0500 Subject: [PATCH 45/45] * Fixed end of file issues in .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ddc5c234..766724a1 100644 --- a/.gitignore +++ b/.gitignore @@ -176,4 +176,4 @@ cython_debug/ pixi.lock # vscode debugging configurations -.vscode/ \ No newline at end of file +.vscode/