Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/finchlite/supertensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .supertensor import SuperTensor
from .interpreter import SuperTensorEinsumInterpreter
216 changes: 216 additions & 0 deletions src/finchlite/supertensor/interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import operator

import numpy as np
import copy

from ..algebra import overwrite, promote_max, promote_min
from typing import List, Tuple, Set, FrozenSet, Dict
from itertools import chain, combinations

from finchlite import finch_einsum as ein
from finchlite import symbolic as sym
from . import supertensor as stns

class SuperTensorEinsumInterpreter:
def __init__(self, xp=None, bindings=None):
if bindings is None:
bindings = {}
if xp is None:
xp = np
self.bindings = bindings
self.xp = xp

def __call__(self, node):
match node:
case ein.Plan(bodies):
res = None
for body in bodies:
res = self(body)
return res
case ein.Produces(args):
return tuple(self(arg) for arg in args)
case ein.Einsum(op, ein.Alias(output_name), output_idxs, arg):
accesses = SuperTensorEinsumInterpreter._collect_accesses(arg)

# Group the indices which appear in exactly the same set of tensors.
output_idxs = [idx.name for idx in output_idxs]

input_idx_list = []
for access in accesses:
tns_name = access.tns.name
idxs = [idx.name for idx in access.idxs]
input_idx_list.append((tns_name, idxs))

idx_groups = SuperTensorEinsumInterpreter._group_indices(output_name, output_idxs, input_idx_list)

# Assign a new index name to each group of original indices.
new_idxs = {}
for k, (tensor_set, _) in enumerate(idx_groups):
new_idxs[tensor_set] = f"i{k}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Namespace from Symbolic.py to create fresh index variable names.



# Compute the corrected SuperTensor representation for each tensor.
inputs = []
for access in accesses:
tns_name = access.tns.name
supertensor = self.bindings[access.tns.name]
idxs = [idx.name for idx in access.idxs]
inputs.append((tns_name, supertensor, idxs))

corrected_bindings = {}
corrected_idx_lists = {}
for tns_name, supertensor, input_idx_list in inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would fuse this list with the previous, they are the same loop, basically.

new_idx_list = []
mode_map = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can construct a dictionary globally, which maps idx -> newidx, which I think would be helpful here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_idx_list = sort(list(set(global_idx_map[idx] for idx in access.idxs)))
mode_map = [[access.idxs.index(idx) for idx in idx_groups[new_idx] if idx in access.idxs] for new_idx in new_idx_list]

for (tns_set, idx_group) in idx_groups:
if tns_name in tns_set:
new_idx = new_idxs[tns_set]
new_idx_list.append(new_idx)

logical_modes = []
for idx in idx_group:
logical_modes.append(input_idx_list.index(idx))

mode_map.append(logical_modes)

# Restore the logical shape of the SuperTensor.
logical_tns = np.empty(supertensor.shape, dtype=supertensor.base.dtype)
for idx in np.ndindex(supertensor.shape):
logical_tns[idx] = supertensor[idx]

# Reshape the base tensor using the updated mode map.
corrected_supertensor = stns.SuperTensor.from_logical(logical_tns, mode_map)

# Map the tensor name to its corrected base representation and index list.
corrected_bindings[tns_name] = corrected_supertensor.base
corrected_idx_lists[tns_name] = new_idx_list

# Construct the correct mode map and index list for the output SuperTensor.
# TODO: Fix the code repetition here...

new_output_idx_list = []
output_mode_map = []
for (tns_set, idx_group) in idx_groups:
if output_name in tns_set:
new_idx = new_idxs[tns_set]
new_output_idx_list.append(new_idx)

logical_modes = []
for idx in idx_group:
logical_modes.append(output_idxs.index(idx))

output_mode_map.append(logical_modes)

corrected_idx_lists[output_name] = new_output_idx_list

# Compute the shape of the output SuperTensor.
output_shape = [0] * len(output_mode_map)
for idx in new_output_idx_list:
# Find an input SuperTensor which contains this index.
for base_tns, idx_list in zip(corrected_bindings.values(), corrected_idx_lists.values()):
if idx in idx_list:
dim = base_tns.shape[idx_list.index(idx)]
output_shape[new_output_idx_list.index(idx)] = dim
break
output_shape = tuple(output_shape)

# Replace each ein.Alias node with the proper base representation (i.e., update index lists).
def reshape_supertensors(node):
match node:
case ein.Access(tns, _):
# TODO: What to do when tns isn't an ein.Alias?
if not isinstance(tns, ein.Alias):
return node
updated_idxs = [ein.Index(idx) for idx in corrected_idx_lists[tns.name]]
return ein.Access(tns, tuple(updated_idxs))
case ein.Einsum(op, ein.Alias(output_name), _, arg):
updated_output_idxs = [ein.Index(idx) for idx in corrected_idx_lists[output_name]]
return ein.Einsum(op, ein.Alias(output_name), tuple(updated_output_idxs), arg)
case _:
return node

corrected_AST = sym.Rewrite(sym.PostWalk(reshape_supertensors))(node)

# Use a regular EinsumInterpreter to execute the einsum on the SuperTensors.
ctx = ein.EinsumInterpreter(bindings=corrected_bindings)
result_alias = ctx(corrected_AST)
output_base = corrected_bindings[result_alias[0]]

self.bindings[output_name] = stns.SuperTensor(output_shape, output_base, output_mode_map)
return (output_name,)
case _:
pass

@classmethod
def _collect_accesses(cls, node: ein.EinsumExpr) -> List[ein.Access]:
"""
Collect all `ein.Access` nodes in an einsum AST.

Args:
node: `ein.EinsumExpr`
The root node of the einsum expression AST, i.e., the `arg` field of an `ein.Einsum` node.

Returns:
`List[ein.Access]`
A list of `ein.Access` nodes found in the AST.
"""

accesses = []

def postorder(curr):
match curr:
case ein.Access():
for child in curr.children:
postorder(child)
accesses.append(curr)
case _:
if hasattr(curr, "children"):
for child in curr.children:
postorder(child)

postorder(node)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use PostOrderDFS from symbolic.py, inline this function

return accesses

@classmethod
def _group_indices(cls, output_name: str, output_idxs: List[str], inputs: List[Tuple[str, List[str]]]) -> List[Tuple[FrozenSet[str], List[str]]]:
"""
Groups indices based on the set of tensors they appear in.

Establishes the canonical ordering for the indices within each group and also for the groups themselves.

Args:
output_name: `str`
The name of output tensor.
output_idxs: `List[str]`
The list of indices in the output tensor.
inputs: `List[Tuple[str, List[str]]]`
The list of input tensors, each represented by a tuple containing the tensor name and its list of indices.

Returns:
`List[Tuple[FrozenSet[str], List[str]]]`
A list of tuples, each containing a set of tensor names and the corresponding list of indices that appear in exactly those tensors.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function, would it be possible to instead just construct a dictionary mapping sets of tensors to lists of indices, then for each index compute the set of tensors and add that index to the corresponding list in the dictionary?

idx_groups = Dict[Index, Set[Alias]]()
for node in PostOrderDFS(einsum):
   match node:
      case Access(tns, idxs):
         for idx in idxs:
             idx_groups.setdefault(idx, Set[Alias]()).add(tns)
group_idxs = Dict[Tuple[Alias], Set[Index]]()
for idx, group in idx_groups:
    tns_groups[group].setdefault(tuple(sort(group)), Set[Index]()).add(idx)

Something like this might be simpler, could you try to refactor a bit to simplify?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also inline this logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed feedback! I applied these changes and I think it definitely made the code significantly more concise and readable.


# Associate each tensor name with its index set.
tensors = [(name, set(idxs)) for (name, idxs) in inputs]
tensors.append((output_name, set(output_idxs)))

# Generate all non-empty subsets of the set of tensors.
powerset = chain.from_iterable(combinations(tensors, n) for n in range(1, len(tensors) + 1))

# Associate each subset of tensors to the group of indices that appear in exactly those tensors.
groups = []
for subset in powerset:
included_tensors = [name for (name, _) in subset]
included_idx_sets = [idxs for (_, idxs) in subset]
excluded_idx_sets = [idxs for (name, idxs) in tensors if name not in included_tensors]

included_intersection = set.intersection(*included_idx_sets)
excluded_union = set.union(*excluded_idx_sets) if excluded_idx_sets else set()
idx_group = included_intersection.difference(excluded_union)

if idx_group:
tensor_set = frozenset(included_tensors)
groups.append((tensor_set, sorted(list(idx_group))))

return groups
108 changes: 108 additions & 0 deletions src/finchlite/supertensor/supertensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
from typing import List, Tuple

class SuperTensor():
"""
Represents a tensor using a base tensor of lower order.

Attributes:
shape: `Tuple[int, ...]`
The logical shape of the tensor.
base: `np.ndarray`
The base tensor.
map: `List[List[int]]`
Maps each mode of the base tensor to a ordered list of the logical modes which are flattened into the base mode.
The ordering of each list defines the order in which the logical modes are flattened.

Example: map = [[0, 2], [3], [4, 1]] indicates that the base tensor has three modes and the logical tensor has five modes.
- Base mode 0 corresponds to logical modes 0 and 2.
- Base mode 1 corresponds to logical mode 3.
- Base mode 2 corresponds to logical modes 4 and 1.
"""

shape: Tuple[int, ...]
base: np.ndarray
map: List[List[int]]

def __init__(self, shape: Tuple[int, ...], base: np.ndarray, map: List[List[int]]):
self.shape = shape
self.base = base
self.map = map

@property
def N(self) -> int:
return len(self.shape)

@property
def B(self) -> int:
return self.base.ndim

@classmethod
def from_logical(cls, tns: np.ndarray, map: List[List[int]]):
"""
Constructs a SuperTensor from a logical tensor and a mode map.

Args:
tns: `np.ndarray`
The logical tensor.
map: `List[List[int]]`
The mode map.
"""
shape = tns.shape

base_shape = [0] * len(map)
for b, logical_idx_group in enumerate(map):
dims = [shape[m] for m in logical_idx_group]
base_shape[b] = np.prod(dims) if dims else 1

perm = [i for logical_idx_group in map for i in logical_idx_group]
permuted_tns = np.transpose(tns, perm)
base = permuted_tns.reshape(tuple(base_shape))

return SuperTensor(shape, base, map)

def __getitem__(self, coords: Tuple[int, ...]):
"""
Accesses an element of the SuperTensor using logical coordinates.

Args:
coords: `Tuple[int, ...]`
The logical coordinates to access.

Returns:
The value in the SuperTensor at the given logical coordinates.

Raises:
IndexError: The number of input coordinates does not match the order of the logical tensor.
"""

if len(coords) != len(self.shape):
raise IndexError(f"Expected {len(self.shape)} indices, got {len(coords)} indices")

base_coords = [0] * self.B
for b, logical_modes in enumerate(self.map):
if len(logical_modes) == 1:
base_coords[b] = coords[logical_modes[0]]
else:
subshape = tuple(self.shape[m] for m in logical_modes)
subidcs = tuple(coords[m] for m in logical_modes)
linear_idx = 0
for dim, idx in zip(subshape, subidcs):
linear_idx = linear_idx * dim + idx
base_coords[b] = linear_idx

return self.base[tuple(base_coords)]

def __repr__(self):
"""
Returns a string representation of the SuperTensor.

Includes the logical shape, the base shape, the mode map, and the logical tensor itself.

Returns:
A string representation of the SuperTensor.
"""
logical_tns = np.empty(self.shape, dtype=self.base.dtype)
for idx in np.ndindex(self.shape):
logical_tns[idx] = self[idx]
return f"SuperTensor(shape={self.shape}, base.shape={self.base.shape}, map={self.map})\n{logical_tns}"
49 changes: 49 additions & 0 deletions src/finchlite/supertensor/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
from finchlite import finch_einsum as ein
from finchlite import supertensor as stns

# ========= TEST SuperTensorEinsumInterpreter =========
# Very simple test case

A = np.random.randint(1,5,(3,2,4))
supertensor_A = stns.SuperTensor.from_logical(A, [[0],[1,2]])

B = np.random.randint(1,5,(2,4,5))
supertensor_B = stns.SuperTensor.from_logical(B, [[0,1],[2]])

print(f"SuperTensor A:\n{supertensor_A}\n")
print(f"SuperTensor B:\n{supertensor_B}\n")

# Using regular EinsumInterpreter
einsum_AST, bindings = ein.parse_einsum("ikl,klj->ij", A, B)
interpreter = ein.EinsumInterpreter(bindings=bindings)
output = interpreter(einsum_AST)
result = bindings[output[0]]
print(f"Regular einsum interpreter result:\n{result}\n")

# Using SuperTensorEinsumInterpreter
supertensor_einsum_AST, supertensor_bindings = ein.parse_einsum("ikl,klj->ij", supertensor_A, supertensor_B)

# print(f"SuperTensor einsum AST info:\n")
# print(f"{supertensor_einsum_AST}\n")
# print(f"{supertensor_bindings}\n")

supertensor_interpreter = stns.SuperTensorEinsumInterpreter(bindings=supertensor_bindings)
output = supertensor_interpreter(supertensor_einsum_AST)
result = supertensor_bindings[output[0]]
print(f"SuperTensor einsum interpreter result:\n{result}")

# TEST SuperTensor class
# A = np.random.randint(1,5,(2,3,4))
# print(A)
# supertensor = stns.SuperTensor.from_logical(A, [[0,1],[2]])
# print(supertensor)
# print(supertensor.base)

# TEST _group_indices()
# groups = stns.SuperTensorEinsumInterpreter._group_indices("out", ["p", "i", "j"], [("A", ["p", "q", "i", "k"]), ("B", ["p", "r", "k", "j"])])
# print(groups)

# TEST _collect_accesses()
# accesses = stns.SuperTensorEinsumInterpreter._collect_accesses(einsum_AST)
# print(f"\nAccesses:\n{accesses}")