Skip to content

Commit f40e0a4

Browse files
committed
bug fix
1 parent 9894eed commit f40e0a4

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

tests/explorer/runner_pool_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ray
77
import torch
88

9+
from tests.tools import get_unittest_dataset_config
910
from trinity.buffer.reader.queue_reader import QueueReader
1011
from trinity.common.config import StorageConfig, load_config
1112
from trinity.common.constants import AlgorithmType, StorageType
@@ -21,7 +22,7 @@
2122
@WORKFLOWS.register_module("dummy_workflow")
2223
class DummyWorkflow(Workflow):
2324
def __init__(self, model, **kwargs):
24-
super().__init__(model)
25+
super().__init__(model, **kwargs)
2526
self.error_type = kwargs.get("task_desc")
2627
self.seconds = None
2728
if "timeout" in self.error_type:
@@ -81,30 +82,37 @@ def setUp(self):
8182

8283
def test_runner_pool(self):
8384
pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()])
85+
taskset_config = get_unittest_dataset_config("countdown")
8486
tasks = [
8587
Task(
8688
task_desc="timeout_100",
8789
workflow=DummyWorkflow,
90+
taskset_config=taskset_config,
8891
),
8992
Task(
9093
task_desc="exception",
9194
workflow=DummyWorkflow,
95+
taskset_config=taskset_config,
9296
),
9397
Task(
9498
task_desc="timeout_2",
9599
workflow=DummyWorkflow,
100+
taskset_config=taskset_config,
96101
),
97102
Task(
98103
task_desc="success",
99104
workflow=DummyWorkflow,
105+
taskset_config=taskset_config,
100106
),
101107
Task(
102108
task_desc="timeout_101",
103109
workflow=DummyWorkflow,
110+
taskset_config=taskset_config,
104111
),
105112
Task(
106113
task_desc="exit",
107114
workflow=DummyWorkflow,
115+
taskset_config=taskset_config,
108116
),
109117
]
110118

tests/template/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ buffer:
1818
taskset:
1919
name: taskset
2020
storage_type: file
21-
path: ''
21+
path: 'placeholder'
2222
split: 'train'
2323
default_workflow_type: ''
2424
default_reward_fn_type: ''

trinity/explorer/workflow_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def run_task(self, task: Task) -> Status:
7777
if metrics:
7878
for k, v in metrics.items():
7979
metric[k] = sum(v) / len(v) # type: ignore
80-
if not task.storage_config.task_type == TaskType.EVAL:
80+
if not task.taskset_config.task_type == TaskType.EVAL:
8181
self.experience_buffer.write(exps)
8282
return Status(True, metric=metric)
8383
except Exception as e:

0 commit comments

Comments
 (0)