|
6 | 6 | import ray |
7 | 7 | import torch |
8 | 8 |
|
| 9 | +from tests.tools import get_unittest_dataset_config |
9 | 10 | from trinity.buffer.reader.queue_reader import QueueReader |
10 | 11 | from trinity.common.config import StorageConfig, load_config |
11 | 12 | from trinity.common.constants import AlgorithmType, StorageType |
|
21 | 22 | @WORKFLOWS.register_module("dummy_workflow") |
22 | 23 | class DummyWorkflow(Workflow): |
23 | 24 | def __init__(self, model, **kwargs): |
24 | | - super().__init__(model) |
| 25 | + super().__init__(model, **kwargs) |
25 | 26 | self.error_type = kwargs.get("task_desc") |
26 | 27 | self.seconds = None |
27 | 28 | if "timeout" in self.error_type: |
@@ -81,30 +82,37 @@ def setUp(self): |
81 | 82 |
|
82 | 83 | def test_runner_pool(self): |
83 | 84 | pool = RunnerPool(self.config, [DummyModel.remote(), DummyModel.remote()]) |
| 85 | + taskset_config = get_unittest_dataset_config("countdown") |
84 | 86 | tasks = [ |
85 | 87 | Task( |
86 | 88 | task_desc="timeout_100", |
87 | 89 | workflow=DummyWorkflow, |
| 90 | + taskset_config=taskset_config, |
88 | 91 | ), |
89 | 92 | Task( |
90 | 93 | task_desc="exception", |
91 | 94 | workflow=DummyWorkflow, |
| 95 | + taskset_config=taskset_config, |
92 | 96 | ), |
93 | 97 | Task( |
94 | 98 | task_desc="timeout_2", |
95 | 99 | workflow=DummyWorkflow, |
| 100 | + taskset_config=taskset_config, |
96 | 101 | ), |
97 | 102 | Task( |
98 | 103 | task_desc="success", |
99 | 104 | workflow=DummyWorkflow, |
| 105 | + taskset_config=taskset_config, |
100 | 106 | ), |
101 | 107 | Task( |
102 | 108 | task_desc="timeout_101", |
103 | 109 | workflow=DummyWorkflow, |
| 110 | + taskset_config=taskset_config, |
104 | 111 | ), |
105 | 112 | Task( |
106 | 113 | task_desc="exit", |
107 | 114 | workflow=DummyWorkflow, |
| 115 | + taskset_config=taskset_config, |
108 | 116 | ), |
109 | 117 | ] |
110 | 118 |
|
|
0 commit comments