diff --git a/src/finchlite/supertensor/__init__.py b/src/finchlite/supertensor/__init__.py new file mode 100644 index 00000000..0cdf5a95 --- /dev/null +++ b/src/finchlite/supertensor/__init__.py @@ -0,0 +1,7 @@ +from .interpreter import SuperTensorEinsumInterpreter +from .supertensor import SuperTensor + +__all__ = [ + "SuperTensor", + "SuperTensorEinsumInterpreter", +] diff --git a/src/finchlite/supertensor/interpreter.py b/src/finchlite/supertensor/interpreter.py new file mode 100644 index 00000000..22af8ddb --- /dev/null +++ b/src/finchlite/supertensor/interpreter.py @@ -0,0 +1,146 @@ +import numpy as np + +from ..finch_einsum import EinsumInterpreter +from ..finch_einsum import nodes as ein +from ..symbolic import Namespace, PostOrderDFS, PostWalk, Rewrite +from .supertensor import SuperTensor + + +class SuperTensorEinsumInterpreter: + def __init__(self, xp=None, bindings=None): + if bindings is None: + bindings = {} + if xp is None: + xp = np + self.bindings = bindings + self.xp = xp + + def __call__(self, node): + match node: + case ein.Plan(bodies): + res = None + for body in bodies: + res = self(body) + return res + case ein.Produces(args): + return tuple(self(arg) for arg in args) + case ein.Einsum(_, output_tns, output_idxs, arg): + # Collect all Access nodes in the einsum AST. + accesses = [ + node for node in PostOrderDFS(arg) if isinstance(node, ein.Access) + ] + + # For each index, collect the set of tensors in which the index appears. + idx_appearances: dict[ein.Index, set[ein.Alias]] = {} + for access in accesses: + for idx in access.idxs: + idx_appearances.setdefault(idx, set()).add(access.tns) + for idx in output_idxs: + idx_appearances.setdefault(idx, set()).add(output_tns) + + # Map each set of tensors to the list of indices. + tensor_sets: dict[tuple[ein.Alias], list[ein.Index]] = {} + for idx, tensors in idx_appearances.items(): + tensor_sets.setdefault( + tuple(sorted(tensors, key=lambda t: t.name)), [] + ).append(idx) + + # Assign a new index name to each group of original indices. + idx_groups: dict[ein.Index, list[ein.Index]] = {} + old_to_new_idx_map: dict[ein.Index, ein.Index] = {} + + namespace = Namespace() + for idx_group in tensor_sets.values(): + new_idx = ein.Index(namespace.freshen("i")) + idx_groups[new_idx] = idx_group + for old_idx in idx_group: + old_to_new_idx_map[old_idx] = new_idx + + # Compute the corrected SuperTensor representations. + corrected_bindings: dict[str, np.ndarray] = {} + corrected_idx_lists: dict[str, list[ein.Index]] = {} + for access in accesses: + new_idx_list = sorted( + {old_to_new_idx_map[idx] for idx in access.idxs}, + key=lambda i: i.name, + ) + mode_map = [ + [ + access.idxs.index(idx) + for idx in idx_groups[new_idx] + if idx in access.idxs + ] + for new_idx in new_idx_list + ] + + # Restore the logical shape of the SuperTensor. + supertensor = self.bindings[access.tns.name] + logical_tns = np.empty( + supertensor.shape, dtype=supertensor.base.dtype + ) + for idx in np.ndindex(supertensor.shape): + logical_tns[idx] = supertensor[idx] + + # Reshape the base tensor using the corrected mode map. + corrected_supertensor = SuperTensor.from_logical( + logical_tns, mode_map + ) + + # Map the tensor name to its corrected base tensor and index list. + corrected_bindings[access.tns.name] = corrected_supertensor.base + corrected_idx_lists[access.tns.name] = new_idx_list + + # Construct the corrected index list and mode map for the output. + new_output_idx_list = sorted( + {old_to_new_idx_map[idx] for idx in output_idxs}, + key=lambda i: i.name, + ) + output_mode_map = [ + [ + output_idxs.index(idx) + for idx in idx_groups[new_idx] + if idx in output_idxs + ] + for new_idx in new_output_idx_list + ] + corrected_idx_lists[output_tns.name] = new_output_idx_list + + # Compute the logical shape of the output SuperTensor. + output_shape = [0] * len(output_idxs) + for idx in output_idxs: + # Find an input tensor which contains this logical index. + for access in accesses: + if idx in access.idxs: + supertensor = self.bindings[access.tns.name] + output_shape[output_idxs.index(idx)] = supertensor.shape[ + access.idxs.index(idx) + ] + break + output_shape = tuple(output_shape) + + # Rewrite the einsum AST to use the corrected indices. + def reshape_supertensors(node): + match node: + case ein.Access(tns, _): + updated_idxs = corrected_idx_lists[tns.name] + return ein.Access(tns, tuple(updated_idxs)) + case ein.Einsum(op, tns, _, arg): + updated_output_idxs = corrected_idx_lists[tns.name] + return ein.Einsum(op, tns, tuple(updated_output_idxs), arg) + case _: + return node + + corrected_AST = Rewrite(PostWalk(reshape_supertensors))(node) + + # Compute the output base tensor. + ctx = EinsumInterpreter(bindings=corrected_bindings) + result = ctx(corrected_AST) + output_base = corrected_bindings[result[0]] + + # Wrap the output base tensor into a SuperTensor. + self.bindings[output_tns.name] = SuperTensor( + output_shape, output_base, output_mode_map + ) + return (output_tns.name,) + case _: + return None diff --git a/src/finchlite/supertensor/supertensor.py b/src/finchlite/supertensor/supertensor.py new file mode 100644 index 00000000..ced5215a --- /dev/null +++ b/src/finchlite/supertensor/supertensor.py @@ -0,0 +1,123 @@ +import numpy as np + + +class SuperTensor: + """ + Represents a tensor using a base tensor of lower order. + + Attributes: + shape: `tuple[int, ...]` + The logical shape of the tensor. + base: `np.ndarray` + The base tensor. + mode_map: `list[list[int]]` + Maps each mode of the base tensor to a ordered list of the logical modes + which are flattened into the base mode. The ordering of each list defines + the order in which the logical modes are flattened. + + Example: mode_map = [[0, 2], [3], [4, 1]] indicates that the base tensor + has three modes and the logical tensor has five modes. + - Base mode 0 corresponds to logical modes 0 and 2. + - Base mode 1 corresponds to logical mode 3. + - Base mode 2 corresponds to logical modes 4 and 1. + """ + + shape: tuple[int, ...] + base: np.ndarray + mode_map: list[list[int]] + + def __init__( + self, shape: tuple[int, ...], base: np.ndarray, mode_map: list[list[int]] + ): + self.shape = shape + self.base = base + self.mode_map = mode_map + + @property + def N(self) -> int: + return len(self.shape) + + @property + def B(self) -> int: + return self.base.ndim + + @classmethod + def from_logical(cls, tns: np.ndarray, mode_map: list[list[int]]): + """ + Constructs a SuperTensor from a logical tensor and a mode map. + + Args: + tns: `np.ndarray` + The logical tensor. + mode_map: `list[list[int]]` + The mode map. + """ + shape = tns.shape + + base_shape = [0] * len(mode_map) + for b, logical_idx_group in enumerate(mode_map): + dims = [shape[m] for m in logical_idx_group] + base_shape[b] = np.prod(dims) if dims else 1 + + perm = [i for logical_idx_group in mode_map for i in logical_idx_group] + permuted_tns = np.transpose(tns, perm) + base = permuted_tns.reshape(tuple(base_shape)) + + return SuperTensor(shape, base, mode_map) + + def __getitem__(self, coords: tuple[int, ...]): + """ + Accesses an element of the SuperTensor using logical coordinates. + + Args: + coords: `tuple[int, ...]` + The logical coordinates to access. + + Returns: + The value in the SuperTensor at the given logical coordinates. + + Raises: + IndexError: The number of input coordinates does not match the + order of the logical tensor. + """ + + if len(coords) != len(self.shape): + raise IndexError( + f"Expected {len(self.shape)} indices, got {len(coords)} indices" + ) + + base_coords = [0] * self.B + for b, logical_modes in enumerate(self.mode_map): + if len(logical_modes) == 1: + base_coords[b] = coords[logical_modes[0]] + else: + subshape = tuple(self.shape[m] for m in logical_modes) + subidcs = tuple(coords[m] for m in logical_modes) + linear_idx = 0 + for dim, idx in zip(subshape, subidcs, strict=True): + linear_idx = linear_idx * dim + idx + base_coords[b] = linear_idx + + return self.base[tuple(base_coords)] + + def __repr__(self): + """ + Returns a string representation of the SuperTensor. + + Includes the logical shape, the base shape, the mode map, + and the logical tensor itself. + + Returns: + A string representation of the SuperTensor. + """ + logical_tns = np.empty(self.shape, dtype=self.base.dtype) + for idx in np.ndindex(self.shape): + logical_tns[idx] = self[idx] + + return ( + f"SuperTensor(" + f"shape={self.shape}, " + f"base.shape={self.base.shape}, " + f"mode_map={self.mode_map})\n" + f"{logical_tns}" + ) diff --git a/src/finchlite/supertensor/test.py b/src/finchlite/supertensor/test.py new file mode 100644 index 00000000..722c1abd --- /dev/null +++ b/src/finchlite/supertensor/test.py @@ -0,0 +1,41 @@ +import numpy as np + +from finchlite import finch_einsum as ein +from finchlite import supertensor as stns + +# ========= TEST SuperTensorEinsumInterpreter ========= +# Very simple test case + +rng = np.random.default_rng() + +A = rng.integers(1, 5, (3, 2, 4)) +supertensor_A = stns.SuperTensor.from_logical(A, [[0], [1, 2]]) + +B = rng.integers(1, 5, (2, 5, 4, 3)) +supertensor_B = stns.SuperTensor.from_logical(B, [[3], [1], [2, 0]]) + +print(f"SuperTensor A:\n{supertensor_A}\n") +print(f"SuperTensor B:\n{supertensor_B}\n") + +# Using regular EinsumInterpreter +einsum_AST, bindings = ein.parse_einsum("ikl,kjln->ijn", A, B) +interpreter = ein.EinsumInterpreter(bindings=bindings) +output = interpreter(einsum_AST) +result = bindings[output[0]] +print(f"Regular einsum interpreter result:\n{result}\n") + +# Using SuperTensorEinsumInterpreter +supertensor_einsum_AST, supertensor_bindings = ein.parse_einsum( + "ikl,kjln->ijn", supertensor_A, supertensor_B +) + +# print(f"SuperTensor einsum AST info:\n") +# print(f"{supertensor_einsum_AST}\n") +# print(f"{supertensor_bindings}\n") + +supertensor_interpreter = stns.SuperTensorEinsumInterpreter( + bindings=supertensor_bindings +) +output = supertensor_interpreter(supertensor_einsum_AST) +result = supertensor_bindings[output[0]] +print(f"SuperTensor einsum interpreter result:\n{result}")