|
47 | 47 | ShardingType,
|
48 | 48 | )
|
49 | 49 | from torchrec.distributed.fused_params import (
|
| 50 | + ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION, |
50 | 51 | FUSED_PARAM_IS_SSD_TABLE,
|
51 | 52 | FUSED_PARAM_SSD_TABLE_LIST,
|
52 | 53 | )
|
@@ -419,6 +420,20 @@ def __init__(
|
419 | 420 | module_fqn: Optional[str] = None,
|
420 | 421 | ) -> None:
|
421 | 422 | super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
|
| 423 | + self._enable_feature_score_weight_accumulation: bool = False |
| 424 | + |
| 425 | + if ( |
| 426 | + fused_params is not None |
| 427 | + and ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION in fused_params |
| 428 | + ): |
| 429 | + self._enable_feature_score_weight_accumulation = cast( |
| 430 | + bool, fused_params[ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION] |
| 431 | + ) |
| 432 | + fused_params.pop(ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION) |
| 433 | + logger.info( |
| 434 | + f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}." |
| 435 | + ) |
| 436 | + |
422 | 437 | self._module_fqn = module_fqn
|
423 | 438 | self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()
|
424 | 439 | self._table_names: List[str] = [
|
@@ -1321,11 +1336,32 @@ def _dedup_indices(
|
1321 | 1336 | input_feature.offsets().to(torch.int64),
|
1322 | 1337 | input_feature.values().to(torch.int64),
|
1323 | 1338 | )
|
| 1339 | + acc_weights = None |
| 1340 | + if ( |
| 1341 | + self._enable_feature_score_weight_accumulation |
| 1342 | + and input_feature.weights_or_none() is not None |
| 1343 | + ): |
| 1344 | + source_weights = input_feature.weights() |
| 1345 | + assert ( |
| 1346 | + source_weights.dtype == torch.float32 |
| 1347 | + ), "Only float32 weights are supported for feature score eviction weights." |
| 1348 | + |
| 1349 | + acc_weights = torch.ops.fbgemm.jagged_acc_weights_and_counts( |
| 1350 | + source_weights.view(-1), |
| 1351 | + reverse_indices, |
| 1352 | + unique_indices.numel(), |
| 1353 | + ) |
| 1354 | + |
1324 | 1355 | dedup_features = KeyedJaggedTensor(
|
1325 | 1356 | keys=input_feature.keys(),
|
1326 | 1357 | lengths=lengths,
|
1327 | 1358 | offsets=offsets,
|
1328 | 1359 | values=unique_indices,
|
| 1360 | + weights=( |
| 1361 | + acc_weights.view(torch.float64).view(-1) |
| 1362 | + if acc_weights is not None |
| 1363 | + else None |
| 1364 | + ), |
1329 | 1365 | )
|
1330 | 1366 |
|
1331 | 1367 | ctx.input_features.append(input_feature)
|
|
0 commit comments