Commit 05ca023
[shard-map] in eager shmap, handle all rep rule output cases
By convention, rep_rules can return three kinds of thing:
1. a sequence (tuple or list),
2. a single set, or
3. a single None.
Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives.
In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs.
Previously, the code was checking `if type(out_rep) is set`, which doesn't handle case 3.
(We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.)
fixes jax-ml#26148, fixes jax-ml#27673
Co-authored-by: Justin Fu <[email protected]>1 parent e1e37f8 commit 05ca023
2 files changed
+15
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
958 | 958 | | |
959 | 959 | | |
960 | 960 | | |
961 | | - | |
| 961 | + | |
| 962 | + | |
962 | 963 | | |
963 | 964 | | |
964 | 965 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
685 | 685 | | |
686 | 686 | | |
687 | 687 | | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
688 | 701 | | |
689 | 702 | | |
690 | 703 | | |
| |||
0 commit comments