Skip to content

Commit 110495c

Browse files
committed
refactor priority function
1 parent db8bed6 commit 110495c

File tree

3 files changed

+102
-44
lines changed

3 files changed

+102
-44
lines changed

trinity/buffer/storage/queue.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections import deque
77
from copy import deepcopy
8-
from functools import partial
9-
from typing import List, Optional, Tuple
8+
from typing import Dict, List, Optional, Tuple
109

1110
import numpy as np
1211
import ray
@@ -28,48 +27,83 @@ def is_json_file(path: str) -> bool:
2827

2928

3029
PRIORITY_FUNC = Registry("priority_fn")
31-
"""
32-
Each priority_fn,
33-
Args:
34-
item: List[Experience], assume that all experiences in it have the same model_version and use_count
35-
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
36-
Returns:
37-
priority: float
38-
put_into_queue: bool, decide whether to put item into queue
39-
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
40-
"""
30+
31+
32+
class PriorityFunction(ABC):
33+
"""
34+
Each priority_fn,
35+
Args:
36+
item: List[Experience], assume that all experiences in it have the same model_version and use_count
37+
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
38+
Returns:
39+
priority: float
40+
put_into_queue: bool, decide whether to put item into queue
41+
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
42+
"""
43+
44+
def __init__(self, **kwargs):
45+
pass
46+
47+
@abstractmethod
48+
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
49+
"""Calculate the priority of items."""
50+
51+
@classmethod
52+
@abstractmethod
53+
def default_config(cls) -> Dict:
54+
"""Return the default config."""
4155

4256

4357
@PRIORITY_FUNC.register_module("linear_decay")
44-
def linear_decay_priority(
45-
item: List[Experience],
46-
decay: float = 2.0,
47-
) -> Tuple[float, bool]:
58+
class LinearDecayPriority(PriorityFunction):
4859
"""Calculate priority by linear decay.
4960
5061
Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
5162
"""
52-
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
53-
put_into_queue = True
54-
return priority, put_into_queue
55-
56-
57-
@PRIORITY_FUNC.register_module("linear_decay_use_count_control_randomization")
58-
def linear_decay_use_count_control_priority(
59-
item: List[Experience],
60-
decay: float = 2.0,
61-
use_count_limit: int = 3,
62-
sigma: float = 0.0,
63-
) -> Tuple[float, bool]:
63+
64+
def __init__(self, decay: float = 2.0, **kwargs):
65+
self.decay = decay
66+
67+
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
68+
priority = float(items[0].info["model_version"] - self.decay * items[0].info["use_count"])
69+
put_into_queue = True
70+
return priority, put_into_queue
71+
72+
@classmethod
73+
def default_config(cls) -> Dict:
74+
return {
75+
"decay": 2.0,
76+
}
77+
78+
79+
@PRIORITY_FUNC.register_module("decay_limit_randomization")
80+
class LinearDecayUseCountControlPriority(PriorityFunction):
6481
"""Calculate priority by linear decay, use count control, and randomization.
6582
6683
Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
6784
"""
68-
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
69-
if sigma > 0.0:
70-
priority += float(np.random.randn() * sigma)
71-
put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True
72-
return priority, put_into_queue
85+
86+
def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0, **kwargs):
87+
self.decay = decay
88+
self.use_count_limit = use_count_limit
89+
self.sigma = sigma
90+
91+
def __call__(self, items: List[Experience]) -> Tuple[float, bool]:
92+
priority = float(items[0].info["model_version"] - self.decay * items[0].info["use_count"])
93+
if self.sigma > 0.0:
94+
priority += float(np.random.randn() * self.sigma)
95+
put_into_queue = (
96+
items[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True
97+
)
98+
return priority, put_into_queue
99+
100+
@classmethod
101+
def default_config(cls) -> Dict:
102+
return {
103+
"decay": 2.0,
104+
"use_count_limit": 3,
105+
"sigma": 0.0,
106+
}
73107

74108

75109
class QueueBuffer(ABC):
@@ -168,7 +202,9 @@ def __init__(
168202
self.capacity = capacity
169203
self.item_count = 0
170204
self.priority_groups = SortedDict() # Maps priority -> deque of items
171-
self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **(priority_fn_args or {}))
205+
priority_fn_cls = PRIORITY_FUNC.get(priority_fn)
206+
kwargs = priority_fn_cls.default_config().update(priority_fn_args or {})
207+
self.priority_fn = priority_fn_cls(**kwargs)
172208
self.reuse_cooldown_time = reuse_cooldown_time
173209
self._condition = asyncio.Condition() # For thread-safe operations
174210
self._closed = False

trinity/manager/config_manager.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
1515
from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY
1616
from trinity.common.constants import StorageType
17-
from trinity.manager.config_registry.buffer_config_manager import get_train_batch_size
17+
from trinity.manager.config_registry.buffer_config_manager import (
18+
get_train_batch_size,
19+
parse_priority_fn_args,
20+
)
1821
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
1922
from trinity.manager.config_registry.trainer_config_manager import use_critic
2023
from trinity.utils.plugin_loader import load_plugins
@@ -190,7 +193,8 @@ def _expert_buffer_part(self):
190193
self.get_configs("storage_type")
191194
self.get_configs("experience_buffer_path")
192195
self.get_configs("enable_replay_buffer")
193-
self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay")
196+
self.get_configs("reuse_cooldown_time", "priority_fn")
197+
self.get_configs("priority_fn_args")
194198

195199
# TODO: used for SQL storage
196200
# self.buffer_advanced_tab = st.expander("Advanced Config")
@@ -592,9 +596,7 @@ def _gen_buffer_config(self):
592596
"enable": st.session_state["enable_replay_buffer"],
593597
"priority_fn": st.session_state["priority_fn"],
594598
"reuse_cooldown_time": st.session_state["reuse_cooldown_time"],
595-
"priority_fn_args": {
596-
"decay": st.session_state["priority_decay"],
597-
},
599+
"priority_fn_args": parse_priority_fn_args(st.session_state["priority_fn_args"]),
598600
}
599601

600602
if st.session_state["mode"] != "train":

trinity/manager/config_registry/buffer_config_manager.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import json
2+
3+
import pandas as pd
14
import streamlit as st
25

36
from trinity.buffer.storage.queue import PRIORITY_FUNC
@@ -328,13 +331,30 @@ def set_priority_fn(**kwargs):
328331
)
329332

330333

334+
def parse_priority_fn_args(raw_data: str):
335+
try:
336+
data = json.loads(raw_data)
337+
if data["priority_fn"] != st.session_state["priority_fn"]:
338+
raise ValueError
339+
return data["fn_args"]
340+
except Exception:
341+
return PRIORITY_FUNC.get(st.session_state["priority_fn"]).default_config()
342+
343+
331344
@CONFIG_GENERATORS.register_config(
332-
default_value=0.1, visible=lambda: st.session_state["enable_replay_buffer"]
345+
default_value="", visible=lambda: st.session_state["enable_replay_buffer"]
333346
)
334-
def set_priority_decay(**kwargs):
335-
st.number_input(
336-
"Priority Decay",
337-
**kwargs,
347+
def set_priority_fn_args(**kwargs):
348+
key = kwargs.get("key")
349+
df = pd.DataFrame([parse_priority_fn_args(st.session_state[key])])
350+
df.index = [st.session_state["priority_fn"]]
351+
st.caption("Priority Function Args")
352+
df = st.data_editor(df)
353+
st.session_state[key] = json.dumps(
354+
{
355+
"fn_args": df.to_dict(orient="records")[0],
356+
"priority_fn": st.session_state["priority_fn"],
357+
}
338358
)
339359

340360

0 commit comments

Comments
 (0)