Skip to content

Commit 9a91baf

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Fix EvictionPolicy init for inference_eviction_threshold (pytorch#3266)
Summary: Pull Request resolved: pytorch#3266 As title Reviewed By: yixin94, emlin Differential Revision: D79822053 fbshipit-source-id: 330a17285c99d86edef2001b954b3bbf2973e97c
1 parent 0167503 commit 9a91baf

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

torchrec/modules/embedding_configs.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,14 @@ class CountBasedEvictionPolicy(VirtualTableEvictionPolicy):
194194
15 # eviction threshold for count based eviction policy. 0 means no eviction
195195
)
196196
decay_rate: float = 0.99 # default decay by default
197-
inference_eviction_threshold: int = (
198-
eviction_threshold # eviction threshold for inference count based eviction policy. 0 means no eviction
197+
inference_eviction_threshold: Optional[int] = (
198+
None # eviction threshold for inference count based eviction policy. 0 means no eviction
199199
)
200200

201+
def __post_init__(self) -> None:
202+
if self.inference_eviction_threshold is None:
203+
self.inference_eviction_threshold = self.eviction_threshold
204+
201205

202206
@dataclass
203207
class TimestampBasedEvictionPolicy(VirtualTableEvictionPolicy):
@@ -206,7 +210,11 @@ class TimestampBasedEvictionPolicy(VirtualTableEvictionPolicy):
206210
"""
207211

208212
eviction_ttl_mins: int = 24 * 60 # 1 day. 0 means no eviction
209-
inference_eviction_ttl_mins: int = eviction_ttl_mins # 0 means no eviction
213+
inference_eviction_ttl_mins: Optional[int] = None # 0 means no eviction
214+
215+
def __post_init__(self) -> None:
216+
if self.inference_eviction_ttl_mins is None:
217+
self.inference_eviction_ttl_mins = self.eviction_ttl_mins
210218

211219

212220
@dataclass
@@ -220,14 +228,21 @@ class CountTimestampMixedEvictionPolicy(VirtualTableEvictionPolicy):
220228
)
221229
decay_rate: float = 0.99 # default decay by default
222230
eviction_ttl_mins: int = 24 * 60 # 1 day. 0 means no eviction based on timestamp
223-
inference_eviction_threshold: int = (
224-
eviction_threshold # eviction threshold for inference count based eviction policy. 0 means no eviction based on count
231+
inference_eviction_threshold: Optional[int] = (
232+
None # eviction threshold for inference count based eviction policy. 0 means no eviction based on count
225233
)
226234

227-
inference_eviction_ttl_mins: int = (
228-
eviction_ttl_mins # 0 means no eviction based on timestamp
235+
inference_eviction_ttl_mins: Optional[int] = (
236+
None # 0 means no eviction based on timestamp
229237
)
230238

239+
def __post_init__(self) -> None:
240+
if self.inference_eviction_ttl_mins is None:
241+
self.inference_eviction_ttl_mins = self.eviction_ttl_mins
242+
243+
if self.inference_eviction_threshold is None:
244+
self.inference_eviction_threshold = self.eviction_threshold
245+
231246

232247
@dataclass
233248
class FeatureL2NormBasedEvictionPolicy(VirtualTableEvictionPolicy):
@@ -238,7 +253,11 @@ class FeatureL2NormBasedEvictionPolicy(VirtualTableEvictionPolicy):
238253
eviction_threshold: float = (
239254
0.0 # eviction threshold for feature l2 norm based eviction policy. 0.0 means no eviction
240255
)
241-
inference_eviction_threshold: float = eviction_threshold
256+
inference_eviction_threshold: Optional[float] = None
257+
258+
def __post_init__(self) -> None:
259+
if self.inference_eviction_threshold is None:
260+
self.inference_eviction_threshold = self.eviction_threshold
242261

243262

244263
@dataclass

0 commit comments

Comments
 (0)