Commit 72e5ca9
[JAX] Fix a small bug if shardings is tuple.
# Details
`jax.tree.map` requests all its arguments to have the same data type.
From ```[None] * len(tensorstore_specs) if global_shapes is None else global_shapes```,
The data type is already decided to be a list. So if we pass `sharding` or `tspecs` as a tuple, it will fail.
Here we add an explicit conversion to a list for sharding and tspecs.
PiperOrigin-RevId: 7075768661 parent 464e5a2 commit 72e5ca9
File tree
2 files changed
+35
-2
lines changed- jax/experimental/array_serialization
2 files changed
+35
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
420 | 420 | | |
421 | 421 | | |
422 | 422 | | |
423 | | - | |
424 | 423 | | |
425 | 424 | | |
426 | | - | |
| 425 | + | |
427 | 426 | | |
428 | 427 | | |
429 | 428 | | |
| |||
Lines changed: 34 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
50 | 50 | | |
51 | 51 | | |
52 | 52 | | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
53 | 87 | | |
54 | 88 | | |
55 | 89 | | |
| |||
0 commit comments