11from abc import ABC , abstractmethod
2- from typing import Any , Dict , List , Tuple
2+ from typing import Dict , List , Tuple
33
44from trinity .algorithm .sample_strategy .utils import representative_sample
55from trinity .buffer import get_buffer_reader
66from trinity .common .config import BufferConfig
7- from trinity .common .experience import Experiences
7+ from trinity .common .experience import Experience , Experiences
88from trinity .utils .annotations import Deprecated
99from trinity .utils .monitor import gather_metrics
1010from trinity .utils .registry import Registry
@@ -17,6 +17,14 @@ class SampleStrategy(ABC):
1717 def __init__ (self , buffer_config : BufferConfig , ** kwargs ) -> None :
1818 self .pad_token_id = buffer_config .pad_token_id
1919
20+ def set_model_version_metric (self , exp_list : List [Experience ], metrics : Dict ):
21+ metric_list = [
22+ {"model_version" : exp .info ["model_version" ]}
23+ for exp in exp_list
24+ if "model_version" in exp .info
25+ ]
26+ metrics .update (gather_metrics (metric_list , "sample" ))
27+
2028 @abstractmethod
2129 async def sample (self , step : int ) -> Tuple [Experiences , Dict , List ]:
2230 """Sample data from buffer.
@@ -42,13 +50,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
4250 super ().__init__ (buffer_config )
4351 self .exp_buffer = get_buffer_reader (buffer_config .trainer_input .experience_buffer ) # type: ignore[arg-type]
4452
45- async def sample (self , step : int , ** kwargs ) -> Tuple [Any , Dict , List ]:
53+ async def sample (self , step : int , ** kwargs ) -> Tuple [Experiences , Dict , List ]:
4654 metrics = {}
4755 with Timer (metrics , "time/read_experience" ):
4856 exp_list = await self .exp_buffer .read_async ()
4957 repr_samples = representative_sample (exp_list )
50- metric_list = [{"model_version" : exp .info ["model_version" ]} for exp in exp_list ]
51- metrics .update (gather_metrics (metric_list , "sample" ))
58+ self .set_model_version_metric (exp_list , metrics )
5259 with Timer (metrics , "time/gather_experience" ):
5360 exps = Experiences .gather_experiences (exp_list , self .pad_token_id ) # type: ignore
5461 return exps , metrics , repr_samples
0 commit comments