diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py new file mode 100644 index 0000000000..2fcf1242ce --- /dev/null +++ b/tests/buffer/file_test.py @@ -0,0 +1,32 @@ +import unittest + +from tests.tools import get_template_config, get_unittest_dataset_config +from trinity.buffer.buffer import get_buffer_reader + + +class TestFileReader(unittest.TestCase): + def test_file_reader(self): + """Test file reader.""" + config = get_template_config() + dataset_config = get_unittest_dataset_config("countdown", "train") + config.buffer.explorer_input.taskset = dataset_config + reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + + tasks = [] + while True: + try: + tasks.extend(reader.read()) + except StopIteration: + break + self.assertEqual(len(tasks), 16) + + config.buffer.explorer_input.taskset.total_epochs = 2 + config.buffer.explorer_input.taskset.index = 4 + reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + tasks = [] + while True: + try: + tasks.extend(reader.read()) + except StopIteration: + break + self.assertEqual(len(tasks), 16 * 2 - 4) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index e06b133256..262a2bcd3e 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -1,3 +1,5 @@ +import os + import torch from tests.tools import RayUnittestBase @@ -7,6 +9,8 @@ from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience +file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") + class TestQueueBuffer(RayUnittestBase): def test_queue_buffer(self): @@ -17,6 +21,7 @@ def test_queue_buffer(self): name="test_buffer", algorithm_type=AlgorithmType.PPO, storage_type=StorageType.QUEUE, + path=file_path, ) config = BufferConfig( max_retry_times=3, @@ -36,9 +41,29 @@ def test_queue_buffer(self): ] for _ in range(total_num // put_batch_size): writer.write(exps) - writer.finish() for _ in range(total_num // read_batch_size): exps = reader.read() self.assertEqual(len(exps), read_batch_size) print(f"finish read {read_batch_size} experience") + writer.write( + [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + ) + for i in range(1, put_batch_size * 2 + 1) + ] + ) + exps = reader.read(batch_size=put_batch_size * 2) + self.assertEqual(len(exps), put_batch_size * 2) + writer.finish() self.assertRaises(StopIteration, reader.read) + with open(file_path, "r") as f: + self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) + + def setUp(self): + if os.path.exists(file_path): + os.remove(file_path) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 222beebbb3..2146794ebd 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -47,5 +47,21 @@ def test_create_sql_buffer(self) -> None: for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) + + # dynamic read/write + sql_writer.write( + [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + ) + for i in range(1, put_batch_size * 2 + 1) + ] + ) + exps = sql_reader.read(batch_size=put_batch_size * 2) + self.assertEqual(len(exps), put_batch_size * 2) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index 4676607d65..e5894b7521 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -9,5 +9,7 @@ class BufferReader(ABC): """Interface of the buffer reader.""" @abstractmethod - def read(self, strategy: Optional[ReadStrategy] = None) -> List: + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: """Read from buffer.""" diff --git a/trinity/buffer/db_wrapper.py b/trinity/buffer/db_wrapper.py index 977aaae493..bbf96c176e 100644 --- a/trinity/buffer/db_wrapper.py +++ b/trinity/buffer/db_wrapper.py @@ -61,7 +61,9 @@ def write(self, data: list) -> None: experience_models = [self.table_model_cls.from_experience(exp) for exp in data] session.add_all(experience_models) - def read(self, strategy: Optional[ReadStrategy] = None) -> List: + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: if strategy is None: strategy = ReadStrategy.LFU @@ -78,7 +80,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") exp_list = [] - while len(exp_list) < self.batch_size: + batch_size = batch_size or self.batch_size + while len(exp_list) < batch_size: if len(exp_list): self.logger.info("waiting for experiences...") time.sleep(1) @@ -90,7 +93,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: session.query(self.table_model_cls) .filter(self.table_model_cls.reward.isnot(None)) .order_by(*sortOrder) # TODO: very slow - .limit(self.batch_size - len(exp_list)) + .limit(batch_size - len(exp_list)) .with_for_update() .all() ) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 8490c44506..c6a54650aa 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -5,11 +5,20 @@ import ray +from trinity.buffer.writer.file_writer import JSONWriter from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType +def is_database_url(path: str) -> bool: + return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"]) + + +def is_json_file(path: str) -> bool: + return path.endswith(".json") or path.endswith(".jsonl") + + @ray.remote class QueueActor: """An asyncio.Queue based queue actor.""" @@ -21,12 +30,21 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.capacity = getattr(config, "capacity", 10000) self.queue = asyncio.Queue(self.capacity) if storage_config.path is not None and len(storage_config.path) > 0: - sql_config = deepcopy(storage_config) - sql_config.storage_type = StorageType.SQL - sql_config.wrap_in_ray = False - self.sql_writer = SQLWriter(sql_config, self.config) + if is_database_url(storage_config.path): + storage_config.storage_type = StorageType.SQL + sql_config = deepcopy(storage_config) + sql_config.storage_type = StorageType.SQL + sql_config.wrap_in_ray = False + self.writer = SQLWriter(sql_config, self.config) + elif is_json_file(storage_config.path): + storage_config.storage_type = StorageType.FILE + json_config = deepcopy(storage_config) + json_config.storage_type = StorageType.FILE + self.writer = JSONWriter(json_config, self.config) + else: + self.writer = None else: - self.sql_writer = None + self.writer = None def length(self) -> int: """The length of the queue.""" @@ -35,8 +53,8 @@ def length(self) -> int: async def put_batch(self, exp_list: List) -> None: """Put batch of experience.""" await self.queue.put(exp_list) - if self.sql_writer is not None: - self.sql_writer.write(exp_list) + if self.writer is not None: + self.writer.write(exp_list) async def finish(self) -> None: """Stop the queue.""" diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 69472a3547..9d32d4ba04 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -4,7 +4,7 @@ import datasets import transformers -from datasets import load_dataset +from datasets import Dataset, load_dataset from trinity.buffer.buffer_reader import BufferReader from trinity.common.config import BufferConfig, StorageConfig @@ -17,6 +17,42 @@ FILE_READERS = Registry("file_readers") +class _HFBatchReader: + def __init__(self, dataset: Dataset, max_epoch: int = 1, offset: int = 0): + self.dataset = dataset + self.dataset_size = len(dataset) + self.current_batch_size = None + self.max_epoch = max_epoch + if offset >= self.dataset_size: + self.current_epoch = offset // self.dataset_size + self.current_offset = offset % self.dataset_size + else: + self.current_epoch = 0 + self.current_offset = offset + self.iter = iter(self.dataset) + + for _ in range(self.current_offset): + next(self.iter) + + def read_batch(self, batch_size: int) -> List: + batch = [] + + while len(batch) < batch_size: + try: + item = next(self.iter) + batch.append(item) + self.current_offset += 1 + + except StopIteration: + self.current_epoch += 1 + self.current_offset = 0 + + if self.current_epoch >= self.max_epoch: + raise StopIteration + self.iter = iter(self.dataset) + return batch + + @FILE_READERS.register_module(AlgorithmType.SFT.value) class SFTDataReader(BufferReader): """Reader for SFT file data.""" @@ -29,22 +65,20 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size - self.dataset = load_dataset( - meta.path, name=subset_name, split=self.split + self.dataset = _HFBatchReader( + load_dataset(meta.path, name=subset_name, split=self.split) ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) - def read(self, strategy: Optional[ReadStrategy] = None) -> List: - try: - batch_data = next(self.data_iter) - except StopIteration: - self.dataset = self.dataset.shuffle() - self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) - batch_data = next(self.data_iter) + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + samples = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] if self.prompt_type == PromptType.MESSAGES: - for messages in batch_data[self.messages_key]: + for sample in samples: + messages = sample[self.messages_key] tokens = self.tokenizer.apply_chat_template( messages, add_generation_prompt=False, return_tensors="pt" )[0] @@ -58,9 +92,9 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: exp_list.append(experience) elif self.prompt_type == PromptType.CHATPAIR: - for prompt_messages, response_messages in zip( - batch_data[self.prompt_key], batch_data[self.response_key] - ): + for sample in samples: + prompt_messages = sample[self.prompt_key] + response_messages = sample[self.response_key] if not isinstance(prompt_messages, list): prompt_messages = [prompt_messages] if not isinstance(response_messages, list): @@ -83,7 +117,9 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List: elif self.prompt_type == PromptType.PLAINTEXT: # TODO: support HF format without chat template - for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]): + for sample in samples: + prompt = sample[self.prompt_key] + response = sample[self.response_key] tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0] prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] experience = Experience( @@ -106,8 +142,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.chosen_key = meta.format.chosen_key self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size - self.dataset = load_dataset( - meta.path, name=subset_name, split=self.split + self.dataset = _HFBatchReader( + load_dataset(meta.path, name=subset_name, split=self.split) ) # TODO: support resume self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -120,17 +156,16 @@ def _get_assistant_message(self, item) -> dict: else: return item - def read(self, strategy: Optional[ReadStrategy] = None) -> List: - try: - batch_data = next(self.data_iter) - except StopIteration: - self.dataset = self.dataset.shuffle() - self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True) - batch_data = next(self.data_iter) + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + batch_data = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] - for prompt, chosen, rejected in zip( - batch_data[self.prompt_key], batch_data[self.chosen_key], batch_data[self.rejected_key] - ): + for sample in batch_data: + prompt = sample[self.prompt_key] + chosen = sample[self.chosen_key] + rejected = sample[self.rejected_key] + if self.prompt_type == PromptType.MESSAGES: prompt_messages = prompt @@ -177,18 +212,14 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.split = meta.split subset_name = meta.subset_name # disable datasets caching to avoid reuse old-version dataset + self.epoch = 0 datasets.disable_caching() - self.dataset = load_dataset( - meta.path, name=subset_name, split=self.split - ) # TODO: may from db_url - # if task_type != TaskType.EVAL and config.db_url != "": - # logger.info(f"Loading dataset from database with url: {config.db_url}") - # db_type = config.db_url.split(":")[0] - # db_name = config.db_url.split("/")[-1] - # dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}") - datasets.enable_caching() - self.index = meta.index # TODO: apply shuffle - + self.dataset = _HFBatchReader( + load_dataset(meta.path, name=subset_name, split=self.split), + max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1, + offset=self.meta.index, + ) + self.read_batch_size = config.batch_size self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key self.workflow_key = meta.format.workflow_key @@ -197,36 +228,35 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.task_type = meta.task_type self.default_workflow_cls = WORKFLOWS.get(meta.default_workflow_type) # type: ignore self.default_reward_fn_cls = REWARD_FUNCTIONS.get(meta.default_reward_fn_type) # type: ignore - self.total_epochs = meta.total_epochs if self.task_type == TaskType.EXPLORE else 1 - - def __len__(self): - return len(self.dataset) - - def read(self, strategy: Optional[ReadStrategy] = None): - if self.index >= len(self.dataset) * self.total_epochs: - raise StopIteration - sample = self.dataset[self.index % len(self.dataset)] - workflow_class = ( - WORKFLOWS.get(sample[self.workflow_key]) - if self.workflow_key in sample - else self.default_workflow_cls - ) - reward_fn = ( - REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) - if self.reward_fn_key in sample - else self.default_reward_fn_cls - ) - assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required" - task = Task( - workflow=workflow_class, - format_args=self.meta.format, - rollout_args=self.meta.rollout_args, - workflow_args=self.meta.workflow_args, - is_eval=self.meta.task_type == TaskType.EVAL, - reward_fn=reward_fn, - raw_task=sample, - ) - self.index += 1 - if self.task_type == TaskType.EVAL and self.index == len(self.dataset): - self.index = 0 - return task + + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + batch_size = batch_size or self.read_batch_size + tasks = [] + samples = self.dataset.read_batch(batch_size) + for sample in samples: + workflow_class = ( + WORKFLOWS.get(sample[self.workflow_key]) + if self.workflow_key in sample + else self.default_workflow_cls + ) + reward_fn = ( + REWARD_FUNCTIONS.get(sample[self.reward_fn_key]) + if self.reward_fn_key in sample + else self.default_reward_fn_cls + ) + assert ( + workflow_class is not None + ), "`default_workflow_type` or `workflow_key` is required" + task = Task( + workflow=workflow_class, + format_args=self.meta.format, + rollout_args=self.meta.rollout_args, + workflow_args=self.meta.workflow_args, + is_eval=self.meta.task_type == TaskType.EVAL, + reward_fn=reward_fn, + raw_task=sample, + ) + tasks.append(task) + return tasks diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 3b26014fc4..f696c6decb 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -18,17 +18,20 @@ class QueueReader(BufferReader): def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE - self.config = config + self.read_batch_size = config.read_batch_size self.queue = QueueActor.options( name=f"queue-{storage_config.name}", get_if_exists=True, ).remote(storage_config, config) - def read(self, strategy: Optional[ReadStrategy] = None) -> List: + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: if strategy is not None and strategy != ReadStrategy.FIFO: raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.") try: - exps = ray.get(self.queue.get_batch.remote(self.config.read_batch_size)) + batch_size = batch_size or self.read_batch_size + exps = ray.get(self.queue.get_batch.remote(batch_size)) except StopAsyncIteration: raise StopIteration() return exps diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index dcd9d942bb..eb21c92b95 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -18,8 +18,10 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: self.wrap_in_ray = meta.wrap_in_ray self.db_wrapper = DBWrapper.get_wrapper(meta, config) - def read(self, strategy: Optional[ReadStrategy] = None) -> List: + def read( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: if self.wrap_in_ray: - return ray.get(self.db_wrapper.read.remote(strategy)) + return ray.get(self.db_wrapper.read.remote(batch_size, strategy)) else: - return self.db_wrapper.read(strategy) + return self.db_wrapper.read(batch_size, strategy) diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py new file mode 100644 index 0000000000..b163b86b67 --- /dev/null +++ b/trinity/buffer/writer/file_writer.py @@ -0,0 +1,39 @@ +import json +import os +from typing import List + +from trinity.buffer.buffer_writer import BufferWriter +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience +from trinity.common.workflows import Task + + +class _Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Experience): + return o.to_dict() + if isinstance(o, Task): + return o.to_dict() + return super().default(o) + + +class JSONWriter(BufferWriter): + def __init__(self, meta: StorageConfig, config: BufferConfig): + assert meta.storage_type == StorageType.FILE + if meta.path is None: + raise ValueError("File path cannot be None for RawFileWriter") + ext = os.path.splitext(meta.path)[-1] + if ext != ".jsonl" and ext != ".json": + raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}") + self.file = open(meta.path, "a", encoding="utf-8") + self.encoder = _Encoder(ensure_ascii=False) + + def write(self, data: List) -> None: + for item in data: + json_str = self.encoder.encode(item) + self.file.write(json_str + "\n") + self.file.flush() + + def finish(self): + self.file.close() diff --git a/trinity/common/experience.py b/trinity/common/experience.py index a1b5008681..a31b778563 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -43,6 +43,23 @@ def deserialize(data: bytes) -> Experience: """Deserialize the experience from bytes.""" return pickle.loads(data) + def to_dict(self) -> dict: + """Convert the experience to a dictionary.""" + res = { + "prompt_text": self.prompt_text, + "info": self.info, + "metrics": self.metrics, + } + if self.response_text is not None: + res["response_text"] = self.response_text + if self.chosen is not None: + res["chosen"] = self.chosen.tolist() + if self.rejected is not None: + res["rejected"] = self.rejected.tolist() + if self.reward is not None: + res["reward"] = float(self.reward) + return res + @dataclass(frozen=True) class Experiences: diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index fc4a87556b..2e45804863 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -68,6 +68,9 @@ def truth(self) -> Union[str, None]: response_key = self.format_args.response_key return self.raw_task[response_key] if response_key in self.raw_task else None # type: ignore + def to_dict(self) -> dict: + return self.raw_task # type: ignore + class Workflow(ABC): """The base workflow class. diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index bdc0228a65..5f05973d4d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -43,9 +43,6 @@ def __init__(self, config: Config): self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer ) - self.eval_tasksets = [] - for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: - self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) self.runner_pool = self._init_runner_pool() self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, @@ -177,14 +174,14 @@ def explore_one_period(self) -> Tuple[bool, int]: explore_status: whether there are more tasks to explore. explore_step_num: the number of explore steps """ - task_num_per_period = self.config.synchronizer.sync_interval * self.config.buffer.batch_size - st = time.time() all_metrics = defaultdict(list) # submit tasks of this step try: - tasks = [self.taskset.read() for _ in range(task_num_per_period)] + tasks = [] + for _ in range(self.config.synchronizer.sync_interval): + tasks.extend(self.taskset.read()) self.runner_pool.run_tasks(tasks) # type: ignore except StopIteration: self.experience_buffer.finish() @@ -218,7 +215,8 @@ def explore_one_period(self) -> Tuple[bool, int]: # save explore checkpoint self.cache.save_explorer( current_step=self.step_num, - current_task_index=self.taskset.index, + current_task_index=self.step_num * self.config.buffer.batch_size, + # TODO: remove current_task_index ) self.logger.info(f"Explore step {self.step_num} finished.") @@ -226,13 +224,16 @@ def explore_one_period(self) -> Tuple[bool, int]: def eval(self) -> Tuple[bool, int]: """Evaluation on all evaluation data samples.""" - if len(self.eval_tasksets) == 0: + eval_tasksets = [] + for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets: + eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) + if len(eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") return True, self.step_num self.logger.info("Evaluation started.") all_st = time.time() log_metrics = {} - for eval_taskset in self.eval_tasksets: + for eval_taskset in eval_tasksets: st = time.time() all_metrics = defaultdict(list) @@ -247,10 +248,13 @@ def wait(): for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) - for _ in range(len(eval_taskset)): # type: ignore + while True: if not self.runner_pool.has_free(): wait() - self.runner_pool.run_tasks([eval_taskset.read()]) # type: ignore + try: + self.runner_pool.run_tasks(eval_taskset.read()) + except StopIteration: + break while self.runner_pool.has_next(): wait() metrics = self.monitor.calculate_metrics(all_metrics, prefix=f"eval/{eval_taskset.name}") # type: ignore