Skip to content

Commit 1b1e2b3

Browse files
emlinfacebook-github-bot
authored andcommitted
call zch forward with kjt weights to support feature score eviction (#3352)
Summary: Pull Request resolved: #3352 as title Reviewed By: steven1327, EddyLXJ Differential Revision: D81726703 fbshipit-source-id: 9e1393da2f09404db7166c510b75a92b0d803a99
1 parent 93eae33 commit 1b1e2b3

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2217,6 +2217,17 @@ def split_embedding_weights(
22172217
]:
22182218
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
22192219

2220+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2221+
# reset split weights during training
2222+
self._split_weights_res = None
2223+
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
2224+
2225+
return self.emb_module(
2226+
indices=features.values().long(),
2227+
offsets=features.offsets().long(),
2228+
weights=features.weights_or_none(),
2229+
)
2230+
22202231

22212232
class ZeroCollisionEmbeddingCache(ZeroCollisionKeyValueEmbedding):
22222233
def __init__(

0 commit comments

Comments
 (0)