|
| 1 | +import os |
1 | 2 | import unittest |
2 | 3 |
|
3 | 4 | import ray |
@@ -935,3 +936,59 @@ async def test_api_tool_calls(self): |
935 | 936 | print_debug( |
936 | 937 | "\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n" |
937 | 938 | ) |
| 939 | + |
| 940 | + |
| 941 | +class TestSuperLongGeneration(RayUnittestBaseAysnc): |
| 942 | + def setUp(self): |
| 943 | + self.config = get_template_config() |
| 944 | + self.config.mode = "explore" |
| 945 | + self.config.model.model_path = get_model_path() |
| 946 | + self.config.model.max_model_len = 81920 |
| 947 | + self.config.model.max_prompt_tokens = 61440 |
| 948 | + self.config.model.max_response_tokens = 20480 |
| 949 | + self.config.model.rope_scaling = { |
| 950 | + "rope_type": "yarn", |
| 951 | + "factor": 2.0, |
| 952 | + "original_max_position_embeddings": 40960, |
| 953 | + } |
| 954 | + self.config.explorer.rollout_model.engine_type = "vllm" |
| 955 | + self.config.explorer.rollout_model.engine_num = 1 |
| 956 | + self.config.explorer.rollout_model.tensor_parallel_size = 1 |
| 957 | + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE |
| 958 | + |
| 959 | + self.config.check_and_update() |
| 960 | + self.engines, self.auxiliary_engines = create_inference_models(self.config) |
| 961 | + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) |
| 962 | + |
| 963 | + async def test_generate(self): |
| 964 | + base_dir = os.path.dirname(__file__) |
| 965 | + target_dir = os.path.join(base_dir, "..", "..", "trinity", "trainer", "verl") |
| 966 | + with open(os.path.join(target_dir, "fsdp_workers.py")) as f: |
| 967 | + fsdp_code = f.read() |
| 968 | + with open(os.path.join(target_dir, "megatron_workers.py")) as f: |
| 969 | + megatron_code = f.read() |
| 970 | + target_dir = os.path.join(base_dir, "..", "..", "trinity", "common") |
| 971 | + with open(os.path.join(target_dir, "config.py")) as f: |
| 972 | + config_code = f.read() |
| 973 | + target_dir = os.path.join(base_dir, "..", "..", "trinity", "manager") |
| 974 | + with open(os.path.join(target_dir, "config_manager.py")) as f: |
| 975 | + config_manager_code = f.read() |
| 976 | + |
| 977 | + messages = [ |
| 978 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 979 | + { |
| 980 | + "role": "user", |
| 981 | + "content": """# Please add comments and documentation for these following code, """ |
| 982 | + """make sure the code is well-structured and easy to read, """ |
| 983 | + """and the complete code must be shown, do not omit any parts.\n""" |
| 984 | + f"""## fsdp_workers.py\n{fsdp_code}\n""" |
| 985 | + f"""## megatron_workers.py\n{megatron_code}\n""" |
| 986 | + f"""## config.py\n{config_code}\n""" |
| 987 | + f"""## config_manager.py\n{config_manager_code}\n""", |
| 988 | + }, |
| 989 | + ] |
| 990 | + response = self.model_wrapper.chat(messages, n=1, temperature=0.7, logprobs=True)[0] |
| 991 | + self.assertGreater( |
| 992 | + response.prompt_length, 40960 |
| 993 | + ) # If not long enough, please add more files to prompt |
| 994 | + self.assertGreater(response.logprobs.shape[0], 1000) |
0 commit comments