Skip to content

Commit a915128

Browse files
authored
Add config_validator.py and refactor config (#487)
1 parent 837f826 commit a915128

File tree

10 files changed

+1286
-870
lines changed

10 files changed

+1286
-870
lines changed

.github/workflows/docker/docker-compose.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
services:
22
trinity-node-1:
33
image: trinity-rft-unittest:20260115
4+
cap_add:
5+
- SYS_PTRACE
46
pull_policy: never
57
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block"
68
environment:
@@ -31,6 +33,8 @@ services:
3133

3234
trinity-node-2:
3335
image: trinity-rft-unittest:20260115
36+
cap_add:
37+
- SYS_PTRACE
3438
pull_policy: never
3539
command: bash -c "source /opt/venv/bin/activate && uv pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block"
3640
environment:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ dependencies = [
4141
"jsonlines",
4242
"sortedcontainers",
4343
"word2number",
44-
"transformers",
45-
"datasets",
44+
"transformers>=4.51.0",
45+
"datasets>=4.0.0",
4646
]
4747

4848
[project.scripts]

tests/common/config_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_update_config_from_ray_cluster(self):
9191
config.cluster.node_num = None
9292
config.cluster.gpu_per_node = None
9393

94-
config._update_config_from_ray_cluster()
94+
config.check_and_update()
9595
self.assertEqual(config.cluster.node_num, 2)
9696
self.assertEqual(config.cluster.gpu_per_node, 2)
9797

tests/tools.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import os
23
import unittest
34
from collections import defaultdict
@@ -243,21 +244,51 @@ def metric_list(self, metric_prefix: str) -> List[str]:
243244
return [name for name in self._metrics if name.startswith(metric_prefix)]
244245

245246

246-
class RayUnittestBase(unittest.TestCase):
247+
class RayCleanupPlugin:
248+
@classmethod
249+
def _cleanup_ray_data_state(cls):
250+
"""clean up the global states of Ray Data"""
251+
try:
252+
# reset execution context
253+
if hasattr(ray.data._internal.execution.streaming_executor, "_execution_context"):
254+
ray.data._internal.execution.streaming_executor._execution_context = None
255+
256+
# trigger gc.collect() on all workers in the cluster
257+
ray._private.internal_api.global_gc()
258+
259+
# clean up stats manager
260+
from ray.data._internal.stats import StatsManager
261+
262+
if hasattr(StatsManager, "_instance"):
263+
StatsManager._instance = None
264+
265+
except Exception:
266+
pass
267+
268+
269+
class RayUnittestBase(unittest.TestCase, RayCleanupPlugin):
247270
@classmethod
248271
def setUpClass(cls):
249272
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
250273

274+
# erase existing resources
275+
cls._cleanup_ray_data_state()
276+
gc.collect()
277+
251278
@classmethod
252279
def tearDownClass(cls):
253280
ray.shutdown(_exiting_interpreter=True)
254281

255282

256-
class RayUnittestBaseAsync(unittest.IsolatedAsyncioTestCase):
283+
class RayUnittestBaseAsync(unittest.IsolatedAsyncioTestCase, RayCleanupPlugin):
257284
@classmethod
258285
def setUpClass(cls):
259286
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
260287

288+
# erase existing resources
289+
cls._cleanup_ray_data_state()
290+
gc.collect()
291+
261292
@classmethod
262293
def tearDownClass(cls):
263294
ray.shutdown(_exiting_interpreter=True)

tests/trainer/trainer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def test_trainer(self, mock_load):
350350

351351
with self.assertRaises(Exception):
352352
run(config_path="dummy.yaml")
353+
ray.shutdown(_exiting_interpreter=True)
353354

354355
stage_configs = [cfg.check_and_update() for cfg in deepcopy(self.config)]
355356

@@ -372,6 +373,7 @@ def test_trainer(self, mock_load):
372373

373374
self.config.stages[1].buffer.explorer_input.taskset.path = old_taskset_path
374375
mock_load.return_value = deepcopy(self.config)
376+
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
375377
run(config_path="dummy.yaml")
376378

377379
# grpo stage

0 commit comments

Comments
 (0)