Skip to content

Commit d3e42bf

Browse files
committed
add unittest
1 parent c00811b commit d3e42bf

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

tests/common/vllm_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23

34
import ray
@@ -935,3 +936,59 @@ async def test_api_tool_calls(self):
935936
print_debug(
936937
"\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n"
937938
)
939+
940+
941+
class TestSuperLongGeneration(RayUnittestBaseAysnc):
942+
def setUp(self):
943+
self.config = get_template_config()
944+
self.config.mode = "explore"
945+
self.config.model.model_path = get_model_path()
946+
self.config.model.max_model_len = 81920
947+
self.config.model.max_prompt_tokens = 61440
948+
self.config.model.max_response_tokens = 20480
949+
self.config.model.rope_scaling = {
950+
"rope_type": "yarn",
951+
"factor": 2.0,
952+
"original_max_position_embeddings": 40960,
953+
}
954+
self.config.explorer.rollout_model.engine_type = "vllm"
955+
self.config.explorer.rollout_model.engine_num = 1
956+
self.config.explorer.rollout_model.tensor_parallel_size = 1
957+
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
958+
959+
self.config.check_and_update()
960+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
961+
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
962+
963+
async def test_generate(self):
964+
base_dir = os.path.dirname(__file__)
965+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "trainer", "verl")
966+
with open(os.path.join(target_dir, "fsdp_workers.py")) as f:
967+
fsdp_code = f.read()
968+
with open(os.path.join(target_dir, "megatron_workers.py")) as f:
969+
megatron_code = f.read()
970+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "common")
971+
with open(os.path.join(target_dir, "config.py")) as f:
972+
config_code = f.read()
973+
target_dir = os.path.join(base_dir, "..", "..", "trinity", "manager")
974+
with open(os.path.join(target_dir, "config_manager.py")) as f:
975+
config_manager_code = f.read()
976+
977+
messages = [
978+
{"role": "system", "content": "You are a helpful assistant."},
979+
{
980+
"role": "user",
981+
"content": """# Please add comments and documentation for these following code, """
982+
"""make sure the code is well-structured and easy to read, """
983+
"""and the complete code must be shown, do not omit any parts.\n"""
984+
f"""## fsdp_workers.py\n{fsdp_code}\n"""
985+
f"""## megatron_workers.py\n{megatron_code}\n"""
986+
f"""## config.py\n{config_code}\n"""
987+
f"""## config_manager.py\n{config_manager_code}\n""",
988+
},
989+
]
990+
response = self.model_wrapper.chat(messages, n=1, temperature=0.7, logprobs=True)[0]
991+
self.assertGreater(
992+
response.prompt_length, 40960
993+
) # If not long enough, please add more files to prompt
994+
self.assertGreater(response.logprobs.shape[0], 1000)

tests/trainer/trainer_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ class TestTrainerCountdown(BaseTrainerCase):
7373
def test_trainer(self):
7474
"""Test the both and bench mode."""
7575
# test both mode
76+
self.config.model.rope_scaling = {
77+
"rope_type": "yarn",
78+
"factor": 2.0,
79+
"original_max_position_embeddings": 16384,
80+
}
81+
self.config.model.rope_theta = 10000
7682
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
7783
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
7884
selector_type="shuffle", seed=42

0 commit comments

Comments
 (0)