-
Notifications
You must be signed in to change notification settings - Fork 8
[DRAFT] SuperTensor einsum interpreter #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
willow-ahrens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the correct logic! For readability and maintainability, consider using generator expressions, sets, and dicts more frequently. Take a look at my specific comments for more detail.
| for child in curr.children: | ||
| postorder(child) | ||
|
|
||
| postorder(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use PostOrderDFS from symbolic.py, inline this function
| Returns: | ||
| `List[Tuple[FrozenSet[str], List[str]]]` | ||
| A list of tuples, each containing a set of tensor names and the corresponding list of indices that appear in exactly those tensors. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this function, would it be possible to instead just construct a dictionary mapping sets of tensors to lists of indices, then for each index compute the set of tensors and add that index to the corresponding list in the dictionary?
idx_groups = Dict[Index, Set[Alias]]()
for node in PostOrderDFS(einsum):
match node:
case Access(tns, idxs):
for idx in idxs:
idx_groups.setdefault(idx, Set[Alias]()).add(tns)
group_idxs = Dict[Tuple[Alias], Set[Index]]()
for idx, group in idx_groups:
tns_groups[group].setdefault(tuple(sort(group)), Set[Index]()).add(idx)
Something like this might be simpler, could you try to refactor a bit to simplify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also inline this logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the detailed feedback! I applied these changes and I think it definitely made the code significantly more concise and readable.
| # Assign a new index name to each group of original indices. | ||
| new_idxs = {} | ||
| for k, (tensor_set, _) in enumerate(idx_groups): | ||
| new_idxs[tensor_set] = f"i{k}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use Namespace from Symbolic.py to create fresh index variable names.
|
|
||
| corrected_bindings = {} | ||
| corrected_idx_lists = {} | ||
| for tns_name, supertensor, input_idx_list in inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would fuse this list with the previous, they are the same loop, basically.
| corrected_idx_lists = {} | ||
| for tns_name, supertensor, input_idx_list in inputs: | ||
| new_idx_list = [] | ||
| mode_map = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can construct a dictionary globally, which maps idx -> newidx, which I think would be helpful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_idx_list = sort(list(set(global_idx_map[idx] for idx in access.idxs)))
mode_map = [[access.idxs.index(idx) for idx in idx_groups[new_idx] if idx in access.idxs] for new_idx in new_idx_list]
supertensor.pyinterpreter.pyAccessnode is bound to aSuperTensor, and the indices of eachAccessnode are logical indices on theSuperTensorSuperTensorin the input AST will have some arbitrary shape.Accessnodes in the einsum expression.Accessnode with the proper indices to access the correctly-shaped base tensor.SuperTensor.test.py