Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/finchlite/supertensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .interpreter import SuperTensorEinsumInterpreter
from .supertensor import SuperTensor

__all__ = [
"SuperTensor",
"SuperTensorEinsumInterpreter",
]
146 changes: 146 additions & 0 deletions src/finchlite/supertensor/interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import numpy as np

from ..finch_einsum import EinsumInterpreter
from ..finch_einsum import nodes as ein
from ..symbolic import Namespace, PostOrderDFS, PostWalk, Rewrite
from .supertensor import SuperTensor


class SuperTensorEinsumInterpreter:
def __init__(self, xp=None, bindings=None):
if bindings is None:
bindings = {}
if xp is None:
xp = np
self.bindings = bindings
self.xp = xp

def __call__(self, node):
match node:
case ein.Plan(bodies):
res = None
for body in bodies:
res = self(body)
return res
case ein.Produces(args):
return tuple(self(arg) for arg in args)
case ein.Einsum(_, output_tns, output_idxs, arg):
# Collect all Access nodes in the einsum AST.
accesses = [
node for node in PostOrderDFS(arg) if isinstance(node, ein.Access)
]

# For each index, collect the set of tensors in which the index appears.
idx_appearances: dict[ein.Index, set[ein.Alias]] = {}
for access in accesses:
for idx in access.idxs:
idx_appearances.setdefault(idx, set()).add(access.tns)
for idx in output_idxs:
idx_appearances.setdefault(idx, set()).add(output_tns)

# Map each set of tensors to the list of indices.
tensor_sets: dict[tuple[ein.Alias], list[ein.Index]] = {}
for idx, tensors in idx_appearances.items():
tensor_sets.setdefault(
tuple(sorted(tensors, key=lambda t: t.name)), []
).append(idx)

# Assign a new index name to each group of original indices.
idx_groups: dict[ein.Index, list[ein.Index]] = {}
old_to_new_idx_map: dict[ein.Index, ein.Index] = {}

namespace = Namespace()
for idx_group in tensor_sets.values():
new_idx = ein.Index(namespace.freshen("i"))
idx_groups[new_idx] = idx_group
for old_idx in idx_group:
old_to_new_idx_map[old_idx] = new_idx

# Compute the corrected SuperTensor representations.
corrected_bindings: dict[str, np.ndarray] = {}
corrected_idx_lists: dict[str, list[ein.Index]] = {}
for access in accesses:
new_idx_list = sorted(
{old_to_new_idx_map[idx] for idx in access.idxs},
key=lambda i: i.name,
)
mode_map = [
[
access.idxs.index(idx)
for idx in idx_groups[new_idx]
if idx in access.idxs
]
for new_idx in new_idx_list
]

# Restore the logical shape of the SuperTensor.
supertensor = self.bindings[access.tns.name]
logical_tns = np.empty(
supertensor.shape, dtype=supertensor.base.dtype
)
for idx in np.ndindex(supertensor.shape):
logical_tns[idx] = supertensor[idx]

# Reshape the base tensor using the corrected mode map.
corrected_supertensor = SuperTensor.from_logical(
logical_tns, mode_map
)

# Map the tensor name to its corrected base tensor and index list.
corrected_bindings[access.tns.name] = corrected_supertensor.base
corrected_idx_lists[access.tns.name] = new_idx_list

# Construct the corrected index list and mode map for the output.
new_output_idx_list = sorted(
{old_to_new_idx_map[idx] for idx in output_idxs},
key=lambda i: i.name,
)
output_mode_map = [
[
output_idxs.index(idx)
for idx in idx_groups[new_idx]
if idx in output_idxs
]
for new_idx in new_output_idx_list
]
corrected_idx_lists[output_tns.name] = new_output_idx_list

# Compute the logical shape of the output SuperTensor.
output_shape = [0] * len(output_idxs)
for idx in output_idxs:
# Find an input tensor which contains this logical index.
for access in accesses:
if idx in access.idxs:
supertensor = self.bindings[access.tns.name]
output_shape[output_idxs.index(idx)] = supertensor.shape[
access.idxs.index(idx)
]
break
output_shape = tuple(output_shape)

# Rewrite the einsum AST to use the corrected indices.
def reshape_supertensors(node):
match node:
case ein.Access(tns, _):
updated_idxs = corrected_idx_lists[tns.name]
return ein.Access(tns, tuple(updated_idxs))
case ein.Einsum(op, tns, _, arg):
updated_output_idxs = corrected_idx_lists[tns.name]
return ein.Einsum(op, tns, tuple(updated_output_idxs), arg)
case _:
return node

corrected_AST = Rewrite(PostWalk(reshape_supertensors))(node)

# Compute the output base tensor.
ctx = EinsumInterpreter(bindings=corrected_bindings)
result = ctx(corrected_AST)
output_base = corrected_bindings[result[0]]

# Wrap the output base tensor into a SuperTensor.
self.bindings[output_tns.name] = SuperTensor(
output_shape, output_base, output_mode_map
)
return (output_tns.name,)
case _:
return None
123 changes: 123 additions & 0 deletions src/finchlite/supertensor/supertensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np


class SuperTensor:
"""
Represents a tensor using a base tensor of lower order.

Attributes:
shape: `tuple[int, ...]`
The logical shape of the tensor.
base: `np.ndarray`
The base tensor.
mode_map: `list[list[int]]`
Maps each mode of the base tensor to a ordered list of the logical modes
which are flattened into the base mode. The ordering of each list defines
the order in which the logical modes are flattened.

Example: mode_map = [[0, 2], [3], [4, 1]] indicates that the base tensor
has three modes and the logical tensor has five modes.
- Base mode 0 corresponds to logical modes 0 and 2.
- Base mode 1 corresponds to logical mode 3.
- Base mode 2 corresponds to logical modes 4 and 1.
"""

shape: tuple[int, ...]
base: np.ndarray
mode_map: list[list[int]]

def __init__(
self, shape: tuple[int, ...], base: np.ndarray, mode_map: list[list[int]]
):
self.shape = shape
self.base = base
self.mode_map = mode_map

@property
def N(self) -> int:
return len(self.shape)

@property
def B(self) -> int:
return self.base.ndim

@classmethod
def from_logical(cls, tns: np.ndarray, mode_map: list[list[int]]):
"""
Constructs a SuperTensor from a logical tensor and a mode map.

Args:
tns: `np.ndarray`
The logical tensor.
mode_map: `list[list[int]]`
The mode map.
"""
shape = tns.shape

base_shape = [0] * len(mode_map)
for b, logical_idx_group in enumerate(mode_map):
dims = [shape[m] for m in logical_idx_group]
base_shape[b] = np.prod(dims) if dims else 1

perm = [i for logical_idx_group in mode_map for i in logical_idx_group]
permuted_tns = np.transpose(tns, perm)
base = permuted_tns.reshape(tuple(base_shape))

return SuperTensor(shape, base, mode_map)

def __getitem__(self, coords: tuple[int, ...]):
"""
Accesses an element of the SuperTensor using logical coordinates.

Args:
coords: `tuple[int, ...]`
The logical coordinates to access.

Returns:
The value in the SuperTensor at the given logical coordinates.

Raises:
IndexError: The number of input coordinates does not match the
order of the logical tensor.
"""

if len(coords) != len(self.shape):
raise IndexError(
f"Expected {len(self.shape)} indices, got {len(coords)} indices"
)

base_coords = [0] * self.B
for b, logical_modes in enumerate(self.mode_map):
if len(logical_modes) == 1:
base_coords[b] = coords[logical_modes[0]]
else:
subshape = tuple(self.shape[m] for m in logical_modes)
subidcs = tuple(coords[m] for m in logical_modes)
linear_idx = 0
for dim, idx in zip(subshape, subidcs, strict=True):
linear_idx = linear_idx * dim + idx
base_coords[b] = linear_idx

return self.base[tuple(base_coords)]

def __repr__(self):
"""
Returns a string representation of the SuperTensor.

Includes the logical shape, the base shape, the mode map,
and the logical tensor itself.

Returns:
A string representation of the SuperTensor.
"""
logical_tns = np.empty(self.shape, dtype=self.base.dtype)
for idx in np.ndindex(self.shape):
logical_tns[idx] = self[idx]

return (
f"SuperTensor("
f"shape={self.shape}, "
f"base.shape={self.base.shape}, "
f"mode_map={self.mode_map})\n"
f"{logical_tns}"
)
41 changes: 41 additions & 0 deletions src/finchlite/supertensor/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np

from finchlite import finch_einsum as ein
from finchlite import supertensor as stns

# ========= TEST SuperTensorEinsumInterpreter =========
# Very simple test case

rng = np.random.default_rng()

A = rng.integers(1, 5, (3, 2, 4))
supertensor_A = stns.SuperTensor.from_logical(A, [[0], [1, 2]])

B = rng.integers(1, 5, (2, 5, 4, 3))
supertensor_B = stns.SuperTensor.from_logical(B, [[3], [1], [2, 0]])

print(f"SuperTensor A:\n{supertensor_A}\n")
print(f"SuperTensor B:\n{supertensor_B}\n")

# Using regular EinsumInterpreter
einsum_AST, bindings = ein.parse_einsum("ikl,kjln->ijn", A, B)
interpreter = ein.EinsumInterpreter(bindings=bindings)
output = interpreter(einsum_AST)
result = bindings[output[0]]
print(f"Regular einsum interpreter result:\n{result}\n")

# Using SuperTensorEinsumInterpreter
supertensor_einsum_AST, supertensor_bindings = ein.parse_einsum(
"ikl,kjln->ijn", supertensor_A, supertensor_B
)

# print(f"SuperTensor einsum AST info:\n")
# print(f"{supertensor_einsum_AST}\n")
# print(f"{supertensor_bindings}\n")

supertensor_interpreter = stns.SuperTensorEinsumInterpreter(
bindings=supertensor_bindings
)
output = supertensor_interpreter(supertensor_einsum_AST)
result = supertensor_bindings[output[0]]
print(f"SuperTensor einsum interpreter result:\n{result}")