Skip to content

Commit 7601bb9

Browse files
authored
Bug fix in last checkpoint save. (#460)
1 parent 3178d4f commit 7601bb9

17 files changed

+697
-437
lines changed

tests/buffer/experience_pipeline_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import ray
55
import torch
66

7-
from tests.tools import RayUnittestBaseAysnc, get_template_config
7+
from tests.tools import RayUnittestBaseAsync, get_template_config
88
from trinity.buffer import get_buffer_reader
99
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
1010
from trinity.common.config import (
@@ -34,7 +34,7 @@ def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) ->
3434
]
3535

3636

37-
class TestExperiencePipeline(RayUnittestBaseAysnc):
37+
class TestExperiencePipeline(RayUnittestBaseAsync):
3838
def setUp(self):
3939
if os.path.exists(BUFFER_FILE_PATH):
4040
os.remove(BUFFER_FILE_PATH)

tests/buffer/experience_storage_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from parameterized import parameterized
99

10-
from tests.tools import RayUnittestBaseAysnc
10+
from tests.tools import RayUnittestBaseAsync
1111
from trinity.buffer.reader.sql_reader import SQLReader
1212
from trinity.buffer.writer.sql_writer import SQLWriter
1313
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
@@ -17,7 +17,7 @@
1717
DB_PATH = os.path.join(os.path.dirname(__file__), "test.db")
1818

1919

20-
class ExperienceStorageTest(RayUnittestBaseAysnc):
20+
class ExperienceStorageTest(RayUnittestBaseAsync):
2121
def setUp(self):
2222
self.total_num = 8
2323
self.put_batch_size = 2

tests/buffer/queue_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from parameterized import parameterized
99

10-
from tests.tools import RayUnittestBaseAysnc
10+
from tests.tools import RayUnittestBaseAsync
1111
from trinity.buffer.reader.queue_reader import QueueReader
1212
from trinity.buffer.writer.queue_writer import QueueWriter
1313
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
@@ -17,7 +17,7 @@
1717
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
1818

1919

20-
class TestQueueBuffer(RayUnittestBaseAysnc):
20+
class TestQueueBuffer(RayUnittestBaseAsync):
2121
@parameterized.expand(
2222
[
2323
(

tests/buffer/reader_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config
1+
from tests.tools import RayUnittestBaseAsync, get_unittest_dataset_config
22
from trinity.buffer.buffer import get_buffer_reader
33
from trinity.buffer.reader import READER
44
from trinity.buffer.reader.file_reader import FileReader, TaskFileReader
@@ -12,7 +12,7 @@ def __init__(self, config):
1212
super().__init__(config)
1313

1414

15-
class TestBufferReader(RayUnittestBaseAysnc):
15+
class TestBufferReader(RayUnittestBaseAsync):
1616
async def test_buffer_reader_registration(self) -> None:
1717
config = get_unittest_dataset_config("countdown", "train")
1818
config.batch_size = 2

tests/buffer/sample_strategy_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from parameterized import parameterized_class
77

8-
from tests.tools import RayUnittestBaseAysnc, get_template_config
8+
from tests.tools import RayUnittestBaseAsync, get_template_config
99
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
1010
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
1111
from trinity.buffer.buffer import get_buffer_writer
@@ -21,7 +21,7 @@
2121
(6,),
2222
],
2323
)
24-
class ExperienceStorageTest(RayUnittestBaseAysnc):
24+
class ExperienceStorageTest(RayUnittestBaseAsync):
2525
def setUp(self):
2626
self.config = get_template_config()
2727
self.num_steps = 20
@@ -249,5 +249,5 @@ async def test_sql_staleness_control_sample_strategy(self):
249249

250250
def tearDown(self):
251251
asyncio.run(self.buffer_writer.release())
252-
shutil.rmtree(self.config.checkpoint_job_dir)
252+
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
253253
return super().tearDown()

tests/buffer/sql_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from parameterized import parameterized
66

7-
from tests.tools import RayUnittestBaseAysnc
7+
from tests.tools import RayUnittestBaseAsync
88
from trinity.buffer import get_buffer_reader
99
from trinity.buffer.reader.sql_reader import SQLReader
1010
from trinity.buffer.writer.sql_writer import SQLWriter
@@ -19,7 +19,7 @@
1919
db_path = os.path.join(os.path.dirname(__file__), "test.db")
2020

2121

22-
class TestSQLBuffer(RayUnittestBaseAysnc):
22+
class TestSQLBuffer(RayUnittestBaseAsync):
2323
@parameterized.expand(
2424
[
2525
(True,),

tests/buffer/task_scheduler_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def setUpClass(cls):
3333
def tearDownClass(cls):
3434
super().tearDownClass()
3535
if os.path.exists(cls.temp_output_path):
36-
shutil.rmtree(cls.temp_output_path)
36+
shutil.rmtree(cls.temp_output_path, ignore_errors=True)
3737

3838
def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None:
3939
for task, index in zip(batch_tasks, indices):

tests/common/config_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,4 @@ def test_chat_template_path(self):
184184

185185
def tearDown(self):
186186
if os.path.exists(CHECKPOINT_ROOT_DIR):
187-
shutil.rmtree(CHECKPOINT_ROOT_DIR)
187+
shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True)

tests/common/vllm_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import AutoTokenizer
99

1010
from tests.tools import (
11-
RayUnittestBaseAysnc,
11+
RayUnittestBaseAsync,
1212
get_api_model_path,
1313
get_model_path,
1414
get_template_config,
@@ -113,7 +113,7 @@ async def prepare_engines(engines, auxiliary_engines):
113113
(2, 1, 3, True, True),
114114
],
115115
)
116-
class ModelWrapperTest(RayUnittestBaseAysnc):
116+
class ModelWrapperTest(RayUnittestBaseAsync):
117117
def setUp(self):
118118
# configure the model
119119
self.config = get_template_config()
@@ -233,7 +233,7 @@ async def test_generate(self):
233233
(20, 5, 15),
234234
],
235235
)
236-
class TestModelLen(RayUnittestBaseAysnc):
236+
class TestModelLen(RayUnittestBaseAsync):
237237
def setUp(self):
238238
self.config = get_template_config()
239239
self.config.mode = "explore"
@@ -302,7 +302,7 @@ def _check_experience(exp):
302302
)
303303

304304

305-
class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc):
305+
class TestModelLenWithoutPromptTruncation(RayUnittestBaseAsync):
306306
def setUp(self):
307307
self.config = get_template_config()
308308
self.config.mode = "explore"
@@ -351,7 +351,7 @@ async def test_model_len(self):
351351
)
352352

353353

354-
class TestAPIServer(RayUnittestBaseAysnc):
354+
class TestAPIServer(RayUnittestBaseAsync):
355355
def setUp(self):
356356
self.config = get_template_config()
357357
self.config.mode = "explore"
@@ -482,7 +482,7 @@ async def test_api(self):
482482
"""
483483

484484

485-
class TestLogprobs(RayUnittestBaseAysnc):
485+
class TestLogprobs(RayUnittestBaseAsync):
486486
def setUp(self):
487487
self.config = get_template_config()
488488
self.config.mode = "explore"
@@ -669,7 +669,7 @@ async def test_logprobs_api(self):
669669
)
670670

671671

672-
class TestAsyncAPIServer(RayUnittestBaseAysnc):
672+
class TestAsyncAPIServer(RayUnittestBaseAsync):
673673
def setUp(self):
674674
self.config = get_template_config()
675675
self.config.mode = "explore"
@@ -880,7 +880,7 @@ def test_action_mask_with_tools(self):
880880
(False, None),
881881
],
882882
)
883-
class TestAPIServerToolCall(RayUnittestBaseAysnc):
883+
class TestAPIServerToolCall(RayUnittestBaseAsync):
884884
def setUp(self):
885885
self.config = get_template_config()
886886
self.config.mode = "explore"
@@ -1161,7 +1161,7 @@ async def test_api_tool_calls(self):
11611161
)
11621162

11631163

1164-
class TestSuperLongGeneration(RayUnittestBaseAysnc):
1164+
class TestSuperLongGeneration(RayUnittestBaseAsync):
11651165
def setUp(self):
11661166
self.config = get_template_config()
11671167
self.config.mode = "explore"
@@ -1217,7 +1217,7 @@ async def test_generate(self):
12171217
self.assertGreater(response.logprobs.shape[0], 1000)
12181218

12191219

1220-
class TestTinkerAPI(RayUnittestBaseAysnc):
1220+
class TestTinkerAPI(RayUnittestBaseAsync):
12211221
"""Test the Tinker API integration with the vLLM engine."""
12221222

12231223
def setUp(self):

tests/explorer/explorer_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from tests.tools import (
1414
RayUnittestBase,
15-
RayUnittestBaseAysnc,
15+
RayUnittestBaseAsync,
1616
TensorBoardParser,
1717
get_api_model_path,
1818
get_checkpoint_path,
@@ -180,7 +180,7 @@ def run_agent(proxy_url, model_path: str):
180180
return response.choices[0].message.content
181181

182182

183-
class ServeTest(RayUnittestBaseAysnc):
183+
class ServeTest(RayUnittestBaseAsync):
184184
def setUp(self):
185185
self.config = get_template_config()
186186
self.config.mode = "serve"

0 commit comments

Comments
 (0)