Skip to content

Commit b94d1ba

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Adding new ID_COUNT eviction trigger condition (#4829)
Summary: X-link: meta-pytorch/torchrec#3357 Pull Request resolved: #4829 X-link: facebookresearch/FBGEMM#1855 ID COUNT Trigger is based on total num of id in all kvzch tables. Memory usage is cross all tables but eviction trigger is for tbe, like: if target total memory is 10G tbe1 has table [table0, table1], training_id_eviction_trigger_count is [200, 300] tbe2 has table [table2], training_id_eviction_trigger_count is [500] then the target memory for table0 is 2G, table1 is 3G, table2 is 5G, but the trigger condition is tbe level. If currently total id num in tbe1 is [100, 350], in this case tbe1 will not trigger evict even num in table1 is greater than 300 because total num. 100 + 350 = 450 < 500. If total id num in tbe1 is [100, 450], it will trigger eviction on tbe1 only for table1, because num in table0 100 < 200. Reviewed By: emlin Differential Revision: D81151216 fbshipit-source-id: 693bd47e950cc63498b4cdfc7691636197c832dc
1 parent 0322d2e commit b94d1ba

File tree

8 files changed

+143
-107
lines changed

8 files changed

+143
-107
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def from_str(cls, key: str):
6262

6363
class EvictionPolicy(NamedTuple):
6464
eviction_trigger_mode: int = (
65-
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
65+
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 4: id count
6666
)
6767
eviction_strategy: int = (
6868
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
@@ -85,11 +85,11 @@ class EvictionPolicy(NamedTuple):
8585
feature_score_counter_decay_rates: Optional[List[float]] = (
8686
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
8787
)
88-
max_training_id_num_per_table: Optional[List[int]] = (
89-
None # max_training_id_num_per_table for each table
88+
training_id_eviction_trigger_count: Optional[List[int]] = (
89+
None # training_id_eviction_trigger_count for each table
9090
)
91-
target_eviction_percent_per_table: Optional[List[float]] = (
92-
None # target_eviction_percent_per_table for each table
91+
training_id_keep_count: Optional[List[int]] = (
92+
None # training_id_keep_count for each table
9393
)
9494
l2_weight_thresholds: Optional[List[float]] = (
9595
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
@@ -116,8 +116,8 @@ class EvictionPolicy(NamedTuple):
116116
meta_header_lens: Optional[List[int]] = None # metaheader length for each table
117117

118118
def validate(self) -> None:
119-
assert self.eviction_trigger_mode in [0, 1, 2, 3], (
120-
"eviction_trigger_mode must be 0, 1, 2, or 3, "
119+
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], (
120+
"eviction_trigger_mode must be 0, 1, 2, 3 or 4 "
121121
f"actual {self.eviction_trigger_mode}"
122122
)
123123
if self.eviction_trigger_mode == 0:
@@ -139,6 +139,10 @@ def validate(self) -> None:
139139
assert (
140140
self.eviction_mem_threshold_gb is not None
141141
), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
142+
elif self.eviction_trigger_mode == 4:
143+
assert (
144+
self.training_id_eviction_trigger_count is not None
145+
), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4"
142146

143147
if self.eviction_strategy == 0:
144148
assert self.ttls_in_mins is not None, (
@@ -184,13 +188,13 @@ def validate(self) -> None:
184188
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
185189
f"actual {self.feature_score_counter_decay_rates}"
186190
)
187-
assert self.max_training_id_num_per_table is not None, (
188-
"max_training_id_num_per_table must be set if eviction_strategy is 5,"
189-
f"actual {self.max_training_id_num_per_table}"
191+
assert self.training_id_eviction_trigger_count is not None, (
192+
"training_id_eviction_trigger_count must be set if eviction_strategy is 5,"
193+
f"actual {self.training_id_eviction_trigger_count}"
190194
)
191-
assert self.target_eviction_percent_per_table is not None, (
192-
"target_eviction_percent_per_table must be set if eviction_strategy is 5,"
193-
f"actual {self.target_eviction_percent_per_table}"
195+
assert self.training_id_keep_count is not None, (
196+
"training_id_keep_count must be set if eviction_strategy is 5,"
197+
f"actual {self.training_id_keep_count}"
194198
)
195199
assert self.threshold_calculation_bucket_stride is not None, (
196200
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
@@ -201,12 +205,12 @@ def validate(self) -> None:
201205
f"actual {self.threshold_calculation_bucket_num}"
202206
)
203207
assert (
204-
len(self.target_eviction_percent_per_table)
208+
len(self.training_id_keep_count)
205209
== len(self.feature_score_counter_decay_rates)
206-
== len(self.max_training_id_num_per_table)
210+
== len(self.training_id_eviction_trigger_count)
207211
), (
208-
"feature_score_thresholds, max_training_id_num_per_table and target_eviction_percent_per_table must have the same length, "
209-
f"actual {self.target_eviction_percent_per_table} vs {self.feature_score_counter_decay_rates} vs {self.max_training_id_num_per_table}"
212+
"feature_score_thresholds, training_id_eviction_trigger_count and training_id_keep_count must have the same length, "
213+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.training_id_eviction_trigger_count}"
210214
)
211215

212216

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -690,16 +690,16 @@ def __init__(
690690
)
691691
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
692692
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
693-
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
693+
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
694694
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
695695
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
696696
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
697697
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
698698
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
699699
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
700700
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
701-
self.kv_zch_params.eviction_policy.max_training_id_num_per_table, # max_training_id_num for each table
702-
self.kv_zch_params.eviction_policy.target_eviction_percent_per_table, # target_eviction_percent for each table
701+
self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
702+
self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table
703703
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
704704
table_dims.tolist() if table_dims is not None else None,
705705
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
@@ -1047,9 +1047,6 @@ def __init__(
10471047
self.stats_reporter.register_stats(
10481048
f"eviction.feature_table.{t}.processed_counts"
10491049
)
1050-
self.stats_reporter.register_stats(
1051-
f"eviction.feature_table.{t}.eviction_threshold_with_dry_run"
1052-
)
10531050
self.stats_reporter.register_stats(
10541051
f"eviction.feature_table.{t}.evict_rate"
10551052
)
@@ -3926,11 +3923,6 @@ def _report_eviction_stats(self) -> None:
39263923
data_bytes=int(processed_counts[t].item()),
39273924
enable_tb_metrics=True,
39283925
)
3929-
stats_reporter.report_data_amount(
3930-
iteration_step=self.step,
3931-
event_name=f"eviction.feature_table.{t}.eviction_threshold_with_dry_run",
3932-
data_bytes=float(eviction_threshold_with_dry_run[t].item()),
3933-
)
39343926
if processed_counts[t].item() != 0:
39353927
stats_reporter.report_data_amount(
39363928
iteration_step=self.step,

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,14 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
11011101
}
11021102
break;
11031103
}
1104+
case EvictTriggerMode::ID_COUNT: {
1105+
auto used_id_count = get_num_rows();
1106+
if (used_id_count > feature_evict_config_.value()
1107+
->total_id_eviction_trigger_count_.value()) {
1108+
trigger_feature_evict();
1109+
}
1110+
break;
1111+
}
11041112
default:
11051113
break;
11061114
}

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ void DramKVEmbeddingInferenceWrapper::init(
5656
std::nullopt /* counter_thresholds */,
5757
std::nullopt /* counter_decay_rates */,
5858
std::nullopt /* feature_score_counter_decay_rates */,
59-
std::nullopt /* max_training_id_num_per_table */,
60-
std::nullopt /* target_eviction_percent_per_table */,
59+
std::nullopt /* training_id_eviction_trigger_count */,
60+
std::nullopt /* training_id_keep_count */,
6161
std::nullopt /* l2_weight_thresholds */,
6262
std::nullopt /* embedding_dims */,
6363
std::nullopt /* threshold_calculation_bucket_stride */,

fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ enum class EvictTriggerMode {
3333
DISABLED, // Do not use feature evict
3434
ITERATION, // Trigger based on iteration steps
3535
MEM_UTIL, // Trigger based on memory usage
36-
MANUAL // Manually triggered by upstream
36+
MANUAL, // Manually triggered by upstream
37+
ID_COUNT // Trigger based on id count
3738
};
3839
inline std::string to_string(EvictTriggerMode mode) {
3940
switch (mode) {
@@ -45,6 +46,8 @@ inline std::string to_string(EvictTriggerMode mode) {
4546
return "MEM_UTIL";
4647
case EvictTriggerMode::MANUAL:
4748
return "MANUAL";
49+
case EvictTriggerMode::ID_COUNT:
50+
return "ID_COUNT";
4851
}
4952
}
5053

@@ -102,8 +105,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
102105
std::optional<std::vector<int64_t>> counter_thresholds,
103106
std::optional<std::vector<double>> counter_decay_rates,
104107
std::optional<std::vector<double>> feature_score_counter_decay_rates,
105-
std::optional<std::vector<int64_t>> max_training_id_num_per_table,
106-
std::optional<std::vector<double>> target_eviction_percent_per_table,
108+
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count,
109+
std::optional<std::vector<int64_t>> training_id_keep_count,
107110
std::optional<std::vector<double>> l2_weight_thresholds,
108111
std::optional<std::vector<int64_t>> embedding_dims,
109112
std::optional<double> threshold_calculation_bucket_stride = 0.2,
@@ -120,10 +123,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
120123
counter_decay_rates_(counter_decay_rates),
121124
feature_score_counter_decay_rates_(
122125
std::move(feature_score_counter_decay_rates)),
123-
max_training_id_num_per_table_(
124-
std::move(max_training_id_num_per_table)),
125-
target_eviction_percent_per_table_(
126-
std::move(target_eviction_percent_per_table)),
126+
training_id_eviction_trigger_count_(
127+
std::move(training_id_eviction_trigger_count)),
128+
training_id_keep_count_(std::move(training_id_keep_count)),
127129
l2_weight_thresholds_(l2_weight_thresholds),
128130
embedding_dims_(embedding_dims),
129131
threshold_calculation_bucket_stride_(
@@ -160,6 +162,28 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
160162
case EvictTriggerMode::MANUAL: {
161163
break;
162164
}
165+
case EvictTriggerMode::ID_COUNT: {
166+
CHECK(
167+
training_id_eviction_trigger_count_.has_value() &&
168+
!training_id_eviction_trigger_count_.value().empty());
169+
const auto& vec = training_id_eviction_trigger_count_.value();
170+
eviction_trigger_stats_log = ", training_id_eviction_trigger_count: [";
171+
total_id_eviction_trigger_count_ = 0;
172+
for (size_t i = 0; i < vec.size(); ++i) {
173+
total_id_eviction_trigger_count_ =
174+
total_id_eviction_trigger_count_.value() + vec[i];
175+
if (vec[i] <= 0) {
176+
throw std::runtime_error(
177+
"Invalid training_id_eviction_trigger_count, must be positive if ID_COUNT trigger mode is used");
178+
}
179+
eviction_trigger_stats_log += std::to_string(vec[i]);
180+
if (i + 1 < vec.size()) {
181+
eviction_trigger_stats_log += ", ";
182+
}
183+
}
184+
eviction_trigger_stats_log += "]";
185+
break;
186+
}
163187
default:
164188
throw std::runtime_error("Unknown evict trigger mode");
165189
}
@@ -178,18 +202,21 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
178202

179203
case EvictTriggerStrategy::BY_FEATURE_SCORE: {
180204
CHECK(feature_score_counter_decay_rates_.has_value());
181-
CHECK(max_training_id_num_per_table_.has_value());
182-
CHECK(target_eviction_percent_per_table_.has_value());
205+
CHECK(training_id_eviction_trigger_count_.has_value());
206+
CHECK(training_id_keep_count_.has_value());
207+
CHECK(total_id_eviction_trigger_count_.has_value());
183208
CHECK(threshold_calculation_bucket_stride_.has_value());
184209
CHECK(threshold_calculation_bucket_num_.has_value());
185210
CHECK(ttls_in_mins_.has_value());
186211
LOG(INFO) << "eviction config, trigger mode:"
187212
<< to_string(trigger_mode_) << eviction_trigger_stats_log
188213
<< ", strategy: " << to_string(trigger_strategy_)
189-
<< ", max_training_id_num_per_table: "
190-
<< max_training_id_num_per_table_.value()
191-
<< ", target_eviction_percent_per_table:"
192-
<< target_eviction_percent_per_table_.value()
214+
<< ", training_id_eviction_trigger_count: "
215+
<< training_id_eviction_trigger_count_.value()
216+
<< ", training_id_keep_count:"
217+
<< training_id_keep_count_.value()
218+
<< ", total_id_eviction_trigger_count: "
219+
<< total_id_eviction_trigger_count_.value()
193220
<< ", ttls_in_mins: " << ttls_in_mins_.value()
194221
<< ", threshold_calculation_bucket_stride: "
195222
<< threshold_calculation_bucket_stride_.value()
@@ -252,8 +279,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
252279
std::optional<std::vector<int64_t>> counter_thresholds_;
253280
std::optional<std::vector<double>> counter_decay_rates_;
254281
std::optional<std::vector<double>> feature_score_counter_decay_rates_;
255-
std::optional<std::vector<int64_t>> max_training_id_num_per_table_;
256-
std::optional<std::vector<double>> target_eviction_percent_per_table_;
282+
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count_;
283+
std::optional<std::vector<int64_t>> training_id_keep_count_;
284+
std::optional<int64_t> total_id_eviction_trigger_count_;
257285
std::optional<std::vector<double>> l2_weight_thresholds_;
258286
std::optional<std::vector<int64_t>> embedding_dims_;
259287
std::optional<double> threshold_calculation_bucket_stride_;
@@ -953,8 +981,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
953981
SynchronizedShardedMap<int64_t, weight_type*>& kv_store,
954982
const std::vector<int64_t>& sub_table_hash_cumsum,
955983
const std::vector<double>& decay_rates,
956-
const std::vector<int64_t>& max_training_id_num_per_table,
957-
const std::vector<double>& target_eviction_percent_per_table,
984+
const std::vector<int64_t>& training_id_eviction_trigger_count,
985+
const std::vector<int64_t>& training_id_keep_count,
958986
const std::vector<int64_t>& ttls_in_mins,
959987
const double threshold_calculation_bucket_stride,
960988
const int64_t threshold_calculation_bucket_num,
@@ -972,8 +1000,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
9721000
is_training,
9731001
test_mode),
9741002
decay_rates_(decay_rates),
975-
max_training_id_num_per_table_(max_training_id_num_per_table),
976-
target_eviction_percent_per_table_(target_eviction_percent_per_table),
1003+
training_id_eviction_trigger_count_(training_id_eviction_trigger_count),
1004+
training_id_keep_count_(training_id_keep_count),
9771005
ttls_in_mins_(ttls_in_mins),
9781006
threshold_calculation_bucket_stride_(
9791007
threshold_calculation_bucket_stride),
@@ -1007,6 +1035,13 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
10071035
bool add_new_block = false) {
10081036
int64_t key = FixedBlockPool::get_key(block);
10091037
int sub_table_id = this->get_sub_table_id(key);
1038+
1039+
if (ttls_in_mins_[sub_table_id] > 0) {
1040+
// If ttl is enabled, we don't need to populate the feature score
1041+
// bucket.
1042+
return;
1043+
}
1044+
10101045
if (add_new_block) {
10111046
double ratio = FixedBlockPool::get_feature_score_rate(block);
10121047
int64_t idx = get_bucket_id_from_ratio(ratio);
@@ -1124,11 +1159,7 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
11241159
this->local_blocks_num_per_shard_per_table_[table_id][shard_id];
11251160
}
11261161

1127-
const double target_keep_ratio =
1128-
1 - target_eviction_percent_per_table_[table_id];
1129-
int64_t max_id =
1130-
static_cast<int64_t>(max_training_id_num_per_table_[table_id]);
1131-
int64_t keep_count = static_cast<int64_t>(max_id * target_keep_ratio);
1162+
int64_t keep_count = training_id_keep_count_[table_id];
11321163
int64_t evict_count = total - keep_count;
11331164

11341165
int64_t acc_count = 0;
@@ -1202,10 +1233,12 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12021233
std::vector<double> thresholds_; // Threshold for eviction.
12031234

12041235
const std::vector<int64_t>&
1205-
max_training_id_num_per_table_; // training max id for each table.
1206-
const std::vector<double>&
1207-
target_eviction_percent_per_table_; // target eviction percent for
1208-
// each table
1236+
training_id_eviction_trigger_count_; // training id num for trigger
1237+
// eviction.
1238+
const std::vector<int64_t>&
1239+
training_id_keep_count_; // target keep training id num per table after
1240+
// eviction.
1241+
12091242
const std::vector<int64_t>& ttls_in_mins_; // Time-to-live for eviction.
12101243
std::vector<std::vector<std::vector<size_t>>>
12111244
local_buckets_per_shard_per_table_;
@@ -1453,8 +1486,8 @@ std::unique_ptr<FeatureEvict<weight_type>> create_feature_evict(
14531486
kv_store,
14541487
sub_table_hash_cumsum,
14551488
config->feature_score_counter_decay_rates_.value(),
1456-
config->max_training_id_num_per_table_.value(),
1457-
config->target_eviction_percent_per_table_.value(),
1489+
config->training_id_eviction_trigger_count_.value(),
1490+
config->training_id_keep_count_.value(),
14581491
config->ttls_in_mins_.value(),
14591492
config->threshold_calculation_bucket_stride_.value(),
14601493
config->threshold_calculation_bucket_num_.value(),

0 commit comments

Comments
 (0)