Consistent Indices for Addressable Shards of Multi-Host Arrays? #19319
Unanswered
andyehrenberg
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Suppose I have three arrays
a
,b
andc
, whereb
is an input to some compiled functionf
, andc = f(b)
. They all have the same leading dimension (consistent batch size).f
is a jitted function with an output shardingNamedSharding(mesh, PartitionSpec("data"))
anda
andb
are also multi-host arrays sharded along "data".For a given host, do the addressable_shards of
a
,b
andc
correspond to the same indices? Like ifa
is a batch of target sequences, andc
is predictions, are we guaranteed that computing a word error rate between the detokenization ofa
andc
's addressable shards will actually compare the correct targets and predictions? Or are we only guaranteed to have the correct order after all-gathering?Code like https://github.com/apple/axlearn/blob/2c7e9f8fcd6fa2ee4e9596d4d18f870e4b653289/axlearn/common/utils.py#L585 seems to imply that the indices addressable by a host remain consistent, but that their order may still change.
Beta Was this translation helpful? Give feedback.
All reactions