-
Notifications
You must be signed in to change notification settings - Fork 8
[DRAFT] SuperTensor einsum interpreter #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
4411648
aed7ce3
c63f758
967f4ff
75b9a49
2144fff
a7fb3eb
3123dfa
dfdce84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .supertensor import SuperTensor | ||
| from .interpreter import SuperTensorEinsumInterpreter |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| import operator | ||
|
|
||
| import numpy as np | ||
| import copy | ||
|
|
||
| from ..algebra import overwrite, promote_max, promote_min | ||
| from typing import List, Tuple, Set, FrozenSet, Dict | ||
| from itertools import chain, combinations | ||
|
|
||
| from finchlite import finch_einsum as ein | ||
| from finchlite import symbolic as sym | ||
| from . import supertensor as stns | ||
|
|
||
| class SuperTensorEinsumInterpreter: | ||
| def __init__(self, xp=None, bindings=None): | ||
| if bindings is None: | ||
| bindings = {} | ||
| if xp is None: | ||
| xp = np | ||
| self.bindings = bindings | ||
| self.xp = xp | ||
|
|
||
| def __call__(self, node): | ||
| match node: | ||
| case ein.Plan(bodies): | ||
| res = None | ||
| for body in bodies: | ||
| res = self(body) | ||
| return res | ||
| case ein.Produces(args): | ||
| return tuple(self(arg) for arg in args) | ||
| case ein.Einsum(op, ein.Alias(output_name), output_idxs, arg): | ||
| accesses = SuperTensorEinsumInterpreter._collect_accesses(arg) | ||
|
|
||
| # Group the indices which appear in exactly the same set of tensors. | ||
| output_idxs = [idx.name for idx in output_idxs] | ||
|
|
||
| input_idx_list = [] | ||
| for access in accesses: | ||
| tns_name = access.tns.name | ||
| idxs = [idx.name for idx in access.idxs] | ||
| input_idx_list.append((tns_name, idxs)) | ||
|
|
||
| idx_groups = SuperTensorEinsumInterpreter._group_indices(output_name, output_idxs, input_idx_list) | ||
|
|
||
| # Assign a new index name to each group of original indices. | ||
| new_idxs = {} | ||
| for k, (tensor_set, _) in enumerate(idx_groups): | ||
| new_idxs[tensor_set] = f"i{k}" | ||
|
|
||
|
|
||
| # Compute the corrected SuperTensor representation for each tensor. | ||
| inputs = [] | ||
| for access in accesses: | ||
| tns_name = access.tns.name | ||
| supertensor = self.bindings[access.tns.name] | ||
| idxs = [idx.name for idx in access.idxs] | ||
| inputs.append((tns_name, supertensor, idxs)) | ||
|
|
||
| corrected_bindings = {} | ||
| corrected_idx_lists = {} | ||
| for tns_name, supertensor, input_idx_list in inputs: | ||
|
||
| new_idx_list = [] | ||
| mode_map = [] | ||
|
||
| for (tns_set, idx_group) in idx_groups: | ||
| if tns_name in tns_set: | ||
| new_idx = new_idxs[tns_set] | ||
| new_idx_list.append(new_idx) | ||
|
|
||
| logical_modes = [] | ||
| for idx in idx_group: | ||
| logical_modes.append(input_idx_list.index(idx)) | ||
|
|
||
| mode_map.append(logical_modes) | ||
|
|
||
| # Restore the logical shape of the SuperTensor. | ||
| logical_tns = np.empty(supertensor.shape, dtype=supertensor.base.dtype) | ||
| for idx in np.ndindex(supertensor.shape): | ||
| logical_tns[idx] = supertensor[idx] | ||
|
|
||
| # Reshape the base tensor using the updated mode map. | ||
| corrected_supertensor = stns.SuperTensor.from_logical(logical_tns, mode_map) | ||
|
|
||
| # Map the tensor name to its corrected base representation and index list. | ||
| corrected_bindings[tns_name] = corrected_supertensor.base | ||
| corrected_idx_lists[tns_name] = new_idx_list | ||
|
|
||
| # Construct the correct mode map and index list for the output SuperTensor. | ||
| # TODO: Fix the code repetition here... | ||
|
|
||
| new_output_idx_list = [] | ||
| output_mode_map = [] | ||
| for (tns_set, idx_group) in idx_groups: | ||
| if output_name in tns_set: | ||
| new_idx = new_idxs[tns_set] | ||
| new_output_idx_list.append(new_idx) | ||
|
|
||
| logical_modes = [] | ||
| for idx in idx_group: | ||
| logical_modes.append(output_idxs.index(idx)) | ||
|
|
||
| output_mode_map.append(logical_modes) | ||
|
|
||
| corrected_idx_lists[output_name] = new_output_idx_list | ||
|
|
||
| # Compute the shape of the output SuperTensor. | ||
| output_shape = [0] * len(output_mode_map) | ||
| for idx in new_output_idx_list: | ||
| # Find an input SuperTensor which contains this index. | ||
| for base_tns, idx_list in zip(corrected_bindings.values(), corrected_idx_lists.values()): | ||
| if idx in idx_list: | ||
| dim = base_tns.shape[idx_list.index(idx)] | ||
| output_shape[new_output_idx_list.index(idx)] = dim | ||
| break | ||
| output_shape = tuple(output_shape) | ||
|
|
||
| # Replace each ein.Alias node with the proper base representation (i.e., update index lists). | ||
| def reshape_supertensors(node): | ||
| match node: | ||
| case ein.Access(tns, _): | ||
| # TODO: What to do when tns isn't an ein.Alias? | ||
| if not isinstance(tns, ein.Alias): | ||
| return node | ||
| updated_idxs = [ein.Index(idx) for idx in corrected_idx_lists[tns.name]] | ||
| return ein.Access(tns, tuple(updated_idxs)) | ||
| case ein.Einsum(op, ein.Alias(output_name), _, arg): | ||
| updated_output_idxs = [ein.Index(idx) for idx in corrected_idx_lists[output_name]] | ||
| return ein.Einsum(op, ein.Alias(output_name), tuple(updated_output_idxs), arg) | ||
| case _: | ||
| return node | ||
|
|
||
| corrected_AST = sym.Rewrite(sym.PostWalk(reshape_supertensors))(node) | ||
|
|
||
| # Use a regular EinsumInterpreter to execute the einsum on the SuperTensors. | ||
| ctx = ein.EinsumInterpreter(bindings=corrected_bindings) | ||
| result_alias = ctx(corrected_AST) | ||
| output_base = corrected_bindings[result_alias[0]] | ||
|
|
||
| self.bindings[output_name] = stns.SuperTensor(output_shape, output_base, output_mode_map) | ||
| return (output_name,) | ||
| case _: | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _collect_accesses(cls, node: ein.EinsumExpr) -> List[ein.Access]: | ||
| """ | ||
| Collect all `ein.Access` nodes in an einsum AST. | ||
|
|
||
| Args: | ||
| node: `ein.EinsumExpr` | ||
| The root node of the einsum expression AST, i.e., the `arg` field of an `ein.Einsum` node. | ||
|
|
||
| Returns: | ||
| `List[ein.Access]` | ||
| A list of `ein.Access` nodes found in the AST. | ||
| """ | ||
|
|
||
| accesses = [] | ||
|
|
||
| def postorder(curr): | ||
| match curr: | ||
| case ein.Access(): | ||
| for child in curr.children: | ||
| postorder(child) | ||
| accesses.append(curr) | ||
| case _: | ||
| if hasattr(curr, "children"): | ||
| for child in curr.children: | ||
| postorder(child) | ||
|
|
||
| postorder(node) | ||
|
||
| return accesses | ||
|
|
||
| @classmethod | ||
| def _group_indices(cls, output_name: str, output_idxs: List[str], inputs: List[Tuple[str, List[str]]]) -> List[Tuple[FrozenSet[str], List[str]]]: | ||
| """ | ||
| Groups indices based on the set of tensors they appear in. | ||
|
|
||
| Establishes the canonical ordering for the indices within each group and also for the groups themselves. | ||
|
|
||
| Args: | ||
| output_name: `str` | ||
| The name of output tensor. | ||
| output_idxs: `List[str]` | ||
| The list of indices in the output tensor. | ||
| inputs: `List[Tuple[str, List[str]]]` | ||
| The list of input tensors, each represented by a tuple containing the tensor name and its list of indices. | ||
|
|
||
| Returns: | ||
| `List[Tuple[FrozenSet[str], List[str]]]` | ||
| A list of tuples, each containing a set of tensor names and the corresponding list of indices that appear in exactly those tensors. | ||
| """ | ||
|
||
|
|
||
| # Associate each tensor name with its index set. | ||
| tensors = [(name, set(idxs)) for (name, idxs) in inputs] | ||
| tensors.append((output_name, set(output_idxs))) | ||
|
|
||
| # Generate all non-empty subsets of the set of tensors. | ||
| powerset = chain.from_iterable(combinations(tensors, n) for n in range(1, len(tensors) + 1)) | ||
|
|
||
| # Associate each subset of tensors to the group of indices that appear in exactly those tensors. | ||
| groups = [] | ||
| for subset in powerset: | ||
| included_tensors = [name for (name, _) in subset] | ||
| included_idx_sets = [idxs for (_, idxs) in subset] | ||
| excluded_idx_sets = [idxs for (name, idxs) in tensors if name not in included_tensors] | ||
|
|
||
| included_intersection = set.intersection(*included_idx_sets) | ||
| excluded_union = set.union(*excluded_idx_sets) if excluded_idx_sets else set() | ||
| idx_group = included_intersection.difference(excluded_union) | ||
|
|
||
| if idx_group: | ||
| tensor_set = frozenset(included_tensors) | ||
| groups.append((tensor_set, sorted(list(idx_group)))) | ||
|
|
||
| return groups | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| import numpy as np | ||
| from typing import List, Tuple | ||
|
|
||
| class SuperTensor(): | ||
| """ | ||
| Represents a tensor using a base tensor of lower order. | ||
|
|
||
| Attributes: | ||
| shape: `Tuple[int, ...]` | ||
| The logical shape of the tensor. | ||
| base: `np.ndarray` | ||
| The base tensor. | ||
| map: `List[List[int]]` | ||
| Maps each mode of the base tensor to a ordered list of the logical modes which are flattened into the base mode. | ||
| The ordering of each list defines the order in which the logical modes are flattened. | ||
|
|
||
| Example: map = [[0, 2], [3], [4, 1]] indicates that the base tensor has three modes and the logical tensor has five modes. | ||
| - Base mode 0 corresponds to logical modes 0 and 2. | ||
| - Base mode 1 corresponds to logical mode 3. | ||
| - Base mode 2 corresponds to logical modes 4 and 1. | ||
| """ | ||
|
|
||
| shape: Tuple[int, ...] | ||
| base: np.ndarray | ||
| map: List[List[int]] | ||
|
|
||
| def __init__(self, shape: Tuple[int, ...], base: np.ndarray, map: List[List[int]]): | ||
| self.shape = shape | ||
| self.base = base | ||
| self.map = map | ||
|
|
||
| @property | ||
| def N(self) -> int: | ||
| return len(self.shape) | ||
|
|
||
| @property | ||
| def B(self) -> int: | ||
| return self.base.ndim | ||
|
|
||
| @classmethod | ||
| def from_logical(cls, tns: np.ndarray, map: List[List[int]]): | ||
| """ | ||
| Constructs a SuperTensor from a logical tensor and a mode map. | ||
|
|
||
| Args: | ||
| tns: `np.ndarray` | ||
| The logical tensor. | ||
| map: `List[List[int]]` | ||
| The mode map. | ||
| """ | ||
| shape = tns.shape | ||
|
|
||
| base_shape = [0] * len(map) | ||
| for b, logical_idx_group in enumerate(map): | ||
| dims = [shape[m] for m in logical_idx_group] | ||
| base_shape[b] = np.prod(dims) if dims else 1 | ||
|
|
||
| perm = [i for logical_idx_group in map for i in logical_idx_group] | ||
| permuted_tns = np.transpose(tns, perm) | ||
| base = permuted_tns.reshape(tuple(base_shape)) | ||
|
|
||
| return SuperTensor(shape, base, map) | ||
|
|
||
| def __getitem__(self, coords: Tuple[int, ...]): | ||
| """ | ||
| Accesses an element of the SuperTensor using logical coordinates. | ||
|
|
||
| Args: | ||
| coords: `Tuple[int, ...]` | ||
| The logical coordinates to access. | ||
|
|
||
| Returns: | ||
| The value in the SuperTensor at the given logical coordinates. | ||
|
|
||
| Raises: | ||
| IndexError: The number of input coordinates does not match the order of the logical tensor. | ||
| """ | ||
|
|
||
| if len(coords) != len(self.shape): | ||
| raise IndexError(f"Expected {len(self.shape)} indices, got {len(coords)} indices") | ||
|
|
||
| base_coords = [0] * self.B | ||
| for b, logical_modes in enumerate(self.map): | ||
| if len(logical_modes) == 1: | ||
| base_coords[b] = coords[logical_modes[0]] | ||
| else: | ||
| subshape = tuple(self.shape[m] for m in logical_modes) | ||
| subidcs = tuple(coords[m] for m in logical_modes) | ||
| linear_idx = 0 | ||
| for dim, idx in zip(subshape, subidcs): | ||
| linear_idx = linear_idx * dim + idx | ||
| base_coords[b] = linear_idx | ||
|
|
||
| return self.base[tuple(base_coords)] | ||
|
|
||
| def __repr__(self): | ||
| """ | ||
| Returns a string representation of the SuperTensor. | ||
|
|
||
| Includes the logical shape, the base shape, the mode map, and the logical tensor itself. | ||
|
|
||
| Returns: | ||
| A string representation of the SuperTensor. | ||
| """ | ||
| logical_tns = np.empty(self.shape, dtype=self.base.dtype) | ||
| for idx in np.ndindex(self.shape): | ||
| logical_tns[idx] = self[idx] | ||
| return f"SuperTensor(shape={self.shape}, base.shape={self.base.shape}, map={self.map})\n{logical_tns}" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import numpy as np | ||
| from finchlite import finch_einsum as ein | ||
| from finchlite import supertensor as stns | ||
|
|
||
| # ========= TEST SuperTensorEinsumInterpreter ========= | ||
| # Very simple test case | ||
|
|
||
| A = np.random.randint(1,5,(3,2,4)) | ||
| supertensor_A = stns.SuperTensor.from_logical(A, [[0],[1,2]]) | ||
|
|
||
| B = np.random.randint(1,5,(2,4,5)) | ||
| supertensor_B = stns.SuperTensor.from_logical(B, [[0,1],[2]]) | ||
|
|
||
| print(f"SuperTensor A:\n{supertensor_A}\n") | ||
| print(f"SuperTensor B:\n{supertensor_B}\n") | ||
|
|
||
| # Using regular EinsumInterpreter | ||
| einsum_AST, bindings = ein.parse_einsum("ikl,klj->ij", A, B) | ||
| interpreter = ein.EinsumInterpreter(bindings=bindings) | ||
| output = interpreter(einsum_AST) | ||
| result = bindings[output[0]] | ||
| print(f"Regular einsum interpreter result:\n{result}\n") | ||
|
|
||
| # Using SuperTensorEinsumInterpreter | ||
| supertensor_einsum_AST, supertensor_bindings = ein.parse_einsum("ikl,klj->ij", supertensor_A, supertensor_B) | ||
|
|
||
| # print(f"SuperTensor einsum AST info:\n") | ||
| # print(f"{supertensor_einsum_AST}\n") | ||
| # print(f"{supertensor_bindings}\n") | ||
|
|
||
| supertensor_interpreter = stns.SuperTensorEinsumInterpreter(bindings=supertensor_bindings) | ||
| output = supertensor_interpreter(supertensor_einsum_AST) | ||
| result = supertensor_bindings[output[0]] | ||
| print(f"SuperTensor einsum interpreter result:\n{result}") | ||
|
|
||
| # TEST SuperTensor class | ||
| # A = np.random.randint(1,5,(2,3,4)) | ||
| # print(A) | ||
| # supertensor = stns.SuperTensor.from_logical(A, [[0,1],[2]]) | ||
| # print(supertensor) | ||
| # print(supertensor.base) | ||
|
|
||
| # TEST _group_indices() | ||
| # groups = stns.SuperTensorEinsumInterpreter._group_indices("out", ["p", "i", "j"], [("A", ["p", "q", "i", "k"]), ("B", ["p", "r", "k", "j"])]) | ||
| # print(groups) | ||
|
|
||
| # TEST _collect_accesses() | ||
| # accesses = stns.SuperTensorEinsumInterpreter._collect_accesses(einsum_AST) | ||
| # print(f"\nAccesses:\n{accesses}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use Namespace from Symbolic.py to create fresh index variable names.