Skip to content

Commit 4dae3ad

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Free mem trigger with all2all for sync trigger eviction
Summary: Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Differential Revision: D83896528
1 parent 283e2f8 commit 4dae3ad

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,40 @@ def _populate_zero_collision_tbe_params(
336336
eviction_strategy = -1
337337
table_names = [table.name for table in config.embedding_tables]
338338
l2_cache_size = tbe_params["l2_cache_size"]
339-
if "kvzch_eviction_trigger_mode" in tbe_params:
340-
eviction_tirgger_mode = tbe_params["kvzch_eviction_trigger_mode"]
341-
tbe_params.pop("kvzch_eviction_trigger_mode")
342-
else:
343-
eviction_tirgger_mode = 2 # 2 means mem_util based eviction
339+
340+
# Eviction tbe config default values
341+
eviction_tirgger_mode = 2 # 2 means mem_util based eviction
342+
eviction_free_mem_threshold_gb = (
343+
10 # Eviction free memory trigger threshold in GB
344+
)
345+
eviction_free_mem_check_interval_batch = (
346+
1000,
347+
) # how many batchs to check free memory when trigger model is free_mem
348+
threshold_calculation_bucket_stride = 0.2
349+
threshold_calculation_bucket_num = 1000000 # 1M
350+
if "kvzch_eviction_tbe_config" in tbe_params:
351+
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
352+
tbe_params.pop("kvzch_eviction_tbe_config")
353+
354+
if eviction_tbe_config.kvzch_eviction_trigger_mode is not None:
355+
eviction_tirgger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
356+
if eviction_tbe_config.eviction_free_mem_threshold_gb is not None:
357+
eviction_free_mem_threshold_gb = (
358+
eviction_tbe_config.eviction_free_mem_threshold_gb
359+
)
360+
if eviction_tbe_config.eviction_free_mem_check_interval_batch is not None:
361+
eviction_free_mem_check_interval_batch = (
362+
eviction_tbe_config.eviction_free_mem_check_interval_batch
363+
)
364+
if eviction_tbe_config.threshold_calculation_bucket_stride is not None:
365+
threshold_calculation_bucket_stride = (
366+
eviction_tbe_config.threshold_calculation_bucket_stride
367+
)
368+
if eviction_tbe_config.threshold_calculation_bucket_num is not None:
369+
threshold_calculation_bucket_num = (
370+
eviction_tbe_config.threshold_calculation_bucket_num
371+
)
372+
344373
for i, table in enumerate(config.embedding_tables):
345374
policy_t = table.virtual_table_eviction_policy
346375
if policy_t is not None:
@@ -420,6 +449,10 @@ def _populate_zero_collision_tbe_params(
420449
training_id_keep_count=training_id_keep_count,
421450
l2_weight_thresholds=l2_weight_thresholds,
422451
meta_header_lens=meta_header_lens,
452+
eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb,
453+
eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch,
454+
threshold_calculation_bucket_stride=threshold_calculation_bucket_stride,
455+
threshold_calculation_bucket_num=threshold_calculation_bucket_num,
423456
)
424457
else:
425458
eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens)

torchrec/distributed/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3131
BoundsCheckMode,
3232
CacheAlgorithm,
33+
KVZCHEvictionTBEConfig,
3334
MultiPassPrefetchConfig,
3435
)
3536

@@ -644,7 +645,7 @@ class KeyValueParams:
644645
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
645646
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
646647
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
647-
kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH
648+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
648649
649650
# Parameter Server (PS) Attributes
650651
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -670,7 +671,7 @@ class KeyValueParams:
670671
None # enable raw embedding streaming for SSD TBE
671672
)
672673
res_store_shards: Optional[int] = None # shards to store the raw embeddings
673-
kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH
674+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
674675

675676
# Parameter Server (PS) Attributes
676677
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -699,7 +700,7 @@ def __hash__(self) -> int:
699700
self.lazy_bulk_init_enabled,
700701
self.enable_raw_embedding_streaming,
701702
self.res_store_shards,
702-
self.kvzch_eviction_trigger_mode,
703+
self.kvzch_eviction_tbe_config,
703704
)
704705
)
705706

0 commit comments

Comments
 (0)