Skip to content

Commit 58c5599

Browse files
committed
support task selector
1 parent cb3f16b commit 58c5599

File tree

16 files changed

+463
-53
lines changed

16 files changed

+463
-53
lines changed

trinity/buffer/buffer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig
2020
from trinity.buffer.reader.file_reader import (
2121
ExperienceFileReader,
2222
TaskFileReader,
23+
TaskFileReaderWithSelector,
2324
)
2425

2526
schema_type = storage_config.schema_type
2627
if schema_type:
2728
# only trainer input has schema type
2829
return ExperienceFileReader(storage_config, buffer_config)
2930
else:
30-
return TaskFileReader(storage_config, buffer_config)
31+
if storage_config.task_selector:
32+
return TaskFileReaderWithSelector(storage_config, buffer_config)
33+
else:
34+
return TaskFileReader(storage_config, buffer_config)
3135
else:
3236
raise ValueError(f"{storage_config.storage_type} not supported.")
3337

trinity/buffer/buffer_reader.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Reader of the buffer."""
22
from abc import ABC, abstractmethod
3-
from typing import List, Optional
3+
from typing import Dict, List, Optional
44

55

66
class BufferReader(ABC):
@@ -13,3 +13,14 @@ def read(self, batch_size: Optional[int] = None) -> List:
1313
@abstractmethod
1414
async def read_async(self, batch_size: Optional[int] = None) -> List:
1515
"""Read from buffer asynchronously."""
16+
17+
@property
18+
@abstractmethod
19+
def index(self) -> int:
20+
"""Get the current index."""
21+
22+
def state_dict(self) -> Dict:
23+
return {}
24+
25+
def load_state_dict(self, state_dict: Dict) -> None:
26+
pass
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from collections import defaultdict
2+
from typing import Dict, List, Optional
3+
4+
import numpy as np
5+
from sqlalchemy import Tuple
6+
7+
from trinity.buffer.operators.experience_operator import (
8+
EXPERIENCE_OPERATORS,
9+
ExperienceOperator,
10+
)
11+
from trinity.buffer.task_scheduler import TASKSET_SCHEDULE_METRIC
12+
from trinity.common.experience import Experience
13+
14+
15+
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
16+
class PassRateCalculator(ExperienceOperator):
17+
def __init__(self, reward_shaping_configs: Optional[List[Dict]] = None):
18+
self.reward_shaping_configs = reward_shaping_configs
19+
20+
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
21+
raw_metric = defaultdict(list)
22+
for exp in exps:
23+
raw_metric[exp.task_index].append(exp.reward)
24+
metric = {task_index: np.mean(rewards) for task_index, rewards in raw_metric.items()}
25+
return exps, {TASKSET_SCHEDULE_METRIC: metric}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import torch
5+
6+
from trinity.common.config import DataSelectorConfig, StorageConfig
7+
8+
from .diff_estimator import InterpolationBetaPREstimator
9+
10+
11+
def build_diff_estimator(dataset, config: DataSelectorConfig):
12+
print(f"[DEBUG]: {config=}")
13+
feature_keys = config.feature_keys
14+
features = np.concat([np.array(list(dataset[k]))[:, None] for k in feature_keys], axis=1)
15+
print(f"[DEBUG]: {features.shape=}")
16+
print(f"[DEBUG]: {features[:5]=}")
17+
adaptive_rho = hasattr(config, "adaptive_rho") and config.adaptive_rho
18+
return InterpolationBetaPREstimator(
19+
features=features, m=config.m, lamb=config.lamb, rho=config.rho, adaptive_rho=adaptive_rho
20+
)
21+
22+
23+
class BaseSelector:
24+
def __init__(self, data_source, config: DataSelectorConfig):
25+
self.data_source = data_source
26+
self.config = config
27+
28+
def get_indices(self, batch_size: int, return_extra_info: bool = False):
29+
raise NotImplementedError
30+
31+
def update(self, indices: List[int], values: List[float]):
32+
raise NotImplementedError
33+
34+
35+
class RandomSelector(BaseSelector):
36+
def __init__(self, data_source, config: DataSelectorConfig):
37+
super().__init__(data_source, config)
38+
self.n = len(data_source)
39+
print(f"[DEBUG]: RandomSelector-{self.n=}")
40+
41+
def get_indices(self, batch_size, return_extra_info=False):
42+
selected_indices = torch.from_numpy(np.random.permutation(self.n)[:batch_size])
43+
print(f"[DEBUG]: RandomSelector-{selected_indices=}")
44+
if return_extra_info:
45+
return selected_indices, {}
46+
else:
47+
return selected_indices
48+
49+
def update(self, *args, **kwargs):
50+
pass
51+
52+
53+
class OfflineEasy2HardSelector(BaseSelector):
54+
def __init__(self, data_source, config: DataSelectorConfig):
55+
super().__init__(data_source, config)
56+
57+
feature_keys = config.feature_keys
58+
self.features = np.concat(
59+
[np.array(list(data_source[k]))[:, None] for k in feature_keys], axis=1
60+
)
61+
features_with_index = [list(self.features[i]) + [i] for i in range(len(self.features))]
62+
features_with_index = sorted(features_with_index)[::-1]
63+
print(f"[DEBUG]: OfflineEasy2HardSelector, sorted {features_with_index[:20]}")
64+
self.sorted_index = np.array([i[2] for i in features_with_index])
65+
66+
self.n = len(data_source)
67+
self.current_position = 0
68+
69+
def update(self, *args, **kwargs) -> None:
70+
pass
71+
72+
def get_indices(self, batch_size, return_extra_info=False):
73+
if self.current_position + batch_size > self.n:
74+
new_position = self.current_position + batch_size - self.n
75+
selected_indices = np.concatenate(
76+
[self.sorted_index[self.current_position :], self.sorted_index[:new_position]]
77+
)
78+
else:
79+
new_position = self.current_position + batch_size
80+
selected_indices = self.sorted_index[self.current_position : new_position]
81+
self.current_position = new_position
82+
if not return_extra_info:
83+
return selected_indices
84+
else:
85+
extra_info = {
86+
"indices": selected_indices.tolist(),
87+
"feat1": self.features[selected_indices, 0].tolist(),
88+
"feat2": self.features[selected_indices, 1].tolist(),
89+
}
90+
return selected_indices, extra_info
91+
92+
93+
class DiffBasedSelector(BaseSelector):
94+
def __init__(self, data_source, config: DataSelectorConfig) -> None:
95+
super().__init__(data_source, config)
96+
self.diff_estimator = build_diff_estimator(data_source, config)
97+
98+
def update(self, indices: List[int], values: List[float]) -> None:
99+
self.diff_estimator.update(indices, values)
100+
101+
def get_scores(self) -> List[float]:
102+
predicted_pr = self.diff_estimator.predict_pr(do_sample=self.config.do_sample)
103+
scores = -np.abs(self.config.target_reward - predicted_pr)
104+
return scores
105+
106+
def get_indices(self, batch_size, return_extra_info=False):
107+
sampling_scores = self.get_scores()
108+
sampling_scores = torch.from_numpy(sampling_scores)
109+
if self.config.tau == 0:
110+
selected_indices = torch.topk(sampling_scores, batch_size).indices
111+
else:
112+
sampling_logits = sampling_scores / self.config.tau
113+
sampling_logits -= sampling_logits.max()
114+
sampling_probabilities = torch.softmax(sampling_logits, dim=0)
115+
selected_indices = torch.multinomial(
116+
sampling_probabilities, batch_size, replacement=False
117+
)
118+
print(f"[DEBUG]: {selected_indices=}")
119+
print(f"[DEBUG]: {sampling_scores=}")
120+
print(f"[DEBUG]: {sampling_scores[selected_indices]=}")
121+
122+
if return_extra_info:
123+
selected_indices_list = selected_indices.tolist()
124+
alphas = self.diff_estimator.alphas[selected_indices_list]
125+
betas = self.diff_estimator.betas[selected_indices_list]
126+
point_est = alphas / (alphas + betas)
127+
extra_info = {
128+
"indices": selected_indices_list,
129+
"scores": sampling_scores[selected_indices].tolist(),
130+
"alphas": alphas.tolist(),
131+
"betas": betas.tolist(),
132+
"point": point_est.tolist(),
133+
}
134+
return selected_indices, extra_info
135+
else:
136+
return selected_indices
137+
138+
139+
def build_selector(dataset, config: StorageConfig) -> BaseSelector:
140+
selector_config = config.task_selector
141+
assert selector_config is not None
142+
selector_type = selector_config.selector_type
143+
if selector_type == "random":
144+
return RandomSelector(dataset, selector_config)
145+
elif selector_type == "diff":
146+
return DiffBasedSelector(dataset, selector_config)
147+
elif selector_type == "offline":
148+
return OfflineEasy2HardSelector(dataset, selector_config)
149+
else:
150+
raise ValueError(f"Unknown selector type: {selector_type}")
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from typing import List
2+
3+
import numpy as np
4+
5+
6+
class BaseBetaPREstimator:
7+
n: int
8+
m: int
9+
lamb: float
10+
rho: float
11+
alphas: np.ndarray
12+
betas: np.ndarray
13+
14+
def __init__(self, n: int, m: int = 16, lamb: float = 0.2, rho: float = 0.2):
15+
"""
16+
alpha_{t+1} = (1 - lamb) * alpha_t + (1 - rho) * bar{s} + rho * tilde{s}
17+
beta_{t+1} = (1 - lamb) beta_t + (1 - rho) * bar{f} + rho * tilde{f}
18+
:param n: number of tasks
19+
:param m: repeat times per tasks
20+
:param lamb: discount factor of historical estimation
21+
:param rho: weight of pseudo counts
22+
"""
23+
self.n = n
24+
self.m = m
25+
self.lamb = lamb
26+
self.rho = rho
27+
self.alphas = np.ones(n, dtype=float)
28+
self.betas = np.ones(n, dtype=float)
29+
print(
30+
f"[DEBUG] {self.n=}, {self.m=}, {self.lamb=}, {self.rho=}, {self.alphas=}, {self.betas=}"
31+
)
32+
33+
def set(self, alphas, betas):
34+
self.alphas = alphas
35+
self.betas = betas
36+
37+
def _update(self, s_bar, f_bar, p_tilde):
38+
self.alphas = (
39+
(1 - self.lamb) * self.alphas
40+
+ self.lamb
41+
+ (1 - self.rho) * s_bar
42+
+ self.rho * p_tilde * self.m
43+
)
44+
self.betas = (
45+
(1 - self.lamb) * self.betas
46+
+ self.lamb
47+
+ (1 - self.rho) * f_bar
48+
+ self.rho * (1 - p_tilde) * self.m
49+
)
50+
51+
def update(self, ref_indices: List[int], ref_pass_rates: List[float]):
52+
raise NotImplementedError
53+
54+
def predict_pr(self, indices=None, do_sample=False):
55+
if indices is None:
56+
indices = np.arange(self.n)
57+
if not do_sample:
58+
return self.alphas[indices] / (self.alphas[indices] + self.betas[indices])
59+
else:
60+
return np.random.beta(self.alphas[indices], self.betas[indices])
61+
62+
def equivalent_count(self, indices=None):
63+
if indices is None:
64+
indices = np.arange(self.n)
65+
return self.alphas[indices] + self.betas[indices]
66+
67+
68+
class InterpolationBetaPREstimator(BaseBetaPREstimator):
69+
def __init__(
70+
self,
71+
features: np.ndarray,
72+
m: int,
73+
lamb,
74+
rho,
75+
cap_coef_update_discount=0.9,
76+
adaptive_rho=False,
77+
):
78+
super(InterpolationBetaPREstimator, self).__init__(len(features), m, lamb, rho)
79+
self.features = features # [D, 2]
80+
self.cap_coef = None
81+
self.cap_coef_update_discount = cap_coef_update_discount
82+
self.adaptive_rho = adaptive_rho
83+
84+
def update(self, ref_indices: List[int], ref_pass_rates: List[float]):
85+
ref_pass_rate = np.mean(ref_pass_rates)
86+
ref_anchor_pass_rates = np.mean(self.features[ref_indices], axis=0)
87+
cap_estimate = (ref_pass_rate - ref_anchor_pass_rates[0]) / (
88+
ref_anchor_pass_rates[1] - ref_anchor_pass_rates[0] + 1e-6
89+
)
90+
if self.cap_coef is None:
91+
self.cap_coef = cap_estimate
92+
else:
93+
self.cap_coef = (
94+
self.cap_coef_update_discount * self.cap_coef
95+
+ (1 - self.cap_coef_update_discount) * cap_estimate
96+
)
97+
s_bar = np.zeros(self.n, dtype=float)
98+
s_bar[ref_indices] = np.array(ref_pass_rates) * self.m
99+
f_bar = np.zeros(self.n, dtype=float)
100+
f_bar[ref_indices] = (1 - np.array(ref_pass_rates)) * self.m
101+
p_tilde = np.clip(
102+
(self.features[:, 1] - self.features[:, 0]) * self.cap_coef + self.features[:, 0], 0, 1
103+
)
104+
105+
predicted_pass_rates = p_tilde[ref_indices]
106+
mean_abs_error = np.mean(np.abs(np.array(predicted_pass_rates) - np.array(ref_pass_rates)))
107+
if self.adaptive_rho and mean_abs_error >= 0.25:
108+
self.rho = self.rho * 0.5
109+
print(f"[DEBUG]: {mean_abs_error=}, {self.rho=}")
110+
p_tilde[ref_indices] = np.array(ref_pass_rates)
111+
112+
self._update(s_bar, f_bar, p_tilde)

0 commit comments

Comments
 (0)