Skip to content

Commit f9d4bbf

Browse files
emlinfacebook-github-bot
authored andcommitted
enable feature score data collection in torchrec (meta-pytorch#3285)
Summary: Pull Request resolved: meta-pytorch#3285 Add enable_feature_score_weight_accumulation flag to ShardedEmbeddingCollection. When this flag is true, and dedup ec index is true, we'll accumulate kjt weight and count and reset back to kjt weight, to allow input dist to distribute feature score. this change is part of ZCH v.Next feature score eviction story: - collect score for every feature id in model, e.g. for positive id set to 0.5, and negative id set to 0.2. - set score as the input id list feature kjt's weight value - in EC forward, if there is ID dedup, aggregate the id score and occurrence of each id. - distribute the id score in kjt weight - in KVZCH embedding kernel, call forward with weight as an optional parameter in ZCH TBE backend (separate diffs): - set the feature score to ZCH TBE backend - run eviction based on the id score value for the whole story, please reference here: https://docs.google.com/document/d/1TJHKvO1m3-5tYAKZGhacXnGk7iCNAzz7wQlrFbX_LDI/edit?tab=t.0 Reviewed By: duduyi2013 Differential Revision: D79864431 fbshipit-source-id: 4830ff41c79770e83d20a7e49f84a33f938870e4
1 parent f678776 commit f9d4bbf

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
NoEvictionPolicy,
8989
pooling_type_to_pooling_mode,
9090
TimestampBasedEvictionPolicy,
91-
VirtualTableEvictionPolicy,
9291
)
9392
from torchrec.optim.fused import (
9493
EmptyFusedOptimizer,
@@ -1713,7 +1712,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
17131712
self._split_weights_res = None
17141713
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
17151714

1716-
return super().forward(features)
1715+
return self.emb_module(
1716+
indices=features.values().long(),
1717+
offsets=features.offsets().long(),
1718+
weights=features.weights_or_none(),
1719+
)
17171720

17181721

17191722
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):

torchrec/distributed/embedding.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ShardingType,
4848
)
4949
from torchrec.distributed.fused_params import (
50+
ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION,
5051
FUSED_PARAM_IS_SSD_TABLE,
5152
FUSED_PARAM_SSD_TABLE_LIST,
5253
)
@@ -419,6 +420,20 @@ def __init__(
419420
module_fqn: Optional[str] = None,
420421
) -> None:
421422
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
423+
self._enable_feature_score_weight_accumulation: bool = False
424+
425+
if (
426+
fused_params is not None
427+
and ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION in fused_params
428+
):
429+
self._enable_feature_score_weight_accumulation = cast(
430+
bool, fused_params[ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION]
431+
)
432+
fused_params.pop(ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION)
433+
logger.info(
434+
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
435+
)
436+
422437
self._module_fqn = module_fqn
423438
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()
424439
self._table_names: List[str] = [
@@ -1321,11 +1336,32 @@ def _dedup_indices(
13211336
input_feature.offsets().to(torch.int64),
13221337
input_feature.values().to(torch.int64),
13231338
)
1339+
acc_weights = None
1340+
if (
1341+
self._enable_feature_score_weight_accumulation
1342+
and input_feature.weights_or_none() is not None
1343+
):
1344+
source_weights = input_feature.weights()
1345+
assert (
1346+
source_weights.dtype == torch.float32
1347+
), "Only float32 weights are supported for feature score eviction weights."
1348+
1349+
acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts(
1350+
source_weights.view(-1),
1351+
reverse_indices,
1352+
unique_indices.numel(),
1353+
)
1354+
13241355
dedup_features = KeyedJaggedTensor(
13251356
keys=input_feature.keys(),
13261357
lengths=lengths,
13271358
offsets=offsets,
13281359
values=unique_indices,
1360+
weights=(
1361+
acc_weights.view(torch.float64).view(-1)
1362+
if acc_weights is not None
1363+
else None
1364+
),
13291365
)
13301366

13311367
ctx.input_features.append(input_feature)

torchrec/distributed/fused_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
FUSED_PARAM_SSD_TABLE_LIST: str = "__register_ssd_table_list"
3434
# Bool fused param per table to check if the table is offloaded to SSD
3535
FUSED_PARAM_IS_SSD_TABLE: str = "__register_is_ssd_table"
36+
ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION: str = (
37+
"enable_feature_score_weight_accumulation"
38+
)
3639

3740

3841
class TBEToRegisterMixIn:

0 commit comments

Comments
 (0)