Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
RayUnittestBase,
RayUnittestBaseAysnc,
TensorBoardParser,
get_api_model_path,
get_checkpoint_path,
get_model_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.buffer import get_buffer_reader
from trinity.cli.launcher import explore, run_stage
from trinity.common.config import ExperienceBufferConfig
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
from trinity.common.constants import StorageType
from trinity.explorer.explorer import Explorer
from trinity.manager.state_manager import StateManager
Expand All @@ -31,6 +32,7 @@
class BaseExplorerCase(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.buffer.total_epochs = 2
self.config.buffer.batch_size = 4
self.config.model.model_path = get_model_path()
Expand Down Expand Up @@ -67,18 +69,33 @@ def test_explorer(self):
self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics)


class TestExplorerCountdownNoEval(BaseExplorerCase):
class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):
def test_explorer(self):
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.auxiliary_models = [
InferenceModelConfig(
model_path=get_api_model_path(),
tensor_parallel_size=1,
engine_num=2,
)
]
self.config.algorithm.repeat_times = 2
self.config.buffer.total_steps = 2
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k_ruler")
self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {
"std_threshold": 0.0001,
}
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)


class TestExplorerGSM8k(BaseExplorerCase):
Expand Down
10 changes: 2 additions & 8 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def init_process_group(
) -> None:
pass

def has_api_server(self) -> bool:
return False

def get_api_server_url(self) -> Optional[str]:
async def get_api_server_url(self) -> Optional[str]:
return None


Expand All @@ -161,10 +158,7 @@ def init_process_group(
) -> None:
pass

def has_api_server(self) -> bool:
return True

def get_api_server_url(self) -> str:
async def get_api_server_url(self) -> str:
return "http://localhost:12345"


Expand Down
11 changes: 11 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
default_workflow_type="math_workflow",
default_reward_fn_type="math_reward",
)
elif dataset_name == "gsm8k_ruler":
return TasksetConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"),
split="train",
format=FormatConfig(
prompt_key="question",
response_key="answer",
),
default_workflow_type="math_ruler_workflow",
)
elif dataset_name == "sft_for_gsm8k":
# SFT dataset with 8 samples
return ExperienceBufferConfig(
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ class InferenceModelConfig:
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
model_path: str = ""

engine_type: str = "vllm_async"
engine_type: str = "vllm"
engine_num: int = 1
tensor_parallel_size: int = 1
use_v1: bool = True
Expand Down
12 changes: 3 additions & 9 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def create_inference_models(
config=config.explorer.rollout_model,
)
)
if config.explorer.rollout_model.enable_openai_api:
for engine in rollout_engines:
engine.run_api_server.remote()

if config.explorer.rollout_model.enable_history:
logger.info(
"Model History recording is enabled. Please periodically extract "
Expand Down Expand Up @@ -138,10 +136,6 @@ def create_inference_models(
.remote(config=model_config)
)
auxiliary_engines.append(engines)
# all auxiliary engines run api server
for engines in auxiliary_engines:
for engine in engines:
engine.run_api_server.remote()

return rollout_engines, auxiliary_engines

Expand All @@ -159,10 +153,10 @@ def create_debug_inference_model(config: Config) -> None:
rollout_models, auxiliary_models = create_inference_models(config)
# make sure models are started
for m in rollout_models:
ray.get(m.get_model_path.remote())
ray.get(m.run_api_server.remote())
for models in auxiliary_models:
for m in models:
ray.get(m.get_model_path.remote())
ray.get(m.run_api_server.remote())
logger.info(
"----------------------------------------------------\n"
"Inference models started successfully for debugging.\n"
Expand Down
44 changes: 19 additions & 25 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@ def get_available_address(self) -> Tuple[str, int]:
port = s.getsockname()[1]
return address, port

def has_api_server(self) -> bool:
"""Check if the model has an API server."""
return False

def get_api_server_url(self) -> Optional[str]:
async def get_api_server_url(self) -> Optional[str]:
"""Get the API server URL if available."""
return None

Expand Down Expand Up @@ -106,26 +102,24 @@ def __init__(

async def prepare(self) -> None:
"""Prepare the model wrapper."""
if await self.model.has_api_server.remote():
self.api_address = await self.model.get_api_server_url.remote()
if self.api_address is None:
raise RuntimeError(
"Failed to connect to the API server. Please set `enable_openai_api` to `True`."
)
max_retries = 30
interval = 2 # seconds
for i in range(max_retries):
try:
async with httpx.AsyncClient() as client:
response = await client.get(self.api_address + "/health", timeout=5)
if response.status_code == 200:
return
except Exception as e:
self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}")
await asyncio.sleep(interval)
raise RuntimeError(
f"API server at {self.api_address} not ready after {max_retries} attempts."
)
self.api_address = await self.model.get_api_server_url.remote()
if self.api_address is None:
self.logger.info("API server is not enabled for inference model.")
return
max_retries = 30
interval = 2 # seconds
for i in range(max_retries):
try:
async with httpx.AsyncClient() as client:
response = await client.get(self.api_address + "/health", timeout=5)
if response.status_code == 200:
return
except Exception as e:
self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}")
await asyncio.sleep(interval)
raise RuntimeError(
f"API server at {self.api_address} not ready after {max_retries} attempts."
)

def _record_history(self, exps: Union[Experience, List[Experience]]) -> None:
"""Record experiences to history."""
Expand Down
54 changes: 32 additions & 22 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self.api_server_host = None
self.api_server_port = None
self.api_server = None
self.async_lock = asyncio.Lock()

async def _initialize_tokenizer(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -455,35 +456,44 @@ async def init_process_group(
),
)

async def run_api_server(self):
"""Run the OpenAI API server in a Ray actor."""
if not (self.api_server_host is None or self.api_server_port is None):
raise RuntimeError("API server is already running.")
from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor

self.api_server_host, self.api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
self.api_server_host,
self.api_server_port,
self.config.model_path,
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
)
)
async def run_api_server(self) -> bool:
"""Run the OpenAI API server in a Ray actor.

def has_api_server(self) -> bool:
return self.config.enable_openai_api
Returns:
success (bool): Whether the API server is started successfully.
"""
async with self.async_lock:
if not self.config.enable_openai_api:
return False # Not enabled

if self.api_server_host is not None and self.api_server_port is not None:
return True # already running

from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path,
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
)
)
self.api_server_host = api_server_host
self.api_server_port = api_server_port
return True

def get_api_server_url(self) -> Optional[str]:
async def get_api_server_url(self) -> Optional[str]:
"""Get the URL of the OpenAI API server.

Returns:
api_url (str): The URL of the OpenAI API server.
"""
if not self.has_api_server():
if not await self.run_api_server():
return None
return f"http://{self.api_server_host}:{self.api_server_port}"

Expand Down
12 changes: 9 additions & 3 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ async def _nccl_weights_update(self):
async def prepare(self) -> None:
"""Preparation before running."""
try:
# prepare experience pipeline
await self.experience_pipeline.prepare.remote()
self.logger.info("Experience pipeline is ready.")
# make sure all rollout models are ready
model_ready_ref = [model.__ray_ready__.remote() for model in self.models]
await asyncio.gather(*model_ready_ref)
self.logger.info("All rollout models are ready.")
run_api_ref = [model.run_api_server.remote() for model in self.models]
run_api_ref.extend(
model.run_api_server.remote()
for models in self.auxiliary_models
for model in models
)
await asyncio.gather(*run_api_ref)
self.logger.info("All models are ready.")

if not self.use_nccl_sync:
master_address, master_port = await self.models[0].get_available_address.remote()
Expand Down