Skip to content

Commit ceedc7c

Browse files
committed
apply suggestions from reviews
1 parent 110495c commit ceedc7c

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

tests/buffer/queue_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ async def test_priority_queue_reuse_count_control(self):
326326
path=BUFFER_FILE_PATH,
327327
replay_buffer=ReplayBufferConfig(
328328
enable=True,
329-
priority_fn="linear_decay_use_count_control_randomization",
329+
priority_fn="decay_limit_randomization",
330330
reuse_cooldown_time=0.5,
331331
priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0},
332332
),

trinity/buffer/storage/queue.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,18 @@ class PriorityFunction(ABC):
3434
Each priority_fn,
3535
Args:
3636
item: List[Experience], assume that all experiences in it have the same model_version and use_count
37-
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
37+
priority_fn_args: Dict, the arguments for priority_fn
38+
3839
Returns:
3940
priority: float
4041
put_into_queue: bool, decide whether to put item into queue
42+
4143
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
4244
"""
4345

44-
def __init__(self, **kwargs):
45-
pass
46-
4746
@abstractmethod
48-
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
49-
"""Calculate the priority of items."""
47+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
48+
"""Calculate the priority of item."""
5049

5150
@classmethod
5251
@abstractmethod
@@ -61,11 +60,11 @@ class LinearDecayPriority(PriorityFunction):
6160
Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
6261
"""
6362

64-
def __init__(self, decay: float = 2.0, **kwargs):
63+
def __init__(self, decay: float = 2.0):
6564
self.decay = decay
6665

67-
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
68-
priority = float(items[0].info["model_version"] - self.decay * items[0].info["use_count"])
66+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
67+
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
6968
put_into_queue = True
7069
return priority, put_into_queue
7170

@@ -83,17 +82,17 @@ class LinearDecayUseCountControlPriority(PriorityFunction):
8382
Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
8483
"""
8584

86-
def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0, **kwargs):
85+
def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0):
8786
self.decay = decay
8887
self.use_count_limit = use_count_limit
8988
self.sigma = sigma
9089

91-
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
92-
priority = float(items[0].info["model_version"] - self.decay * items[0].info["use_count"])
90+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
91+
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
9392
if self.sigma > 0.0:
9493
priority += float(np.random.randn() * self.sigma)
9594
put_into_queue = (
96-
items[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True
95+
item[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True
9796
)
9897
return priority, put_into_queue
9998

@@ -203,7 +202,8 @@ def __init__(
203202
self.item_count = 0
204203
self.priority_groups = SortedDict() # Maps priority -> deque of items
205204
priority_fn_cls = PRIORITY_FUNC.get(priority_fn)
206-
kwargs = priority_fn_cls.default_config().update(priority_fn_args or {})
205+
kwargs = priority_fn_cls.default_config()
206+
kwargs.update(priority_fn_args or {})
207207
self.priority_fn = priority_fn_cls(**kwargs)
208208
self.reuse_cooldown_time = reuse_cooldown_time
209209
self._condition = asyncio.Condition() # For thread-safe operations

trinity/manager/config_registry/buffer_config_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def parse_priority_fn_args(raw_data: str):
337337
if data["priority_fn"] != st.session_state["priority_fn"]:
338338
raise ValueError
339339
return data["fn_args"]
340-
except Exception:
340+
except (json.JSONDecodeError, KeyError, ValueError):
341+
print(f"Use `default_config` for {st.session_state['priority_fn']}")
341342
return PRIORITY_FUNC.get(st.session_state["priority_fn"]).default_config()
342343

343344

0 commit comments

Comments
 (0)