diff --git a/.gitignore b/.gitignore index ef712a26..766724a1 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,6 @@ cython_debug/ .pixi/ *.egg-info pixi.lock + +# vscode debugging configurations +.vscode/ diff --git a/src/finchlite/finch_einsum/__init__.py b/src/finchlite/finch_einsum/__init__.py index 9d722c35..92305651 100644 --- a/src/finchlite/finch_einsum/__init__.py +++ b/src/finchlite/finch_einsum/__init__.py @@ -7,6 +7,7 @@ EinsumExpression, EinsumNode, EinsumStatement, + GetAttr, Index, Literal, Plan, @@ -26,6 +27,7 @@ "EinsumScheduler", "EinsumScheduler", "EinsumStatement", + "GetAttr", "Index", "Literal", "Plan", diff --git a/src/finchlite/finch_einsum/interpreter.py b/src/finchlite/finch_einsum/interpreter.py index 01f256fe..e37ba796 100644 --- a/src/finchlite/finch_einsum/interpreter.py +++ b/src/finchlite/finch_einsum/interpreter.py @@ -3,6 +3,7 @@ import numpy as np from ..algebra import overwrite, promote_max, promote_min +from ..symbolic import ftype from . import nodes as ein nary_ops = { @@ -78,6 +79,11 @@ def __init__(self, xp=None, bindings=None, loops=None): self.loops = loops def __call__(self, node): + from ..tensor import ( + SparseTensor, + SparseTensorFType, + ) + xp = self.xp match node: case ein.Literal(val): @@ -92,16 +98,93 @@ def __call__(self, node): func = getattr(xp, nary_ops[func]) vals = [self(arg) for arg in args] return func(*vals) + + # access a tensor with only one indirect access index + case ein.Access(tns, idxs) if len(idxs) == 1 and not isinstance( + idxs[0], ein.Index + ): + idx = self(idxs[0]) + tns = self(tns) # evaluate the tensor + + flat_idx = ( + idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape) + ) + return tns.flat[flat_idx] # return a 1-d array by definition + + # access a tensor with a mixture of indices and other expressions case ein.Access(tns, idxs): - assert len(idxs) == len(set(idxs)) assert self.loops is not None - perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + tns = self(tns) - tns = xp.permute_dims(tns, perm) - return xp.expand_dims( - tns, - [i for i in range(len(self.loops)) if self.loops[i] not in idxs], + indirect_idxs = [idx for idx in idxs if not isinstance(idx, ein.Index)] + + # base case: no indirect indices, just permute the dimensions + if len(indirect_idxs) == 0: + perm = [idxs.index(idx) for idx in self.loops if idx in idxs] + if hasattr(tns, "ndim") and len(perm) < tns.ndim: + perm += list(range(len(perm), tns.ndim)) + + tns = xp.permute_dims(tns, perm) # permute the dimensions + return xp.expand_dims( + tns, + [ + i + for i in range(len(self.loops)) + if self.loops[i] not in idxs + ], + ) + + start_index = idxs.index( + indirect_idxs[0] + ) # index of first indirect access + iterator_idxs = indirect_idxs[ + 0 + ].get_idxs() # iterator indicies of the first indirect access + assert len(iterator_idxs) == 1 + + # get the axes of the idxs that are associated + # with the current iterator indicies + target_axes = [ + i + for i, idx in enumerate(idxs[start_index:], start_index) + if idx.get_idxs().issubset(iterator_idxs) + ] + + # get associated access indicies w/ the first indirect access + current_idxs = [idxs[i] for i in target_axes] + + # evaluate the associated access indicies + evaled_idxs = [ + xp.arange(tns.shape[idxs.index(idx)]) + if isinstance(idx, ein.Index) + else self(idx).flat + for idx in current_idxs + ] + + dest_axes = list(range(len(current_idxs))) + tns = xp.moveaxis(tns, target_axes, dest_axes) + + # access the tensor with the evaled idxs + tns = tns[tuple(evaled_idxs)] + + # restore original tensor axis order + tns = xp.moveaxis(tns, source=0, destination=target_axes[0]) + + # we recursiveley call the interpreter with the remaining idxs + iterator_idx = next(iter(iterator_idxs)) + new_idxs = ( + list(idxs[:start_index]) + + [iterator_idx] + + [ + idx + for idx in idxs[start_index + 1 :] + if idx not in current_idxs + ] ) + + new_access = ein.Access(ein.Literal(tns), new_idxs) + return self(new_access) + case ein.Plan(bodies): res = None for body in bodies: @@ -109,9 +192,38 @@ def __call__(self, node): return res case ein.Produces(args): return tuple(self(arg) for arg in args) - case ein.Einsum(op, ein.Alias(tns), idxs, arg): + + # get non-zero elements/data array of a sparse tensor + case ein.GetAttr(obj, ein.Literal("elems"), _): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + return obj.data + # get coord array of a sparse tensor + case ein.GetAttr(obj, ein.Literal("coords"), dim): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + + # return the coord array for the given dimension or all dimensions + return obj.coords if dim is None else obj.coords[:, dim].flat + # gets the shape of a sparse tensor at a given dimension + case ein.GetAttr(obj, ein.Literal("shape"), dim): + obj = self(obj) + assert isinstance(ftype(obj), SparseTensorFType) + assert isinstance(obj, SparseTensor) + assert dim is not None + + # return the shape for the given dimension + return obj.shape[dim] + + # standard einsum + case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all( + isinstance(idx, ein.Index) for idx in idxs + ): # This is the main entry point for einsum execution loops = arg.get_idxs() + assert set(idxs).issubset(loops) loops = sorted(loops, key=lambda x: x.name) ctx = EinsumInterpreter(self.xp, self.bindings, loops) @@ -128,5 +240,11 @@ def __call__(self, node): axis = [dropped.index(idx) for idx in idxs] self.bindings[tns] = xp.permute_dims(val, axis) return (tns,) + + # indirect einsum + case ein.Einsum(op, ein.Alias(tns), idxs, arg): + raise NotImplementedError( + "Indirect einsum assignment is not implemented" + ) case _: raise ValueError(f"Unknown einsum type: {type(node)}") diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index ef233047..fc7ea306 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -186,6 +186,42 @@ def get_idxs(self) -> set["Index"]: return idxs +@dataclass(eq=True, frozen=True) +class GetAttr(EinsumExpression, EinsumTree): + """ + Gets an attribute of a tensor. + Attributes: + obj: The object to get the attribute from. + attr: The name of the attribute to get. + + dim: The dimension to get the attribute from. + Note this is an integer index, not a named index. + """ + + obj: EinsumExpression + attr: Literal + dim: int | None + + @classmethod + def from_children(cls, *children: Term) -> Self: + # Expects 3 children: obj, attr, idx + if len(children) != 3: + raise ValueError("GetAttribute expects 3 children (obj + attr + idx)") + obj = cast(EinsumExpression, children[0]) + attr = cast(Literal, children[1]) + dim = cast(int | None, children[2]) + return cls(obj, attr, dim) + + @property + def children(self): + return [self.obj, self.attr, self.dim] + + def get_idxs(self) -> set["Index"]: + idxs = set() + idxs.update(self.obj.get_idxs()) + return idxs + + @dataclass(eq=True, frozen=True) class Einsum(EinsumTree, EinsumStatement): """ @@ -224,6 +260,12 @@ def from_children(cls, *children: Term) -> Self: def children(self): return [self.op, self.tns, self.idxs, self.arg] + def get_idxs(self) -> set["Index"]: + idxs = set() + for idx in self.idxs: + idxs.update(idx.get_idxs()) + return idxs + @dataclass(eq=True, frozen=True) class Plan(EinsumTree, EinsumStatement): @@ -344,6 +386,10 @@ def __call__(self, prgm: EinsumNode): if len(args) == 1 and fn.val in unary_strs: return f"{unary_strs[fn.val]}{args_e[0]}" return f"{self(fn)}({', '.join(args_e)})" + case GetAttr(obj, attr, idx): + if idx is not None: + return f"{self(obj)}.{self(attr)}[{idx}]" + return f"{self(obj)}.{self(attr)}" case Einsum(op, tns, idxs, arg): op_str = infix_strs.get(op.val, op.val.__name__) self.exec( diff --git a/src/finchlite/tensor/__init__.py b/src/finchlite/tensor/__init__.py index 9c45ad8c..916e5602 100644 --- a/src/finchlite/tensor/__init__.py +++ b/src/finchlite/tensor/__init__.py @@ -7,6 +7,7 @@ dense, element, ) +from .sparse_tensor import SparseTensor, SparseTensorFType __all__ = [ "DenseLevel", @@ -17,6 +18,8 @@ "FiberTensorFType", "Level", "LevelFType", + "SparseTensor", + "SparseTensorFType", "dense", "element", "tensor", diff --git a/src/finchlite/tensor/sparse_tensor.py b/src/finchlite/tensor/sparse_tensor.py new file mode 100644 index 00000000..6d798615 --- /dev/null +++ b/src/finchlite/tensor/sparse_tensor.py @@ -0,0 +1,103 @@ +import numpy as np + +from finchlite.algebra import TensorFType +from finchlite.interface.eager import EagerTensor + + +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type): + self.shape = shape + self._element_type = element_type + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type + + def __hash__(self): + return hash((self.shape, self.element_type)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def element_type(self): + return self._element_type + + @property + def fill_value(self): + return 0 + + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__( + self, + data: np.typing.NDArray, + coords: np.typing.NDArray, + shape: tuple, + element_type=np.float64, + ): + if data.shape[0] != coords.shape[0]: + raise ValueError("data and coords must have the same number of rows") + + self.coords = coords + self.data = data + self._shape = shape + self._element_type = element_type + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: np.ndarray): + coords = np.where(dense_tensor != 0) + data = dense_tensor[coords] + shape = dense_tensor.shape + element_type = dense_tensor.dtype.type + coords_array = np.array(coords).T + return cls(data, coords_array, shape, element_type) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self._element_type) + + @property + def shape(self): + return self._shape + + @property + def ndim(self) -> np.intp: + return np.intp(len(self._shape)) + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return self.coords.shape[0] / np.prod(self.shape) + + def __getitem__(self, idx: tuple): + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + # coords is a 2D array where each row is a coordinate + mask = np.all(self.coords == idx, axis=1) + matching_indices = np.where(mask)[0] + + if len(matching_indices) > 0: + return self.data[matching_indices[0]] + return 0 + + def __str__(self): + return ( + f"SparseTensor(data={self.data}, coords={self.coords}," + f" shape={self.shape}, element_type={self._element_type})" + ) + + def to_dense(self) -> np.ndarray: + dense_tensor = np.zeros(self.shape, dtype=self._element_type) + for i in range(self.coords.shape[0]): + dense_tensor[tuple(self.coords[i])] = self.data[i] + return dense_tensor diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 434180ae..0f7dcf6d 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -1,8 +1,14 @@ +import operator +from typing import Any + import pytest import numpy as np import finchlite +import finchlite.finch_einsum as ein +from finchlite.algebra import overwrite +from finchlite.tensor import SparseTensor @pytest.fixture @@ -10,6 +16,25 @@ def rng(): return np.random.default_rng(42) +def test_pass_through(rng): + """Test pass through of a tensor""" + A = rng.random((5, 5)) + + B = finchlite.einop("B[i,j] = A[i,j]", A=A) + + assert np.allclose(B, A) + + +def test_transpose(rng): + """Test basic addition with transpose""" + A = rng.random((5, 5)) + + B = finchlite.einop("B[i,j] = A[j, i]", A=A) + B_ref = A.T + + assert np.allclose(B, B_ref) + + def test_basic_addition_with_transpose(rng): """Test basic addition with transpose""" A = rng.random((5, 5)) @@ -1123,3 +1148,766 @@ def test_complex_operations(self, rng, dtype): expected = np.einsum("ij", A) assert np.allclose(result, expected) + + +class TestEinsumIndirectAccess: + """Test einsum with indirect access""" + + def run_einsum_plan( + self, prgm: ein.Plan, bindings: dict[str, Any], expected: np.ndarray + ): + """Runs an einsum plan and returns the result""" + interpreter = ein.EinsumInterpreter(bindings=bindings) + result = interpreter(prgm)[0] + + import sys + + import numpy as np + + np.set_printoptions(threshold=sys.maxsize) + + assert np.allclose(result, expected) + + def test_indirect_elementwise_multiplication(self, rng): + """Test indirect elementwise multiplication but no indirect assignment""" + + A = rng.random((5, 5)) + B = rng.random((5, 5)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # A is sparse + # C[i] = AElems[i] * B[ACoords[i]] + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + result = finchlite.multiply(A, B).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_elementwise_addition(self, rng): + """Test indirect elementwise addition""" + + A = rng.random((4, 4)) + B = rng.random((4, 4)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = AElems[i] + B[ACoords[i]] + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + result = (A + B).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_multiple_reads(self, rng): + """Test multiple indirect reads from the same tensor""" + + A = rng.random((3, 3)) + B = rng.random((3, 3)) + + sparse_A = SparseTensor.from_dense_tensor(A) + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i] = AElems[i] * BElems[i] + # Both A and B are sparse, reading their elements directly + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + result = finchlite.multiply(A, B).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": sparse_B}, result) + + def test_indirect_with_constant(self, rng): + """Test indirect access combined with constant""" + + A = rng.random((4, 4)) + B = rng.random((4, 4)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = AElems[i] * B[ACoords[i]] + 5.0 + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ein.Literal(5.0), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + result = (A * B + 5.0).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, result) + + def test_indirect_nested_operations(self, rng): + """Test nested operations with indirect access""" + + A = rng.random((3, 3)) + B = rng.random((3, 3)) + C = rng.random((3, 3)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # D[i] = (AElems[i] + B[ACoords[i]]) * C[ACoords[i]] + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + ) + ) + + result = ((A + B) * C).flatten() + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B, "C": C}, result) + + def test_indirect_direct_access_only(self, rng): + """Test accessing only the indirect coordinates""" + + A = rng.random((8,)) + B = rng.random((8,)) + + sparse_A = SparseTensor.from_dense_tensor(A) + + # C[i] = B[ACoords[i]] (read B indirectly, without using A's elements) + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("A"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + # Result should be B's values at A's coordinates + expected = B[A != 0] + self.run_einsum_plan(prgm, {"A": sparse_A, "B": B}, expected) + + def test_double_indirection(self, rng): + """Test double indirection: A[B[CCoords[i]]]""" + + # Create small arrays to ensure valid indexing + A = rng.random((8,)) + # B contains integer indices into A (0-7) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + # C is sparse + C = rng.random((8,)) + + sparse_C = SparseTensor.from_dense_tensor(C) + + # D[i] = A[B[CCoords[i]]] + # First get CCoords[i], then use that to index B, then use B's value to index A + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + ) + ) + + # Expected: for each non-zero position in C, get its coord, + # index into B, then index into A + c_coords = sparse_C.coords + expected = A[B[c_coords]].flatten() + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) + + def test_triple_indirection(self, rng): + """Test triple indirection: A[B[C[DCoords[i]]]]""" + + # Create arrays with valid indexing ranges + A = rng.random((8,)) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + C = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(C) + D = rng.random((8,)) + + sparse_D = SparseTensor.from_dense_tensor(D) + + # E[i] = A[B[C[DCoords[i]]]] + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.Alias("C"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("D"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ), + ), + ), + ein.Produces((ein.Alias("E"),)), + ) + ) + + # Expected: chain of indirections + d_coords = sparse_D.coords + expected = A[B[C[d_coords]]].flatten() + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": C, "D": sparse_D}, expected) + + def test_mixed_direct_indirect_indexing_2d(self, rng): + """Test mixed indexing: A[BCoords[i], j] - one indirect, one direct""" + + A = rng.random((5, 4)) + B = rng.random((5,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j] = A[BCoords[i], j] + # First index is indirect (from B's coords), second is direct + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Index("j"), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + # Expected: A rows indexed by B's coords, all columns + b_coords = sparse_B.coords + expected = A[b_coords.flatten(), :] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_mixed_direct_indirect_indexing_reversed(self, rng): + """ + Test mixed indexing reversed + A[i, BCoords[j]] - first direct, second indirect + """ + + A = rng.random((4, 6)) + B = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j] = A[i, BCoords[j]] + # First index is direct, second is indirect + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("j"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + # Expected: all rows of A, columns indexed by B's coords + b_coords = sparse_B.coords + expected = A[:, b_coords.flatten()] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_both_indices_indirect_same_source(self, rng): + """Test both indices indirect from same source: A[BCoords[i], BCoords[i]]""" + + A = rng.random((6, 6)) + B = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i] = A[BCoords[i], BCoords[i]] + # Extracting diagonal-like elements using indirect coordinates + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + # Expected: A[coords, coords] - pseudo-diagonal at indirect positions + b_coords = sparse_B.coords + expected = A[b_coords.flatten(), b_coords.flatten()] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_both_indices_indirect_different_sources(self, rng): + """ + Test both indices indirect from different sources: + A[BCoords[i], CCoords[i]] + """ + + A = rng.random((6, 6)) + B = rng.random((6,)) + C = rng.random((6,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + sparse_C = SparseTensor.from_dense_tensor(C) + + # D[i] = A[BCoords[i], CCoords[i]] + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"),), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + ) + ) + + # Expected: A indexed by pairs of coords from B and C + # This requires both to have the same number of non-zero elements + b_coords = sparse_B.coords.flatten() + c_coords = sparse_C.coords.flatten() + expected = A[b_coords, c_coords] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B, "C": sparse_C}, expected) + + def test_double_indirection_with_operation(self, rng): + """Test double indirection combined with arithmetic operation""" + + A = rng.random((8,)) + B = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + rng.shuffle(B) + C = rng.random((8,)) + + sparse_C = SparseTensor.from_dense_tensor(C) + + # E[i] = A[B[CCoords[i]]] * CElems[i] + # Double indirection plus multiplication with sparse elements + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("E"), + idxs=(ein.Index("i"),), + arg=ein.Call( + op=ein.Literal(operator.mul), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.Alias("B"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("C"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("C"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("E"),)), + ) + ) + + c_coords = sparse_C.coords.flatten() + c_elems = sparse_C.data + expected = A[B[c_coords]] * c_elems + + self.run_einsum_plan(prgm, {"A": A, "B": B, "C": sparse_C}, expected) + + def test_mixed_indexing_with_computation(self, rng): + """Test mixed direct/indirect indexing with computation""" + + A = rng.random((5, 5)) + B = rng.random((4,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # D[i, j] = A[BCoords[i], j] + BElems[i] + # Mixed indexing plus addition with sparse elements + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("D"), + idxs=(ein.Index("i"), ein.Index("j")), + arg=ein.Call( + op=ein.Literal(operator.add), + args=( + ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ein.Index("j"), + ), + ), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("elems"), + dim=None, + ), + idxs=(ein.Index("i"),), + ), + ), + ), + ), + ein.Produces((ein.Alias("D"),)), + ) + ) + + b_coords = sparse_B.coords.flatten() + b_elems = sparse_B.data + expected = A[b_coords, :] + b_elems[:, np.newaxis] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected) + + def test_indirect_3d_tensor_access(self, rng): + """Test indirect access on 3D tensor with mixed indices""" + + A = rng.random((3, 4, 5)) + B = rng.random((4,)) + + sparse_B = SparseTensor.from_dense_tensor(B) + + # C[i, j, k] = A[i, BCoords[j], k] + # Middle dimension is indirectly indexed + prgm = ein.Plan( + ( + ein.Einsum( + op=ein.Literal(overwrite), + tns=ein.Alias("C"), + idxs=(ein.Index("i"), ein.Index("j"), ein.Index("k")), + arg=ein.Access( + tns=ein.Alias("A"), + idxs=( + ein.Index("i"), + ein.Access( + tns=ein.GetAttr( + obj=ein.Alias("B"), + attr=ein.Literal("coords"), + dim=None, + ), + idxs=(ein.Index("j"),), + ), + ein.Index("k"), + ), + ), + ), + ein.Produces((ein.Alias("C"),)), + ) + ) + + b_coords = sparse_B.coords.flatten() + expected = A[:, b_coords, :] + self.run_einsum_plan(prgm, {"A": A, "B": sparse_B}, expected)