diff --git a/.gitignore b/.gitignore index 5b5b47357..9cdd7aecf 100644 --- a/.gitignore +++ b/.gitignore @@ -39,9 +39,13 @@ docs/_spelling/ /skyrl-gym/dist *.log +<<<<<<< HEAD +trials/ +======= # SQLite database files *.db # uv lock files -uv.lock \ No newline at end of file +uv.lock +>>>>>>> main diff --git a/skyrl-train/examples/async/main_async.py b/skyrl-train/examples/async/main_async.py index 4a25dbe77..0d71a3913 100644 --- a/skyrl-train/examples/async/main_async.py +++ b/skyrl-train/examples/async/main_async.py @@ -5,7 +5,7 @@ import hydra from omegaconf import DictConfig from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg -from .async_trainer import AsyncRayPPOTrainer +from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer import asyncio from skyrl_train.utils import initialize_ray import ray @@ -23,7 +23,7 @@ def get_trainer( generator, colocate_pg, ): - return AsyncRayPPOTrainer( + return FullyAsyncRayPPOTrainer( cfg=cfg, tracker=tracker, tokenizer=tokenizer, @@ -33,7 +33,7 @@ def get_trainer( generator=generator, colocate_pg=colocate_pg, ) - + def run(self): trainer = self._setup_trainer() # Start the async training loop diff --git a/skyrl-train/examples/on_policy_distillation/README.md b/skyrl-train/examples/on_policy_distillation/README.md index 1c25c261c..58194035f 100644 --- a/skyrl-train/examples/on_policy_distillation/README.md +++ b/skyrl-train/examples/on_policy_distillation/README.md @@ -15,8 +15,12 @@ In `main_on_policy_distill.py` we provide a simple example for modifying SkyRL t To get started, first set up the dataset from the DAPO example: ```bash +<<<<<<< HEAD +uv run examples/algorithms/dapo/prepare_dapo_data.sh +======= # Run from the `skyrl-train` directory bash examples/algorithms/dapo/prepare_dapo_data.sh +>>>>>>> main ``` Then, just make sure to set the path to your desired teacher model, and you're ready to kick off training! diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh index 250e8252a..697371d25 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh @@ -2,7 +2,11 @@ set -x # Running on policy distillation for Math on the DAPO math dataset, with eval on AIME 2024. # Uses Qwen-3-1.7B-Base as the student model and an RL trained Qwen-3-4B as the teacher model +<<<<<<< HEAD +# uv run examples/algorithms/dapo/prepare_dapo_data.sh +======= # bash examples/algorithms/dapo/prepare_dapo_data.sh +>>>>>>> main # bash examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh DATA_DIR="$HOME/data/dapo" diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh index 670a35152..c59ee5257 100644 --- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh +++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh @@ -2,7 +2,11 @@ set -x # Running on policy distillation for Math on the DAPO math dataset, with eval on AIME 2024. # Uses Qwen-3-4B-Base as the student model and an RL trained Qwen-3-4B as the teacher model +<<<<<<< HEAD +# uv run examples/algorithms/dapo/prepare_dapo_data.sh +======= # bash examples/algorithms/dapo/prepare_dapo_data.sh +>>>>>>> main # bash examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh DATA_DIR="$HOME/data/dapo" diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py index db922953e..25a9db72a 100644 --- a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py +++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py @@ -10,7 +10,7 @@ from skyrl_train.utils.utils import initialize_ray from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator from examples.terminal_bench.dataset import TerminalBenchTaskDataset - +from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer class TerminalBenchExp(BasePPOExp): def get_generator(self, cfg, tokenizer, inference_engine_client): @@ -52,6 +52,28 @@ def get_eval_dataset(self): return prompts_dataset return None + def get_trainer( + self, + cfg, + tracker, + tokenizer, + train_dataset, + eval_dataset, + inference_engine_client, + generator, + colocate_pg, + ): + return FullyAsyncRayPPOTrainer( + cfg=cfg, + tracker=tracker, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + inference_engine_client=inference_engine_client, + generator=generator, + colocate_pg=colocate_pg, + ) + @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: DictConfig): diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py index cd8f478fb..e8f8054c3 100644 --- a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py +++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py @@ -18,7 +18,7 @@ config_dir, ) from skyrl_train.generators.base import GeneratorInput -from examples.terminal_bench.generator.terminal_bench_generator import TerminalBenchGenerator +from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator from examples.terminal_bench.dataset import TerminalBenchTaskDataset diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py new file mode 100644 index 000000000..bc2827290 --- /dev/null +++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py @@ -0,0 +1,36 @@ +""" +Main entrypoint for training on terminal bench tasks. +""" +import ray +import hydra +from omegaconf import DictConfig +from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir +from skyrl_train.utils import validate_cfg +from skyrl_train.utils.utils import initialize_ray +from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator +from examples.terminal_bench.dataset import TerminalBenchTaskDataset +from examples.terminal_bench.entrypoints.main_tbench import TerminalBenchExp +from examples.on_policy_distillation.main_on_policy_distill import OnPolicyDistillationTrainer + +class OnPolicyDistillationTerminalBenchExp(TerminalBenchExp): + def get_trainer(self, *args, **kwargs): + return OnPolicyDistillationTrainer(*args, **kwargs) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: DictConfig): + # make sure that the training loop is not run on the head node. + exp = OnPolicyDistillationTerminalBenchExp(cfg) + exp.run() + +@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +def main(cfg: DictConfig) -> None: + # validate the arguments + validate_cfg(cfg) + + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py b/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py deleted file mode 100644 index 57495d093..000000000 --- a/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py +++ /dev/null @@ -1,183 +0,0 @@ -import asyncio -from dataclasses import dataclass -from typing import List, Optional -from uuid import uuid4 -from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput -from skyrl_train.generators.utils import get_rollout_metrics, get_response_ids_and_loss_mask_from_messages -from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient -from skyrl_train.inference_engines.base import ConversationType -from omegaconf import DictConfig -from pathlib import Path -from harbor.models.trial.config import TrialConfig, AgentConfig, TaskConfig, EnvironmentConfig -from harbor.models.environment_type import EnvironmentType -from harbor.models.agent.name import AgentName -from harbor.trial.trial import Trial - - -@dataclass -class TerminalBenchAgentOutput: - response_ids: List[int] - reward: float - stop_reason: str - loss_mask: List[int] - prompt_ids: List[int] - rollout_logprobs: Optional[List[float]] - - -class TerminalBenchGenerator(GeneratorInterface): - def __init__( - self, - generator_cfg: DictConfig, - terminal_bench_cfg: DictConfig, - inference_engine_client: InferenceEngineClient, - tokenizer, - ): - """ - Args: - generator_cfg: DictConfig object containing the generator configuration - terminal_bench_cfg: DictConfig object containing the terminal bench configuration - inference_engine_client: InferenceEngineClient object for interacting with the inference engines - tokenizer: tokenizer object for encoding and decoding text - """ - self.base_url = f"http://{generator_cfg.http_endpoint_host}:{generator_cfg.http_endpoint_port}" - self.generator_cfg = generator_cfg - self.tokenizer = tokenizer - self.model_name = generator_cfg.model_name - - # TerminalBench config - self.trials_dir = terminal_bench_cfg.trials_dir - self.agent_name = terminal_bench_cfg.agent_name - self.max_episodes = terminal_bench_cfg.max_episodes - - if self.generator_cfg.chat_template.name_or_path is not None: - raise NotImplementedError("TerminalBenchGenerator doesn't support custom chat template") - - async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: - prompts = input_batch["prompts"] - tasks = [] - for prompt in prompts: - tasks.append( - self.terminal_bench_agent_loop( - prompt=prompt, - ) - ) - - all_outputs = await asyncio.gather(*tasks) - - responses = [output.response_ids for output in all_outputs] - rewards = [output.reward for output in all_outputs] - rollout_metrics = get_rollout_metrics(responses, rewards) - - generator_output: GeneratorOutput = { - "prompt_token_ids": [output.prompt_ids for output in all_outputs], - "response_ids": responses, - "rewards": rewards, - "loss_masks": [output.loss_mask for output in all_outputs], - "stop_reasons": [output.stop_reason for output in all_outputs], - "rollout_metrics": rollout_metrics, - "rollout_logprobs": [output.rollout_logprobs for output in all_outputs], - } - - return generator_output - - async def terminal_bench_agent_loop( - self, - prompt: ConversationType, - ) -> TerminalBenchAgentOutput: - """ - Run a single terminal_bench agent. - """ - # Generate session_id for sticky routing to inference engines - # All LLM requests in this trial will share the same session_id - session_id = uuid4().hex - - if self.agent_name == "terminus": - trial_config = TrialConfig( - task=TaskConfig(path=prompt), - trials_dir=Path(self.trials_dir), - environment=EnvironmentConfig(type=EnvironmentType.DAYTONA), - agent=AgentConfig( - name=AgentName.TERMINUS_2.value, - model_name=f"hosted_vllm/{self.model_name}", - kwargs={ - "api_base": f"{self.base_url}/v1", - "key": "fake_key", - "session_id": session_id, - "max_episodes": self.max_episodes, - }, - ), - ) - elif self.agent_name == "oracle": - trial_config = TrialConfig( - task=TaskConfig(path=prompt), - trials_dir=Path(self.trials_dir), - environment=EnvironmentConfig(type=EnvironmentType.DAYTONA), - agent=AgentConfig( - name=AgentName.ORACLE, - model_name=f"hosted_vllm/{self.model_name}", - ), - ) - else: - raise ValueError(f"Invalid agent name: {self.agent_name}") - - trial = Trial(trial_config) - # Run the trial - while True: - try: - results = await trial.run() - print(f"Results: {results}") - if not results.verifier_result: - print(f"[WARNING] Exception info: {results.exception_info}") - continue - reward = results.verifier_result.reward - chat_history = results.agent_result.all_messages - if len(chat_history) > 0: - break - else: - print(f"[WARNING] Agent {self.agent_name} did not return a response") - except Exception as e: - print(f"Error running trial: {e}") - continue - - # Use the first message as the prompt. We assume to be no systems messages. - assert chat_history[0]["role"] == "user", "The first message should be a user message" - prompt = [chat_history[0]] - prompt_ids = self.tokenizer.apply_chat_template( - prompt, - add_generation_prompt=False, # the message below will add it themselves - tokenize=True, - ) - initial_prompt_length = len(prompt_ids) - - # Process response messages (everything after the first message) - response_messages = chat_history[1:] - assistant_logprobs = getattr(results.agent_result, "output_logprobs", None) - response_ids, loss_mask, rollout_logprobs = get_response_ids_and_loss_mask_from_messages( - response_messages, self.tokenizer, assistant_logprobs - ) - - # Determine stop reason - max_response_tokens = ( - self.generator_cfg.sampling_params.max_generate_length - + self.generator_cfg.max_input_length - - initial_prompt_length - ) - stop_reason = "complete" # Default for trial completion - if len(response_ids) > max_response_tokens: - stop_reason = "length" - # TODO(Charlie): should we do rewards = self._zero_reward_if_not_stop(rewards, stop_reasons)? - - # Truncate to maximum allowed length - response_ids = response_ids[:max_response_tokens] - loss_mask = loss_mask[:max_response_tokens] - rollout_logprobs = rollout_logprobs[:max_response_tokens] - - return TerminalBenchAgentOutput( - response_ids=response_ids, - reward=reward, - stop_reason=stop_reason, - loss_mask=loss_mask, - prompt_ids=prompt_ids, - # in case harbor doesn't return logprobs, use None - rollout_logprobs=rollout_logprobs if assistant_logprobs is not None else None, - ) diff --git a/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2 b/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2 new file mode 100644 index 000000000..1643aa5fb --- /dev/null +++ b/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2 @@ -0,0 +1,72 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} diff --git a/skyrl-train/examples/terminal_bench/run_tbench.sh b/skyrl-train/examples/terminal_bench/run_tbench.sh index a44627324..e028877e7 100644 --- a/skyrl-train/examples/terminal_bench/run_tbench.sh +++ b/skyrl-train/examples/terminal_bench/run_tbench.sh @@ -6,13 +6,13 @@ set -x # export WANDB_API_KEY= # bash examples/terminal_bench/run_tbench.sh -DATA_DIR="$(pwd)/harbor/examples/tasks" +DATA_DIR="$HOME/data/tasks" NUM_GPUS=1 LOGGER="console" # change to "console" to print to stdout TBENCH_CONFIG_DIR="examples/terminal_bench" SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train. -uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m examples.terminal_bench.entrypoints.main_tbench \ +uv run --isolated --extra vllm --extra sandboxes --with "sandbox@./sandboxes" -m examples.terminal_bench.entrypoints.main_tbench \ data.train_data="['$DATA_DIR']" \ data.val_data="['$DATA_DIR']" \ hydra.searchpath=[file://$TBENCH_CONFIG_DIR] \ @@ -33,7 +33,7 @@ uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m exa trainer.eval_interval=-1 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=8 \ - trainer.policy_mini_batch_size=2 \ + trainer.policy_mini_batch_size=8 \ trainer.micro_forward_batch_size_per_gpu=1 \ trainer.micro_train_batch_size_per_gpu=1 \ trainer.ckpt_interval=-1 \ diff --git a/skyrl-train/examples/terminal_bench/run_tbench_gen.sh b/skyrl-train/examples/terminal_bench/run_tbench_gen.sh index 82dffcdc3..5d233dbc1 100644 --- a/skyrl-train/examples/terminal_bench/run_tbench_gen.sh +++ b/skyrl-train/examples/terminal_bench/run_tbench_gen.sh @@ -6,13 +6,13 @@ set -x # export WANDB_API_KEY= # bash examples/terminal_bench/run_tbench.sh -DATA_DIR="$(pwd)/harbor/examples/tasks" +DATA_DIR="$HOME/data/tasks" NUM_GPUS=1 LOGGER="console" # change to "console" to print to stdout TBENCH_CONFIG_DIR="examples/terminal_bench" SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train. -uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m examples.terminal_bench.entrypoints.main_tbench_generate \ +uv run --isolated --extra vllm --extra sandboxes --with "sandbox@./sandboxes" -m examples.terminal_bench.entrypoints.main_tbench_generate \ data.train_data="['$DATA_DIR']" \ hydra.searchpath=[file://$TBENCH_CONFIG_DIR] \ +terminal_bench_config=terminal_bench \ diff --git a/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml b/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml index f95d7adb8..e69de29bb 100644 --- a/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml +++ b/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml @@ -1,5 +0,0 @@ -# @package terminal_bench_config - -trials_dir: "~/trials" -agent_name: "terminus" -max_episodes: 16 \ No newline at end of file diff --git a/skyrl-train/examples/terminal_bench/terminal_bench_generator.py b/skyrl-train/examples/terminal_bench/terminal_bench_generator.py new file mode 100644 index 000000000..f0eb2e5f8 --- /dev/null +++ b/skyrl-train/examples/terminal_bench/terminal_bench_generator.py @@ -0,0 +1,301 @@ +import asyncio +from dataclasses import dataclass +from typing import List, Optional +from loguru import logger +from uuid import uuid4 +from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID +from skyrl_train.generators.utils import get_rollout_metrics +from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.inference_engines.base import ConversationType +from omegaconf import DictConfig +from pathlib import Path +from harbor.models.trial.config import TrialConfig, AgentConfig, TaskConfig, EnvironmentConfig +from harbor.models.environment_type import EnvironmentType +from harbor.models.agent.name import AgentName +from harbor.trial.trial import Trial + +# We have N retries for each trial, if one of the rollout (out of n_samples_per_prompt) fails +# after N attemptes, we skip this prompt altogether. +MAX_NUM_RETRIES_PER_TRIAL = 2 + +@dataclass +class TerminalBenchAgentOutput: + response_ids: List[int] + reward: float + stop_reason: str + loss_mask: List[int] + prompt_ids: List[int] + trajectory_id: TrajectoryID + summarization_count: Optional[int] = None + rollout_logprobs: Optional[List[float]] = None + +class TerminalBenchGenerator(GeneratorInterface): + def __init__( + self, + generator_cfg: DictConfig, + terminal_bench_cfg: DictConfig, + inference_engine_client: InferenceEngineClient, + tokenizer, + ): + """ + Args: + generator_cfg: DictConfig object containing the generator configuration + terminal_bench_cfg: DictConfig object containing the terminal bench configuration + inference_engine_client: InferenceEngineClient object for interacting with the inference engines + tokenizer: tokenizer object for encoding and decoding text + """ + self.base_url = f"http://{generator_cfg.http_endpoint_host}:{generator_cfg.http_endpoint_port}" + self.generator_cfg = generator_cfg + self.tokenizer = tokenizer + self.model_name = generator_cfg.model_name + + # TerminalBench config. Parse here to ensure everything is passed in. + self.trials_dir = terminal_bench_cfg.trials_dir + self.agent_name = terminal_bench_cfg.agent_name + self.max_episodes = terminal_bench_cfg.max_episodes + self.enable_summarize = terminal_bench_cfg.get("enable_summarize", True) + + # Optional overrides for the environment + self.override_memory_mb = terminal_bench_cfg.get("override_memory_mb") + self.override_storage_mb = terminal_bench_cfg.get("override_storage_mb") + self.override_cpus = terminal_bench_cfg.get("override_cpus") + + logger.info(f"TerminalBenchGenerator initialized with overrides: memory={self.override_memory_mb}, storage={self.override_storage_mb}, cpus={self.override_cpus}") + + # Read custom chat template + custom_chat_template_path = generator_cfg.engine_init_kwargs.get("custom_chat_template_chat_completion_path", None) + if custom_chat_template_path: + with open(custom_chat_template_path, "r") as f: + self.custom_chat_template_content = f.read() + logger.info(f"TerminalBenchGenerator initialized with custom chat template read from: {custom_chat_template_path}") + else: + self.custom_chat_template_content = None + + async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: + tasks = [] + for i in range(len(input_batch["prompts"])): + tasks.append( + self.terminal_bench_agent_loop( + prompt=input_batch["prompts"][i], + trajectory_id=input_batch["trajectory_ids"][i], + ) + ) + + all_outputs: List[TerminalBenchAgentOutput] = await asyncio.gather(*tasks) + + # For a group of trajectories (n_samples_per_prompt trajectories for the same prompt), if one + # of the trajectories fails, we skip the entire group. We also skip the group for rollout metric aggregation + failed_instance_ids = set() + num_failed_trajectories = 0 # per-trajectory, rather than per-instance + successful_outputs: List[TerminalBenchAgentOutput] = [] # only for metrics purpose + for output in all_outputs: + if output.stop_reason == "error": + failed_instance_ids.add(output.trajectory_id.instance_id) + num_failed_trajectories += 1 + + for output in all_outputs: + if output.trajectory_id.instance_id in failed_instance_ids: + output.response_ids = [0] + output.stop_reason = "error" + output.loss_mask = [0] + output.prompt_ids = [0] + output.reward = 0 + else: + successful_outputs.append(output) + + # Calculate rollout metrics for successful outputs + if len(successful_outputs) > 0: + rollout_metrics = get_rollout_metrics( + [output.response_ids for output in successful_outputs], + [output.reward for output in successful_outputs], + ) + rollout_metrics["generate/trajectories_summarized"] = sum(1 for output in successful_outputs if output.summarization_count > 0) + rollout_metrics["generate/trajectories_truncated"] = sum(1 for output in successful_outputs if output.stop_reason == "length") + else: + rollout_metrics = {} + rollout_metrics["generate/num_failed_instances"] = len(failed_instance_ids) + rollout_metrics["generate/num_failed_trajectories"] = num_failed_trajectories + + generator_output: GeneratorOutput = { + "prompt_token_ids": [output.prompt_ids for output in all_outputs], + "response_ids": [output.response_ids for output in all_outputs], + "rewards": [output.reward for output in all_outputs], + "loss_masks": [output.loss_mask for output in all_outputs], + "stop_reasons": [output.stop_reason for output in all_outputs], + "rollout_metrics": rollout_metrics, + "rollout_logprobs": None, + } + + return generator_output + + async def terminal_bench_agent_loop( + self, + prompt: ConversationType, + trajectory_id: TrajectoryID, + ) -> TerminalBenchAgentOutput: + """ + Run a single terminal_bench agent. + """ + # Generate session_id for sticky routing to inference engines + # All LLM requests in this trial will share the same session_id + session_id = uuid4().hex + + environment_config = EnvironmentConfig( + type=EnvironmentType.DAYTONA, + override_cpus=self.override_cpus, + override_memory_mb=self.override_memory_mb, + override_storage_mb=self.override_storage_mb, + ) + + if self.agent_name == "terminus": + trial_config = TrialConfig( + task=TaskConfig(path=prompt), + trials_dir=Path(self.trials_dir), + environment=environment_config, + agent=AgentConfig( + name=AgentName.TERMINUS_2.value, + model_name=f"hosted_vllm/{self.model_name}", + kwargs={ + "api_base": f"{self.base_url}/v1", + "key": "fake_key", + "max_episodes": self.max_episodes, + "session_id": session_id, + "enable_summarize": self.enable_summarize, + "store_all_messages": True, + "collect_rollout_details": True, + }, + ), + ) + elif self.agent_name == "oracle": + trial_config = TrialConfig( + task=TaskConfig(path=prompt), + trials_dir=Path(self.trials_dir), + environment=environment_config, + agent=AgentConfig( + name=AgentName.ORACLE, + model_name=f"hosted_vllm/{self.model_name}", + ), + ) + else: + raise ValueError(f"Invalid agent name: {self.agent_name}") + + trial = Trial(trial_config) + + # Run the trial to get `rewards`, `chat_history`, and `summarization_count` + successful = False + reward = None + chat_history = None + summarization_count = None + for i in range(MAX_NUM_RETRIES_PER_TRIAL): + prefix = f"Trajectory {trajectory_id} attempt {i+1}/{MAX_NUM_RETRIES_PER_TRIAL}" + results = None + try: + results = await trial.run() + if not results.verifier_result: + logger.warning(f"{prefix} failed: Exception info: {results.exception_info}. Results: {results}") + continue + + reward = results.verifier_result.rewards["reward"] + chat_history = results.agent_result.metadata['all_messages'] + summarization_count = results.agent_result.metadata['summarization_count'] + if len(chat_history) > 1 and chat_history[0]["role"] == "user": + successful = True + logger.info(f"{prefix} successful: Results: {results.agent_result.metadata}") + break + else: + logger.warning(f"{prefix} failed: Agent {self.agent_name} did not return a chat history with a user message. chat_history: {chat_history}\n\nResults: {results}") + except Exception as e: + logger.warning(f"{prefix} failed: Error running trial: {e}. Results: {results}") + continue + + if not successful: + # We make loss mask 0 so it does not contribute to model updates + logger.warning(f"Trajectory {trajectory_id} failed after {MAX_NUM_RETRIES_PER_TRIAL} attempts, will set loss mask to [0].") + return TerminalBenchAgentOutput( + response_ids=[0], + reward=0, + stop_reason="error", + loss_mask=[0], + prompt_ids=[0], + trajectory_id=trajectory_id, + ) + + # Use the first message as the prompt. We assume to be no systems messages. + assert chat_history[0]["role"] == "user", "The first message should be a user message" + prompt = [chat_history[0]] + prompt_ids = self.tokenizer.apply_chat_template( + prompt, + add_generation_prompt=False, # the message below will add it themselves + tokenize=True, + chat_template=self.custom_chat_template_content, + ) + initial_prompt_length = len(prompt_ids) + + # Process response messages (everything after the first message) + response_messages = chat_history[1:] + rollout_details = getattr(results.agent_result, "rollout_details", None) + + # Extract from rollout_details tuple: ([{'prompt_token_ids': [...], 'completion_token_ids': [...], 'logprobs': [...]}],) + try: + details = rollout_details[0] + if isinstance(details, list): + details = details[0] + except: + logger.info(f"Unpacked details: {rollout_details[0]}") + prompt_ids_list = details['prompt_token_ids'] + completion_ids_list = details['completion_token_ids'] + logprobs_list = details['logprobs'] + + # Initial prompt is the first one + prompt_ids = list(prompt_ids_list[0]) + + # Build response_ids, loss_mask, and rollout_logprobs turn by turn + # Structure: completion[0] + user_feedback[0] + completion[1] + user_feedback[1] + ... + completion[n-1] + # where user_feedback[i] = prompt[i+1][len(prompt[i]) + len(completion[i]):] + response_ids = [] + loss_mask = [] + rollout_logprobs = [] + + num_turns = len(prompt_ids_list) + for turn in range(num_turns): + # Add completion tokens (mask = 1) + completion = list(completion_ids_list[turn]) + logprobs = list(logprobs_list[turn]) + response_ids.extend(completion) + loss_mask.extend([1] * len(completion)) + rollout_logprobs.extend(logprobs) + + # Add user feedback tokens after this completion (if not last turn) + if turn < num_turns - 1: + # user_feedback = prompt[turn+1] minus (prompt[turn] + completion[turn]) + prev_len = len(prompt_ids_list[turn]) + len(completion_ids_list[turn]) + user_tokens = list(prompt_ids_list[turn + 1][prev_len:]) + response_ids.extend(user_tokens) + loss_mask.extend([0] * len(user_tokens)) + rollout_logprobs.extend([0.0] * len(user_tokens)) + + # Determine stop reason + max_response_tokens = ( + self.generator_cfg.sampling_params.max_generate_length + + self.generator_cfg.max_input_length + - initial_prompt_length + ) + stop_reason = "complete" # Default for trial completion + if len(response_ids) > max_response_tokens: + stop_reason = "length" + + # Truncate to maximum allowed length + response_ids = response_ids[:max_response_tokens] + loss_mask = loss_mask[:max_response_tokens] + if rollout_logprobs is not None: + rollout_logprobs = rollout_logprobs[:max_response_tokens] + return TerminalBenchAgentOutput( + response_ids=response_ids, + reward=reward, + stop_reason=stop_reason, + loss_mask=loss_mask, + prompt_ids=prompt_ids, + trajectory_id=trajectory_id, + rollout_logprobs=rollout_logprobs, + summarization_count=summarization_count, + ) diff --git a/skyrl-train/skyrl_train/entrypoints/main_base.py b/skyrl-train/skyrl_train/entrypoints/main_base.py index 0fb2dc4b5..a2d56a64e 100644 --- a/skyrl-train/skyrl_train/entrypoints/main_base.py +++ b/skyrl-train/skyrl_train/entrypoints/main_base.py @@ -22,6 +22,7 @@ import hydra from loguru import logger from skyrl_train.utils.tracking import Tracking +import asyncio import multiprocessing as mp # NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method. @@ -312,7 +313,7 @@ def _setup_trainer(self): def run(self): trainer = self._setup_trainer() # Start the training loop - trainer.train() + asyncio.run(trainer.train()) @ray.remote(num_cpus=1) diff --git a/skyrl-train/skyrl_train/generators/utils.py b/skyrl-train/skyrl_train/generators/utils.py index d65ea9bec..692daf69d 100644 --- a/skyrl-train/skyrl_train/generators/utils.py +++ b/skyrl-train/skyrl_train/generators/utils.py @@ -104,13 +104,13 @@ def get_custom_chat_template(chat_template_config: Optional[Union[dict, DictConf raise ValueError(f"Invalid source '{source}'. Must be 'name' or 'file'") -def get_generation_prompt_ids(tokenizer) -> List[int]: +def get_generation_prompt_ids(tokenizer, custom_chat_template=None) -> List[int]: """ Helper function to get the generation prompt ids for a given tokenizer. """ - empty_user = tokenizer.apply_chat_template([{"role": "user", "content": ""}], tokenize=True) + empty_user = tokenizer.apply_chat_template([{"role": "user", "content": ""}], tokenize=True, chat_template=custom_chat_template) empty_user_with_generation_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True + [{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True, chat_template=custom_chat_template ) generation_prompt_ids = empty_user_with_generation_prompt[len(empty_user) :] diff --git a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py index b2e038702..d12d6e7dc 100644 --- a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py @@ -354,6 +354,7 @@ async def _chat_completion_with_retry( async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: session_id = request_payload["json"].pop("session_id", None) + # print(f"CHARLIE session_id: {session_id}") if session_id is None: # if session_id is not provided, we'll use a random engine engine_idx = random.randint(0, len(self.engines) - 1) diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index 0a254695f..26e57e652 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -1,6 +1,7 @@ import os from typing import List, Any, Dict, Optional, Tuple, Iterator from dataclasses import dataclass +from loguru import logger from http import HTTPStatus import ray import torch @@ -19,6 +20,7 @@ CompletionRequest, CompletionResponse, ) +from vllm.v1.metrics.loggers import LoggingStatLogger from vllm.lora.request import LoRARequest from torch.distributed import destroy_process_group from skyrl_train.distributed.utils import init_custom_process_group @@ -47,6 +49,7 @@ class Logprob: def setup_envvars_for_vllm(kwargs, bundle_indices): noset_visible_devices = kwargs.pop("noset_visible_devices") + os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" # TODO(Charlie): may not be needed. if kwargs.get("distributed_executor_backend") == "ray": # a hack to make the script work. # stop ray from manipulating *_VISIBLE_DEVICES @@ -362,6 +365,22 @@ async def _destroy_weights_update_group(self): engine = self._get_engine() return await asyncio.to_thread(engine.collective_rpc, "destroy_weights_update_group") +class V1LoggingStatLoggerFixed(LoggingStatLogger): + """ + A fixed version of LoggingStatLogger that actually logs during the record method. + The log method is otherwise not called in the VLLM codebase. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.log_interval = 5 + + def record(self, *args: Any, **kwargs: Any) -> None: + super().record(*args, **kwargs) + now = time.monotonic() + if now - self.last_log_time > self.log_interval: + self.log() + self.last_log_time = now class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine): """Asynchronous VLLM engine.""" @@ -373,11 +392,16 @@ def __init__(self, *args, **kwargs): def _create_engine(self, *args, **kwargs): openai_kwargs = pop_openai_kwargs(kwargs) # TODO (erictang000): potentially enable log requests for a debugging mode + custom_chat_template_path = kwargs.pop("custom_chat_template_chat_completion_path", None) + stat_loggers = [V1LoggingStatLoggerFixed] + engine_args = vllm.AsyncEngineArgs(**kwargs) + if version.parse(vllm.__version__) >= version.parse("0.10.0"): engine_args = vllm.AsyncEngineArgs(enable_log_requests=False, **kwargs) else: engine_args = vllm.AsyncEngineArgs(disable_log_requests=True, **kwargs) - engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) + engine = vllm.AsyncLLMEngine.from_engine_args(engine_args, stat_loggers=stat_loggers) + # Adapted from https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L327 model_config = engine.model_config @@ -386,16 +410,24 @@ def _create_engine(self, *args, **kwargs): model_name = model_path base_model_paths = [BaseModelPath(name=model_name, model_path=model_path)] - models = OpenAIServingModels(engine, model_config, base_model_paths) + models = OpenAIServingModels(engine, base_model_paths) + + # TODO(Charlie): adding custom chat template for chat completion. Hacky! + if custom_chat_template_path: + with open(custom_chat_template_path, "r") as f: + custom_chat_template_content = f.read() + logger.info(f"Initializing OpenAIServingChat with custom_chat_template read from: {custom_chat_template_path}") + else: + custom_chat_template_content = None + # TODO(Charlie): revisit kwargs `enable_auto_tools` and `tool_parser` when we need to # support OAI-style tool calling; and `request_logger` for better debugging. self.openai_serving_chat = OpenAIServingChat( engine_client=engine, - model_config=model_config, models=models, response_role="assistant", request_logger=None, - chat_template=None, + chat_template=custom_chat_template_content, chat_template_content_format="auto", **openai_kwargs, ) @@ -404,7 +436,6 @@ def _create_engine(self, *args, **kwargs): # `enable_prompt_tokens_details`, `enable_force_include_usage`. self.openai_serving_completion = OpenAIServingCompletion( engine_client=engine, - model_config=model_config, models=models, request_logger=None, ) @@ -526,6 +557,15 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint body = request_payload.get("json", {}) headers = request_payload.get("headers", {}) + # TODO(Charlie): Hacky! We are hijacking to update the sampling params. + # Can we allow Harbor to use customized sampling params? + body.update({ + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + }) + # 1. Build request try: if endpoint == "/chat/completions": diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 0f5c41151..c71987217 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1,4 +1,3 @@ -import asyncio import math import os import shutil @@ -148,7 +147,7 @@ async def eval(self) -> Dict[str, float]: ) return eval_metrics - def train(self): + async def train(self): """ Main training loop for PPO """ diff --git a/skyrl-train/skyrl_train/utils/tracking.py b/skyrl-train/skyrl_train/utils/tracking.py index 5a6f45d8e..0f703c2c5 100644 --- a/skyrl-train/skyrl_train/utils/tracking.py +++ b/skyrl-train/skyrl_train/utils/tracking.py @@ -23,6 +23,34 @@ from loguru import logger from omegaconf import DictConfig, OmegaConf import pprint +import ray +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + +@ray.remote +class WandbNodeLogger: + """ + A Ray actor that initializes wandb on a specific node to capture system metrics. + """ + def __init__(self, project_name, experiment_name, config, run_id, group_name, x_label): + import wandb + # Initialize wandb with the same run ID to aggregate system metrics + run = wandb.init( + project=project_name, + name=experiment_name, + id=run_id, + config=config, + resume="allow", + group=group_name, + job_type="worker_monitor", + settings=wandb.Settings( + mode="shared", + x_primary=False, + x_update_finish_state=False, + x_label=x_label, + ) + ) + self.wandb = run # TODO(tgriggs): Test all backends. @@ -41,8 +69,28 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str] import wandb from omegaconf import OmegaConf - wandb.init(project=project_name, name=experiment_name, config=OmegaConf.to_container(config, resolve=True)) - self.logger["wandb"] = wandb + current_node_ip = "head" + if ray.is_initialized(): + try: + current_node_ip = ray.util.get_node_ip_address() + except Exception as e: + logger.warning(f"Failed to get node IP address, defaulting to 'head'. Error: {e}") + + run = wandb.init( + project=project_name, + name=experiment_name, + config=OmegaConf.to_container(config, resolve=True), + group=experiment_name, + resume="allow", + settings=wandb.Settings( + mode="shared", # mainly for multi-node training's systems metrics aggregation + x_primary=True, + x_label=f"node-{current_node_ip}" + ) + ) + run_id = run.id + self.logger["wandb"] = run + self._prepare_worker_nodes_systems_logging_wandb(project_name, experiment_name, run_id, config, current_node_ip) if "mlflow" in backends: self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config) @@ -73,6 +121,58 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str] self.console_logger = ConsoleLogger() self.logger["console"] = self.console_logger + + def _prepare_worker_nodes_systems_logging_wandb(self, project_name, experiment_name, run_id, config, current_node_ip): + """ + In multi-node training, we spawn WandbNodeLogger actors on each worker node to capture system metrics like + GPU utilization. We use `mode="shared"` to aggregate system metrics from all nodes to the same Wandb run. + However, with this approach, the systems metrics panels do not appear in the Wandb UI automatically but requires + us to manually create the panels. The alternative is to create a run for each node and group them by group_name. + We prefer to keep all nodes metrics to the same run. + """ + # Start WandB loggers on other nodes using Ray + if ray.is_initialized(): + try: + # current_node_ip (head) is already set in __init__ + nodes = ray.nodes() + self.remote_loggers = [] + + for node in nodes: + if not node["Alive"]: + continue + + node_ip = node["NodeManagerAddress"] + # Skip the current node as it's already logging + if node_ip == current_node_ip: + continue + + try: + # Launch a logger on this node + logger_actor = WandbNodeLogger.options( + num_cpus=0, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=node["NodeID"], + soft=False, # fail if that node is gone or infeasible + ), + ).remote( + project_name=project_name, + experiment_name=experiment_name, + config=OmegaConf.to_container(config, resolve=True), + run_id=run_id, + group_name=experiment_name, + x_label=f"node-{node_ip}" + ) + self.remote_loggers.append(logger_actor) + except Exception as e: + logger.warning(f"Failed to spawn WandbNodeLogger on {node_ip}: {e}") + logger.info(f"WandbNodeLogger initialized on 'node-{node_ip}'") + + except Exception as e: + logger.warning(f"Failed to setup distributed wandb logging: {e}") + else: + logger.warning("Ray is not initialized, skipping distributed wandb logging") + + def log(self, data, step, commit=False): for logger_name, logger_instance in self.logger.items(): if logger_name == "wandb": diff --git a/skyrl-train/skyrl_train/weights_manager.py b/skyrl-train/skyrl_train/weights_manager.py index bfd3aa5b2..d63322c0a 100644 --- a/skyrl-train/skyrl_train/weights_manager.py +++ b/skyrl-train/skyrl_train/weights_manager.py @@ -27,6 +27,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.weights_manager.__exit__(exc_type, exc_val, exc_tb) return False + async def __aenter__(self): + if self.condition: + await self.weights_manager.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.condition: + return await self.weights_manager.__aexit__(exc_type, exc_val, exc_tb) + return False + class InferenceWeightsManager: """Manages weight syncing and offloading/backloading between the policy model and the InferenceEngines. diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index 9aea864ac..3d7bfbc82 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -129,7 +129,7 @@ def init_model(self, model_path, num_training_steps: int = None): use_flash_attention_2=self.cfg.trainer.flash_attn, # NOTE (sumanthrh): Model initialization should always be in fp32 # during training - bf16=False, + bf16=True, lora_rank=self.cfg.trainer.policy.model.lora.rank, lora_alpha=self.cfg.trainer.policy.model.lora.alpha, lora_dropout=self.cfg.trainer.policy.model.lora.dropout, @@ -357,7 +357,7 @@ def init_model(self, model_path, num_training_steps: int = None): use_flash_attention_2=self.cfg.trainer.flash_attn, # NOTE (sumanthrh): Model initialization should always be in fp32 # during training - bf16=False, + bf16=True, lora_rank=self.cfg.trainer.critic.model.lora.rank, lora_alpha=self.cfg.trainer.critic.model.lora.alpha, lora_dropout=self.cfg.trainer.critic.model.lora.dropout, diff --git a/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2 b/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2 new file mode 100644 index 000000000..15691e71b --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2 @@ -0,0 +1,85 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if reasoning_content %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file