Skip to content

Commit f2674a0

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Feature score eviction backend and frontend support (#3287)
Summary: Pull Request resolved: #3287 X-link: pytorch/FBGEMM#4681 X-link: facebookresearch/FBGEMM#1707 ## Context We need a new eviction policy for large embedding which has high id growth rate. The feature score eviction is based on engagement rate of id instead of only time or counter. This will help model to keep all relatively important ids during eviction. ## Detail * New Eviction Strategy: BY_FEATURE_SCORE Added a new eviction trigger strategy BY_FEATURE_SCORE in the eviction config and logic. This strategy uses feature scores derived from engagement rates to decide which IDs to evict. * FeatureScoreBasedEvict Class Implements the feature score based eviction logic. Maintains buckets of feature scores per shard and table to compute eviction thresholds. * Supports a dry-run mode to calculate thresholds before actual eviction. Eviction decisions are based on thresholds computed from feature score distributions. Supports decay of feature score statistics over time. * Async Metadata Update API Added set_kv_zch_eviction_metadata_async method to update feature score metadata asynchronously in the KV store. This method shards the input indices and engagement rates and updates the feature score statistics in parallel. * Dry Run Eviction Mode Introduced a dry run mode to simulate eviction rounds to compute thresholds without actually evicting. Dry run results are used to finalize thresholds for real eviction rounds. Reviewed By: emlin Differential Revision: D78138679 fbshipit-source-id: 6196c3676abf94b690f1ac776ca8f5c739cae1ea
1 parent 79fbb29 commit f2674a0

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
CountTimestampMixedEvictionPolicy,
8585
data_type_to_sparse_type,
8686
FeatureL2NormBasedEvictionPolicy,
87+
FeatureScoreBasedEvictionPolicy,
8788
NoEvictionPolicy,
8889
pooling_type_to_pooling_mode,
8990
TimestampBasedEvictionPolicy,
@@ -235,6 +236,9 @@ def _populate_zero_collision_tbe_params(
235236
counter_thresholds = [0] * len(config.embedding_tables)
236237
ttls_in_mins = [0] * len(config.embedding_tables)
237238
counter_decay_rates = [0.0] * len(config.embedding_tables)
239+
feature_score_counter_decay_rates = [0.0] * len(config.embedding_tables)
240+
max_training_id_num_per_table = [0] * len(config.embedding_tables)
241+
target_eviction_percent_per_table = [0.0] * len(config.embedding_tables)
238242
l2_weight_thresholds = [0.0] * len(config.embedding_tables)
239243
eviction_strategy = -1
240244
table_names = [table.name for table in config.embedding_tables]
@@ -251,6 +255,20 @@ def _populate_zero_collision_tbe_params(
251255
raise ValueError(
252256
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 1 for tables {table_names}"
253257
)
258+
elif isinstance(policy_t, FeatureScoreBasedEvictionPolicy):
259+
feature_score_counter_decay_rates[i] = policy_t.decay_rate
260+
max_training_id_num_per_table[i] = (
261+
policy_t.max_training_id_num_per_rank
262+
)
263+
target_eviction_percent_per_table[i] = (
264+
policy_t.target_eviction_percent
265+
)
266+
if eviction_strategy == -1 or eviction_strategy == 5:
267+
eviction_strategy = 5
268+
else:
269+
raise ValueError(
270+
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}"
271+
)
254272
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
255273
ttls_in_mins[i] = policy_t.eviction_ttl_mins
256274
if eviction_strategy == -1 or eviction_strategy == 0:
@@ -288,6 +306,9 @@ def _populate_zero_collision_tbe_params(
288306
counter_thresholds=counter_thresholds,
289307
ttls_in_mins=ttls_in_mins,
290308
counter_decay_rates=counter_decay_rates,
309+
feature_score_counter_decay_rates=feature_score_counter_decay_rates,
310+
max_training_id_num_per_table=max_training_id_num_per_table,
311+
target_eviction_percent_per_table=target_eviction_percent_per_table,
291312
l2_weight_thresholds=l2_weight_thresholds,
292313
meta_header_lens=meta_header_lens,
293314
)

torchrec/modules/embedding_configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ def __post_init__(self) -> None:
210210
self.inference_eviction_threshold = self.eviction_threshold
211211

212212

213+
@dataclass
214+
class FeatureScoreBasedEvictionPolicy(VirtualTableEvictionPolicy):
215+
"""
216+
Feature score based eviction policy for virtual table.
217+
"""
218+
219+
decay_rate: float = 0.99 # default decay by default #TODO: Change to real value
220+
max_training_id_num_per_rank: int = 0 # max number of training ids per rank
221+
target_eviction_percent: float = 0.0 # target eviction percent
222+
223+
213224
@dataclass
214225
class TimestampBasedEvictionPolicy(VirtualTableEvictionPolicy):
215226
"""

0 commit comments

Comments
 (0)