We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 46c748b commit c9a5902Copy full SHA for c9a5902
jax/_src/sharding.py
@@ -43,7 +43,8 @@ def _addressable_devices_indices_map(
43
if d.process_index == d.client.process_index()}
44
45
@cache(max_size=4096, trace_context_in_key=False)
46
-def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]:
+def common_devices_indices_map(
47
+ s: Sharding, global_shape: Shape) -> Mapping[Device, Index]:
48
s.shard_shape(global_shape) # raises a good error message
49
hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
50
indices = op_sharding_to_indices(hlo_sharding, global_shape,
0 commit comments