We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 93eae33 commit 1b1e2b3Copy full SHA for 1b1e2b3
torchrec/distributed/batched_embedding_kernel.py
@@ -2217,6 +2217,17 @@ def split_embedding_weights(
2217
]:
2218
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
2219
2220
+ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2221
+ # reset split weights during training
2222
+ self._split_weights_res = None
2223
+ self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
2224
+
2225
+ return self.emb_module(
2226
+ indices=features.values().long(),
2227
+ offsets=features.offsets().long(),
2228
+ weights=features.weights_or_none(),
2229
+ )
2230
2231
2232
class ZeroCollisionEmbeddingCache(ZeroCollisionKeyValueEmbedding):
2233
def __init__(
0 commit comments