diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index cbcd0d78b..10edcd53f 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -336,11 +336,40 @@ def _populate_zero_collision_tbe_params( eviction_strategy = -1 table_names = [table.name for table in config.embedding_tables] l2_cache_size = tbe_params["l2_cache_size"] - if "kvzch_eviction_trigger_mode" in tbe_params: - eviction_tirgger_mode = tbe_params["kvzch_eviction_trigger_mode"] - tbe_params.pop("kvzch_eviction_trigger_mode") - else: - eviction_tirgger_mode = 2 # 2 means mem_util based eviction + + # Eviction tbe config default values + eviction_tirgger_mode = 2 # 2 means mem_util based eviction + eviction_free_mem_threshold_gb = ( + 10 # Eviction free memory trigger threshold in GB + ) + eviction_free_mem_check_interval_batch = ( + 1000, + ) # how many batchs to check free memory when trigger model is free_mem + threshold_calculation_bucket_stride = 0.2 + threshold_calculation_bucket_num = 1000000 # 1M + if "kvzch_eviction_tbe_config" in tbe_params: + eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"] + tbe_params.pop("kvzch_eviction_tbe_config") + + if eviction_tbe_config.kvzch_eviction_trigger_mode is not None: + eviction_tirgger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode + if eviction_tbe_config.eviction_free_mem_threshold_gb is not None: + eviction_free_mem_threshold_gb = ( + eviction_tbe_config.eviction_free_mem_threshold_gb + ) + if eviction_tbe_config.eviction_free_mem_check_interval_batch is not None: + eviction_free_mem_check_interval_batch = ( + eviction_tbe_config.eviction_free_mem_check_interval_batch + ) + if eviction_tbe_config.threshold_calculation_bucket_stride is not None: + threshold_calculation_bucket_stride = ( + eviction_tbe_config.threshold_calculation_bucket_stride + ) + if eviction_tbe_config.threshold_calculation_bucket_num is not None: + threshold_calculation_bucket_num = ( + eviction_tbe_config.threshold_calculation_bucket_num + ) + for i, table in enumerate(config.embedding_tables): policy_t = table.virtual_table_eviction_policy if policy_t is not None: @@ -420,6 +449,10 @@ def _populate_zero_collision_tbe_params( training_id_keep_count=training_id_keep_count, l2_weight_thresholds=l2_weight_thresholds, meta_header_lens=meta_header_lens, + eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb, + eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch, + threshold_calculation_bucket_stride=threshold_calculation_bucket_stride, + threshold_calculation_bucket_num=threshold_calculation_bucket_num, ) else: eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens) @@ -1760,6 +1793,7 @@ def __init__( feature_table_map=self._feature_table_map, ssd_cache_location=embedding_location, pooling_mode=PoolingMode.NONE, + pg=pg, **ssd_tbe_params, ).to(device) @@ -1992,6 +2026,7 @@ def __init__( ssd_cache_location=embedding_location, pooling_mode=PoolingMode.NONE, backend_type=backend_type, + pg=pg, **ssd_tbe_params, ).to(device) @@ -2672,6 +2707,7 @@ def __init__( feature_table_map=self._feature_table_map, ssd_cache_location=embedding_location, pooling_mode=self._pooling, + pg=pg, **ssd_tbe_params, ).to(device) @@ -2892,6 +2928,7 @@ def __init__( ssd_cache_location=embedding_location, pooling_mode=self._pooling, backend_type=backend_type, + pg=pg, **ssd_tbe_params, ).to(device) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index b23c8524a..12a86ca6f 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -30,6 +30,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, + KVZCHEvictionTBEConfig, MultiPassPrefetchConfig, ) @@ -644,7 +645,7 @@ class KeyValueParams: lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings - kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH + kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -670,7 +671,7 @@ class KeyValueParams: None # enable raw embedding streaming for SSD TBE ) res_store_shards: Optional[int] = None # shards to store the raw embeddings - kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH + kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -699,7 +700,7 @@ def __hash__(self) -> int: self.lazy_bulk_init_enabled, self.enable_raw_embedding_streaming, self.res_store_shards, - self.kvzch_eviction_trigger_mode, + self.kvzch_eviction_tbe_config, ) )