Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/buffer/experience_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ray
import torch

from tests.tools import RayUnittestBaseAysnc, get_template_config
from tests.tools import RayUnittestBaseAsync, get_template_config
from trinity.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.common.config import (
Expand Down Expand Up @@ -34,7 +34,7 @@ def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) ->
]


class TestExperiencePipeline(RayUnittestBaseAysnc):
class TestExperiencePipeline(RayUnittestBaseAsync):
def setUp(self):
if os.path.exists(BUFFER_FILE_PATH):
os.remove(BUFFER_FILE_PATH)
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
Expand All @@ -17,7 +17,7 @@
DB_PATH = os.path.join(os.path.dirname(__file__), "test.db")


class ExperienceStorageTest(RayUnittestBaseAysnc):
class ExperienceStorageTest(RayUnittestBaseAsync):
def setUp(self):
self.total_num = 8
self.put_batch_size = 2
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
Expand All @@ -17,7 +17,7 @@
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBaseAysnc):
class TestQueueBuffer(RayUnittestBaseAsync):
@parameterized.expand(
[
(
Expand Down
4 changes: 2 additions & 2 deletions tests/buffer/reader_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config
from tests.tools import RayUnittestBaseAsync, get_unittest_dataset_config
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import FileReader, TaskFileReader
Expand All @@ -12,7 +12,7 @@ def __init__(self, config):
super().__init__(config)


class TestBufferReader(RayUnittestBaseAysnc):
class TestBufferReader(RayUnittestBaseAsync):
async def test_buffer_reader_registration(self) -> None:
config = get_unittest_dataset_config("countdown", "train")
config.batch_size = 2
Expand Down
6 changes: 3 additions & 3 deletions tests/buffer/sample_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from parameterized import parameterized_class

from tests.tools import RayUnittestBaseAysnc, get_template_config
from tests.tools import RayUnittestBaseAsync, get_template_config
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
from trinity.buffer.buffer import get_buffer_writer
Expand All @@ -21,7 +21,7 @@
(6,),
],
)
class ExperienceStorageTest(RayUnittestBaseAysnc):
class ExperienceStorageTest(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.num_steps = 20
Expand Down Expand Up @@ -249,5 +249,5 @@ async def test_sql_staleness_control_sample_strategy(self):

def tearDown(self):
asyncio.run(self.buffer_writer.release())
shutil.rmtree(self.config.checkpoint_job_dir)
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
return super().tearDown()
4 changes: 2 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from tests.tools import RayUnittestBaseAsync
from trinity.buffer import get_buffer_reader
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
Expand All @@ -19,7 +19,7 @@
db_path = os.path.join(os.path.dirname(__file__), "test.db")


class TestSQLBuffer(RayUnittestBaseAysnc):
class TestSQLBuffer(RayUnittestBaseAsync):
@parameterized.expand(
[
(True,),
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUpClass(cls):
def tearDownClass(cls):
super().tearDownClass()
if os.path.exists(cls.temp_output_path):
shutil.rmtree(cls.temp_output_path)
shutil.rmtree(cls.temp_output_path, ignore_errors=True)

def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, int]]) -> None:
for task, index in zip(batch_tasks, indices):
Expand Down
2 changes: 1 addition & 1 deletion tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ def test_chat_template_path(self):

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True)
20 changes: 10 additions & 10 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer

from tests.tools import (
RayUnittestBaseAysnc,
RayUnittestBaseAsync,
get_api_model_path,
get_model_path,
get_template_config,
Expand Down Expand Up @@ -113,7 +113,7 @@ async def prepare_engines(engines, auxiliary_engines):
(2, 1, 3, True, True),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
class ModelWrapperTest(RayUnittestBaseAsync):
def setUp(self):
# configure the model
self.config = get_template_config()
Expand Down Expand Up @@ -233,7 +233,7 @@ async def test_generate(self):
(20, 5, 15),
],
)
class TestModelLen(RayUnittestBaseAysnc):
class TestModelLen(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -302,7 +302,7 @@ def _check_experience(exp):
)


class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc):
class TestModelLenWithoutPromptTruncation(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -351,7 +351,7 @@ async def test_model_len(self):
)


class TestAPIServer(RayUnittestBaseAysnc):
class TestAPIServer(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -482,7 +482,7 @@ async def test_api(self):
"""


class TestLogprobs(RayUnittestBaseAysnc):
class TestLogprobs(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -669,7 +669,7 @@ async def test_logprobs_api(self):
)


class TestAsyncAPIServer(RayUnittestBaseAysnc):
class TestAsyncAPIServer(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -880,7 +880,7 @@ def test_action_mask_with_tools(self):
(False, None),
],
)
class TestAPIServerToolCall(RayUnittestBaseAysnc):
class TestAPIServerToolCall(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -1161,7 +1161,7 @@ async def test_api_tool_calls(self):
)


class TestSuperLongGeneration(RayUnittestBaseAysnc):
class TestSuperLongGeneration(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
Expand Down Expand Up @@ -1217,7 +1217,7 @@ async def test_generate(self):
self.assertGreater(response.logprobs.shape[0], 1000)


class TestTinkerAPI(RayUnittestBaseAysnc):
class TestTinkerAPI(RayUnittestBaseAsync):
"""Test the Tinker API integration with the vLLM engine."""

def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.tools import (
RayUnittestBase,
RayUnittestBaseAysnc,
RayUnittestBaseAsync,
TensorBoardParser,
get_api_model_path,
get_checkpoint_path,
Expand Down Expand Up @@ -180,7 +180,7 @@ def run_agent(proxy_url, model_path: str):
return response.choices[0].message.content


class ServeTest(RayUnittestBaseAysnc):
class ServeTest(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "serve"
Expand Down
34 changes: 24 additions & 10 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, intervals)
train(config)
ray.shutdown()


def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
ray.shutdown()


def run_both(
Expand All @@ -97,17 +99,26 @@ def run_both(
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
ray.shutdown()


class BaseTestSynchronizer(unittest.TestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)
self.process_list = []

def tearDown(self):
checkpoint_path = get_checkpoint_path()
ray.shutdown(_exiting_interpreter=True)
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR, ignore_errors=True)
for process in self.process_list:
if process.is_alive():
process.terminate()
process.join(timeout=10)
if process.is_alive():
process.kill()
process.join()


class TestSynchronizerExit(BaseTestSynchronizer):
Expand Down Expand Up @@ -151,6 +162,8 @@ def test_synchronizer(self):
target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1])
)
trainer_process.start()
self.process_list.append(trainer_process)

ray.init(ignore_reinit_error=True)
while True:
try:
Expand All @@ -164,6 +177,7 @@ def test_synchronizer(self):
args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]),
)
explorer_process_1.start()
self.process_list.append(explorer_process_1)

self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
Expand All @@ -176,14 +190,13 @@ def test_synchronizer(self):
except ValueError:
print("waiting for explorer1 to start.")
time.sleep(5)
trainer_process.terminate()
trainer_process.join()

trainer_process.join(timeout=200)
self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
)

explorer_process_1.terminate()
explorer_process_1.join()
explorer_process_1.join(timeout=200)
time.sleep(6)
with self.assertRaises(ValueError):
ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
Expand Down Expand Up @@ -278,6 +291,8 @@ def test_synchronizer(self):
target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals)
)
trainer_process.start()
self.process_list.append(trainer_process)

ray.init(ignore_reinit_error=True)
while True:
try:
Expand All @@ -291,10 +306,12 @@ def test_synchronizer(self):
args=(explorer1_config, self.max_steps, self.explorer1_intervals),
)
explorer_process_1.start()
self.process_list.append(explorer_process_1)
explorer_process_2 = multiprocessing.Process(
target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals)
)
explorer_process_2.start()
self.process_list.append(explorer_process_2)

explorer_process_1.join(timeout=200)
explorer_process_2.join(timeout=200)
Expand Down Expand Up @@ -364,6 +381,7 @@ def test_synchronizer(self):
args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals),
)
both_process.start()
self.process_list.append(both_process)
both_process.join(timeout=200)

# check the tensorboard
Expand All @@ -375,7 +393,3 @@ def test_synchronizer(self):
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
2 changes: 1 addition & 1 deletion tests/service/data_juicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def test_data_juicer_operators(self):
class TestDataJuicerTaskPipeline(RayUnittestBase):
def setUp(self):
if os.path.exists(TASKSET_OUTPUT_DIR):
shutil.rmtree(TASKSET_OUTPUT_DIR)
shutil.rmtree(TASKSET_OUTPUT_DIR, ignore_errors=True)

def test_data_juicer_task_pipeline(self):
config = get_template_config()
Expand Down
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def tearDownClass(cls):
ray.shutdown(_exiting_interpreter=True)


class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase):
class RayUnittestBaseAsync(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
Expand Down
Loading