Skip to content

Commit c9a5902

Browse files
[jax] Typing on common_devices_indices_map
PiperOrigin-RevId: 702053791
1 parent 46c748b commit c9a5902

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def _addressable_devices_indices_map(
4343
if d.process_index == d.client.process_index()}
4444

4545
@cache(max_size=4096, trace_context_in_key=False)
46-
def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]:
46+
def common_devices_indices_map(
47+
s: Sharding, global_shape: Shape) -> Mapping[Device, Index]:
4748
s.shard_shape(global_shape) # raises a good error message
4849
hlo_sharding = s._to_xla_hlo_sharding(len(global_shape))
4950
indices = op_sharding_to_indices(hlo_sharding, global_shape,

0 commit comments

Comments
 (0)