Skip to content

Commit cfeb39c

Browse files
committed
apply suggestions from gemini
1 parent 87e95de commit cfeb39c

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
8585
exp_list = usual_exp_list + expert_exp_list
8686
repr_samples = representative_sample(exp_list)
8787

88+
self.set_model_version_metric(exp_list, metrics)
8889
with Timer(metrics, "time/gather_experience"):
8990
exps = Experiences.gather_experiences(
9091
experiences=exp_list,

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Dict, List, Tuple
33

44
from trinity.algorithm.sample_strategy.utils import representative_sample
55
from trinity.buffer import get_buffer_reader
66
from trinity.common.config import BufferConfig
7-
from trinity.common.experience import Experiences
7+
from trinity.common.experience import Experience, Experiences
88
from trinity.utils.annotations import Deprecated
99
from trinity.utils.monitor import gather_metrics
1010
from trinity.utils.registry import Registry
@@ -17,6 +17,14 @@ class SampleStrategy(ABC):
1717
def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
1818
self.pad_token_id = buffer_config.pad_token_id
1919

20+
def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
21+
metric_list = [
22+
{"model_version": exp.info["model_version"]}
23+
for exp in exp_list
24+
if "model_version" in exp.info
25+
]
26+
metrics.update(gather_metrics(metric_list, "sample"))
27+
2028
@abstractmethod
2129
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
2230
"""Sample data from buffer.
@@ -42,13 +50,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
4250
super().__init__(buffer_config)
4351
self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]
4452

45-
async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
53+
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
4654
metrics = {}
4755
with Timer(metrics, "time/read_experience"):
4856
exp_list = await self.exp_buffer.read_async()
4957
repr_samples = representative_sample(exp_list)
50-
metric_list = [{"model_version": exp.info["model_version"]} for exp in exp_list]
51-
metrics.update(gather_metrics(metric_list, "sample"))
58+
self.set_model_version_metric(exp_list, metrics)
5259
with Timer(metrics, "time/gather_experience"):
5360
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
5461
return exps, metrics, repr_samples

trinity/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
125125
List[Dict]: A list of representative samples for logging.
126126
"""
127127
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
128-
metrics["sample/task_count"] = len(set(f"{eid.batch}_{eid.task}" for eid in batch.eids))
128+
metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids))
129129
return batch, metrics, repr_samples
130130

131131
async def need_sync(self) -> bool:

0 commit comments

Comments
 (0)