diff --git a/.gitignore b/.gitignore
index 5b5b47357..9cdd7aecf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -39,9 +39,13 @@ docs/_spelling/
/skyrl-gym/dist
*.log
+<<<<<<< HEAD
+trials/
+=======
# SQLite database files
*.db
# uv lock files
-uv.lock
\ No newline at end of file
+uv.lock
+>>>>>>> main
diff --git a/skyrl-train/examples/async/main_async.py b/skyrl-train/examples/async/main_async.py
index 4a25dbe77..0d71a3913 100644
--- a/skyrl-train/examples/async/main_async.py
+++ b/skyrl-train/examples/async/main_async.py
@@ -5,7 +5,7 @@
import hydra
from omegaconf import DictConfig
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
-from .async_trainer import AsyncRayPPOTrainer
+from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer
import asyncio
from skyrl_train.utils import initialize_ray
import ray
@@ -23,7 +23,7 @@ def get_trainer(
generator,
colocate_pg,
):
- return AsyncRayPPOTrainer(
+ return FullyAsyncRayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
@@ -33,7 +33,7 @@ def get_trainer(
generator=generator,
colocate_pg=colocate_pg,
)
-
+
def run(self):
trainer = self._setup_trainer()
# Start the async training loop
diff --git a/skyrl-train/examples/on_policy_distillation/README.md b/skyrl-train/examples/on_policy_distillation/README.md
index 1c25c261c..58194035f 100644
--- a/skyrl-train/examples/on_policy_distillation/README.md
+++ b/skyrl-train/examples/on_policy_distillation/README.md
@@ -15,8 +15,12 @@ In `main_on_policy_distill.py` we provide a simple example for modifying SkyRL t
To get started, first set up the dataset from the DAPO example:
```bash
+<<<<<<< HEAD
+uv run examples/algorithms/dapo/prepare_dapo_data.sh
+=======
# Run from the `skyrl-train` directory
bash examples/algorithms/dapo/prepare_dapo_data.sh
+>>>>>>> main
```
Then, just make sure to set the path to your desired teacher model, and you're ready to kick off training!
diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh
index 250e8252a..697371d25 100644
--- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh
+++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh
@@ -2,7 +2,11 @@ set -x
# Running on policy distillation for Math on the DAPO math dataset, with eval on AIME 2024.
# Uses Qwen-3-1.7B-Base as the student model and an RL trained Qwen-3-4B as the teacher model
+<<<<<<< HEAD
+# uv run examples/algorithms/dapo/prepare_dapo_data.sh
+=======
# bash examples/algorithms/dapo/prepare_dapo_data.sh
+>>>>>>> main
# bash examples/on_policy_distillation/run_on_policy_distill_math_qwen3_1.7b.sh
DATA_DIR="$HOME/data/dapo"
diff --git a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh
index 670a35152..c59ee5257 100644
--- a/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh
+++ b/skyrl-train/examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh
@@ -2,7 +2,11 @@ set -x
# Running on policy distillation for Math on the DAPO math dataset, with eval on AIME 2024.
# Uses Qwen-3-4B-Base as the student model and an RL trained Qwen-3-4B as the teacher model
+<<<<<<< HEAD
+# uv run examples/algorithms/dapo/prepare_dapo_data.sh
+=======
# bash examples/algorithms/dapo/prepare_dapo_data.sh
+>>>>>>> main
# bash examples/on_policy_distillation/run_on_policy_distill_math_qwen3_4b.sh
DATA_DIR="$HOME/data/dapo"
diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py
index db922953e..25a9db72a 100644
--- a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py
+++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench.py
@@ -10,7 +10,7 @@
from skyrl_train.utils.utils import initialize_ray
from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator
from examples.terminal_bench.dataset import TerminalBenchTaskDataset
-
+from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer
class TerminalBenchExp(BasePPOExp):
def get_generator(self, cfg, tokenizer, inference_engine_client):
@@ -52,6 +52,28 @@ def get_eval_dataset(self):
return prompts_dataset
return None
+ def get_trainer(
+ self,
+ cfg,
+ tracker,
+ tokenizer,
+ train_dataset,
+ eval_dataset,
+ inference_engine_client,
+ generator,
+ colocate_pg,
+ ):
+ return FullyAsyncRayPPOTrainer(
+ cfg=cfg,
+ tracker=tracker,
+ tokenizer=tokenizer,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ inference_engine_client=inference_engine_client,
+ generator=generator,
+ colocate_pg=colocate_pg,
+ )
+
@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py
index cd8f478fb..e8f8054c3 100644
--- a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py
+++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_generate.py
@@ -18,7 +18,7 @@
config_dir,
)
from skyrl_train.generators.base import GeneratorInput
-from examples.terminal_bench.generator.terminal_bench_generator import TerminalBenchGenerator
+from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator
from examples.terminal_bench.dataset import TerminalBenchTaskDataset
diff --git a/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py
new file mode 100644
index 000000000..bc2827290
--- /dev/null
+++ b/skyrl-train/examples/terminal_bench/entrypoints/main_tbench_opd.py
@@ -0,0 +1,36 @@
+"""
+Main entrypoint for training on terminal bench tasks.
+"""
+import ray
+import hydra
+from omegaconf import DictConfig
+from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir
+from skyrl_train.utils import validate_cfg
+from skyrl_train.utils.utils import initialize_ray
+from examples.terminal_bench.terminal_bench_generator import TerminalBenchGenerator
+from examples.terminal_bench.dataset import TerminalBenchTaskDataset
+from examples.terminal_bench.entrypoints.main_tbench import TerminalBenchExp
+from examples.on_policy_distillation.main_on_policy_distill import OnPolicyDistillationTrainer
+
+class OnPolicyDistillationTerminalBenchExp(TerminalBenchExp):
+ def get_trainer(self, *args, **kwargs):
+ return OnPolicyDistillationTrainer(*args, **kwargs)
+
+
+@ray.remote(num_cpus=1)
+def skyrl_entrypoint(cfg: DictConfig):
+ # make sure that the training loop is not run on the head node.
+ exp = OnPolicyDistillationTerminalBenchExp(cfg)
+ exp.run()
+
+@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
+def main(cfg: DictConfig) -> None:
+ # validate the arguments
+ validate_cfg(cfg)
+
+ initialize_ray(cfg)
+ ray.get(skyrl_entrypoint.remote(cfg))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py b/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py
deleted file mode 100644
index 57495d093..000000000
--- a/skyrl-train/examples/terminal_bench/generator/terminal_bench_generator.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import asyncio
-from dataclasses import dataclass
-from typing import List, Optional
-from uuid import uuid4
-from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput
-from skyrl_train.generators.utils import get_rollout_metrics, get_response_ids_and_loss_mask_from_messages
-from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
-from skyrl_train.inference_engines.base import ConversationType
-from omegaconf import DictConfig
-from pathlib import Path
-from harbor.models.trial.config import TrialConfig, AgentConfig, TaskConfig, EnvironmentConfig
-from harbor.models.environment_type import EnvironmentType
-from harbor.models.agent.name import AgentName
-from harbor.trial.trial import Trial
-
-
-@dataclass
-class TerminalBenchAgentOutput:
- response_ids: List[int]
- reward: float
- stop_reason: str
- loss_mask: List[int]
- prompt_ids: List[int]
- rollout_logprobs: Optional[List[float]]
-
-
-class TerminalBenchGenerator(GeneratorInterface):
- def __init__(
- self,
- generator_cfg: DictConfig,
- terminal_bench_cfg: DictConfig,
- inference_engine_client: InferenceEngineClient,
- tokenizer,
- ):
- """
- Args:
- generator_cfg: DictConfig object containing the generator configuration
- terminal_bench_cfg: DictConfig object containing the terminal bench configuration
- inference_engine_client: InferenceEngineClient object for interacting with the inference engines
- tokenizer: tokenizer object for encoding and decoding text
- """
- self.base_url = f"http://{generator_cfg.http_endpoint_host}:{generator_cfg.http_endpoint_port}"
- self.generator_cfg = generator_cfg
- self.tokenizer = tokenizer
- self.model_name = generator_cfg.model_name
-
- # TerminalBench config
- self.trials_dir = terminal_bench_cfg.trials_dir
- self.agent_name = terminal_bench_cfg.agent_name
- self.max_episodes = terminal_bench_cfg.max_episodes
-
- if self.generator_cfg.chat_template.name_or_path is not None:
- raise NotImplementedError("TerminalBenchGenerator doesn't support custom chat template")
-
- async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
- prompts = input_batch["prompts"]
- tasks = []
- for prompt in prompts:
- tasks.append(
- self.terminal_bench_agent_loop(
- prompt=prompt,
- )
- )
-
- all_outputs = await asyncio.gather(*tasks)
-
- responses = [output.response_ids for output in all_outputs]
- rewards = [output.reward for output in all_outputs]
- rollout_metrics = get_rollout_metrics(responses, rewards)
-
- generator_output: GeneratorOutput = {
- "prompt_token_ids": [output.prompt_ids for output in all_outputs],
- "response_ids": responses,
- "rewards": rewards,
- "loss_masks": [output.loss_mask for output in all_outputs],
- "stop_reasons": [output.stop_reason for output in all_outputs],
- "rollout_metrics": rollout_metrics,
- "rollout_logprobs": [output.rollout_logprobs for output in all_outputs],
- }
-
- return generator_output
-
- async def terminal_bench_agent_loop(
- self,
- prompt: ConversationType,
- ) -> TerminalBenchAgentOutput:
- """
- Run a single terminal_bench agent.
- """
- # Generate session_id for sticky routing to inference engines
- # All LLM requests in this trial will share the same session_id
- session_id = uuid4().hex
-
- if self.agent_name == "terminus":
- trial_config = TrialConfig(
- task=TaskConfig(path=prompt),
- trials_dir=Path(self.trials_dir),
- environment=EnvironmentConfig(type=EnvironmentType.DAYTONA),
- agent=AgentConfig(
- name=AgentName.TERMINUS_2.value,
- model_name=f"hosted_vllm/{self.model_name}",
- kwargs={
- "api_base": f"{self.base_url}/v1",
- "key": "fake_key",
- "session_id": session_id,
- "max_episodes": self.max_episodes,
- },
- ),
- )
- elif self.agent_name == "oracle":
- trial_config = TrialConfig(
- task=TaskConfig(path=prompt),
- trials_dir=Path(self.trials_dir),
- environment=EnvironmentConfig(type=EnvironmentType.DAYTONA),
- agent=AgentConfig(
- name=AgentName.ORACLE,
- model_name=f"hosted_vllm/{self.model_name}",
- ),
- )
- else:
- raise ValueError(f"Invalid agent name: {self.agent_name}")
-
- trial = Trial(trial_config)
- # Run the trial
- while True:
- try:
- results = await trial.run()
- print(f"Results: {results}")
- if not results.verifier_result:
- print(f"[WARNING] Exception info: {results.exception_info}")
- continue
- reward = results.verifier_result.reward
- chat_history = results.agent_result.all_messages
- if len(chat_history) > 0:
- break
- else:
- print(f"[WARNING] Agent {self.agent_name} did not return a response")
- except Exception as e:
- print(f"Error running trial: {e}")
- continue
-
- # Use the first message as the prompt. We assume to be no systems messages.
- assert chat_history[0]["role"] == "user", "The first message should be a user message"
- prompt = [chat_history[0]]
- prompt_ids = self.tokenizer.apply_chat_template(
- prompt,
- add_generation_prompt=False, # the message below will add it themselves
- tokenize=True,
- )
- initial_prompt_length = len(prompt_ids)
-
- # Process response messages (everything after the first message)
- response_messages = chat_history[1:]
- assistant_logprobs = getattr(results.agent_result, "output_logprobs", None)
- response_ids, loss_mask, rollout_logprobs = get_response_ids_and_loss_mask_from_messages(
- response_messages, self.tokenizer, assistant_logprobs
- )
-
- # Determine stop reason
- max_response_tokens = (
- self.generator_cfg.sampling_params.max_generate_length
- + self.generator_cfg.max_input_length
- - initial_prompt_length
- )
- stop_reason = "complete" # Default for trial completion
- if len(response_ids) > max_response_tokens:
- stop_reason = "length"
- # TODO(Charlie): should we do rewards = self._zero_reward_if_not_stop(rewards, stop_reasons)?
-
- # Truncate to maximum allowed length
- response_ids = response_ids[:max_response_tokens]
- loss_mask = loss_mask[:max_response_tokens]
- rollout_logprobs = rollout_logprobs[:max_response_tokens]
-
- return TerminalBenchAgentOutput(
- response_ids=response_ids,
- reward=reward,
- stop_reason=stop_reason,
- loss_mask=loss_mask,
- prompt_ids=prompt_ids,
- # in case harbor doesn't return logprobs, use None
- rollout_logprobs=rollout_logprobs if assistant_logprobs is not None else None,
- )
diff --git a/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2 b/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2
new file mode 100644
index 000000000..1643aa5fb
--- /dev/null
+++ b/skyrl-train/examples/terminal_bench/qwen3_thinking_acc.jinja2
@@ -0,0 +1,72 @@
+{%- if tools %}
+ {{- '<|im_start|>system\n' }}
+ {%- if messages[0].role == 'system' %}
+ {{- messages[0].content + '\n\n' }}
+ {%- endif %}
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
+ {%- for tool in tools %}
+ {{- "\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }}
+{%- else %}
+ {%- if messages[0].role == 'system' %}
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
+ {%- endif %}
+{%- endif %}
+{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
+{%- for message in messages[::-1] %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set ns.multi_step_tool = false %}
+ {%- set ns.last_query_index = index %}
+ {%- endif %}
+{%- endfor %}
+{%- for message in messages %}
+ {%- if message.content is string %}
+ {%- set content = message.content %}
+ {%- else %}
+ {%- set content = '' %}
+ {%- endif %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
+ {%- elif message.role == "assistant" %}
+ {{- '<|im_start|>' + message.role + '\n' + content }}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- content }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\n' }}
+ {%- if enable_thinking is defined and enable_thinking is false %}
+ {{- '\n\n\n\n' }}
+ {%- endif %}
+{%- endif %}
diff --git a/skyrl-train/examples/terminal_bench/run_tbench.sh b/skyrl-train/examples/terminal_bench/run_tbench.sh
index a44627324..e028877e7 100644
--- a/skyrl-train/examples/terminal_bench/run_tbench.sh
+++ b/skyrl-train/examples/terminal_bench/run_tbench.sh
@@ -6,13 +6,13 @@ set -x
# export WANDB_API_KEY=
# bash examples/terminal_bench/run_tbench.sh
-DATA_DIR="$(pwd)/harbor/examples/tasks"
+DATA_DIR="$HOME/data/tasks"
NUM_GPUS=1
LOGGER="console" # change to "console" to print to stdout
TBENCH_CONFIG_DIR="examples/terminal_bench"
SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train.
-uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m examples.terminal_bench.entrypoints.main_tbench \
+uv run --isolated --extra vllm --extra sandboxes --with "sandbox@./sandboxes" -m examples.terminal_bench.entrypoints.main_tbench \
data.train_data="['$DATA_DIR']" \
data.val_data="['$DATA_DIR']" \
hydra.searchpath=[file://$TBENCH_CONFIG_DIR] \
@@ -33,7 +33,7 @@ uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m exa
trainer.eval_interval=-1 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=8 \
- trainer.policy_mini_batch_size=2 \
+ trainer.policy_mini_batch_size=8 \
trainer.micro_forward_batch_size_per_gpu=1 \
trainer.micro_train_batch_size_per_gpu=1 \
trainer.ckpt_interval=-1 \
diff --git a/skyrl-train/examples/terminal_bench/run_tbench_gen.sh b/skyrl-train/examples/terminal_bench/run_tbench_gen.sh
index 82dffcdc3..5d233dbc1 100644
--- a/skyrl-train/examples/terminal_bench/run_tbench_gen.sh
+++ b/skyrl-train/examples/terminal_bench/run_tbench_gen.sh
@@ -6,13 +6,13 @@ set -x
# export WANDB_API_KEY=
# bash examples/terminal_bench/run_tbench.sh
-DATA_DIR="$(pwd)/harbor/examples/tasks"
+DATA_DIR="$HOME/data/tasks"
NUM_GPUS=1
LOGGER="console" # change to "console" to print to stdout
TBENCH_CONFIG_DIR="examples/terminal_bench"
SANDBOXES_DIR="sandboxes" # TODO: For now, `sandboxes` is cloned into SkyRL/skyrl-train.
-uv run --isolated --extra vllm --extra sandboxes --with "harbor@./harbor" -m examples.terminal_bench.entrypoints.main_tbench_generate \
+uv run --isolated --extra vllm --extra sandboxes --with "sandbox@./sandboxes" -m examples.terminal_bench.entrypoints.main_tbench_generate \
data.train_data="['$DATA_DIR']" \
hydra.searchpath=[file://$TBENCH_CONFIG_DIR] \
+terminal_bench_config=terminal_bench \
diff --git a/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml b/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml
index f95d7adb8..e69de29bb 100644
--- a/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml
+++ b/skyrl-train/examples/terminal_bench/terminal_bench_config/terminal_bench.yaml
@@ -1,5 +0,0 @@
-# @package terminal_bench_config
-
-trials_dir: "~/trials"
-agent_name: "terminus"
-max_episodes: 16
\ No newline at end of file
diff --git a/skyrl-train/examples/terminal_bench/terminal_bench_generator.py b/skyrl-train/examples/terminal_bench/terminal_bench_generator.py
new file mode 100644
index 000000000..f0eb2e5f8
--- /dev/null
+++ b/skyrl-train/examples/terminal_bench/terminal_bench_generator.py
@@ -0,0 +1,301 @@
+import asyncio
+from dataclasses import dataclass
+from typing import List, Optional
+from loguru import logger
+from uuid import uuid4
+from skyrl_train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID
+from skyrl_train.generators.utils import get_rollout_metrics
+from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
+from skyrl_train.inference_engines.base import ConversationType
+from omegaconf import DictConfig
+from pathlib import Path
+from harbor.models.trial.config import TrialConfig, AgentConfig, TaskConfig, EnvironmentConfig
+from harbor.models.environment_type import EnvironmentType
+from harbor.models.agent.name import AgentName
+from harbor.trial.trial import Trial
+
+# We have N retries for each trial, if one of the rollout (out of n_samples_per_prompt) fails
+# after N attemptes, we skip this prompt altogether.
+MAX_NUM_RETRIES_PER_TRIAL = 2
+
+@dataclass
+class TerminalBenchAgentOutput:
+ response_ids: List[int]
+ reward: float
+ stop_reason: str
+ loss_mask: List[int]
+ prompt_ids: List[int]
+ trajectory_id: TrajectoryID
+ summarization_count: Optional[int] = None
+ rollout_logprobs: Optional[List[float]] = None
+
+class TerminalBenchGenerator(GeneratorInterface):
+ def __init__(
+ self,
+ generator_cfg: DictConfig,
+ terminal_bench_cfg: DictConfig,
+ inference_engine_client: InferenceEngineClient,
+ tokenizer,
+ ):
+ """
+ Args:
+ generator_cfg: DictConfig object containing the generator configuration
+ terminal_bench_cfg: DictConfig object containing the terminal bench configuration
+ inference_engine_client: InferenceEngineClient object for interacting with the inference engines
+ tokenizer: tokenizer object for encoding and decoding text
+ """
+ self.base_url = f"http://{generator_cfg.http_endpoint_host}:{generator_cfg.http_endpoint_port}"
+ self.generator_cfg = generator_cfg
+ self.tokenizer = tokenizer
+ self.model_name = generator_cfg.model_name
+
+ # TerminalBench config. Parse here to ensure everything is passed in.
+ self.trials_dir = terminal_bench_cfg.trials_dir
+ self.agent_name = terminal_bench_cfg.agent_name
+ self.max_episodes = terminal_bench_cfg.max_episodes
+ self.enable_summarize = terminal_bench_cfg.get("enable_summarize", True)
+
+ # Optional overrides for the environment
+ self.override_memory_mb = terminal_bench_cfg.get("override_memory_mb")
+ self.override_storage_mb = terminal_bench_cfg.get("override_storage_mb")
+ self.override_cpus = terminal_bench_cfg.get("override_cpus")
+
+ logger.info(f"TerminalBenchGenerator initialized with overrides: memory={self.override_memory_mb}, storage={self.override_storage_mb}, cpus={self.override_cpus}")
+
+ # Read custom chat template
+ custom_chat_template_path = generator_cfg.engine_init_kwargs.get("custom_chat_template_chat_completion_path", None)
+ if custom_chat_template_path:
+ with open(custom_chat_template_path, "r") as f:
+ self.custom_chat_template_content = f.read()
+ logger.info(f"TerminalBenchGenerator initialized with custom chat template read from: {custom_chat_template_path}")
+ else:
+ self.custom_chat_template_content = None
+
+ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
+ tasks = []
+ for i in range(len(input_batch["prompts"])):
+ tasks.append(
+ self.terminal_bench_agent_loop(
+ prompt=input_batch["prompts"][i],
+ trajectory_id=input_batch["trajectory_ids"][i],
+ )
+ )
+
+ all_outputs: List[TerminalBenchAgentOutput] = await asyncio.gather(*tasks)
+
+ # For a group of trajectories (n_samples_per_prompt trajectories for the same prompt), if one
+ # of the trajectories fails, we skip the entire group. We also skip the group for rollout metric aggregation
+ failed_instance_ids = set()
+ num_failed_trajectories = 0 # per-trajectory, rather than per-instance
+ successful_outputs: List[TerminalBenchAgentOutput] = [] # only for metrics purpose
+ for output in all_outputs:
+ if output.stop_reason == "error":
+ failed_instance_ids.add(output.trajectory_id.instance_id)
+ num_failed_trajectories += 1
+
+ for output in all_outputs:
+ if output.trajectory_id.instance_id in failed_instance_ids:
+ output.response_ids = [0]
+ output.stop_reason = "error"
+ output.loss_mask = [0]
+ output.prompt_ids = [0]
+ output.reward = 0
+ else:
+ successful_outputs.append(output)
+
+ # Calculate rollout metrics for successful outputs
+ if len(successful_outputs) > 0:
+ rollout_metrics = get_rollout_metrics(
+ [output.response_ids for output in successful_outputs],
+ [output.reward for output in successful_outputs],
+ )
+ rollout_metrics["generate/trajectories_summarized"] = sum(1 for output in successful_outputs if output.summarization_count > 0)
+ rollout_metrics["generate/trajectories_truncated"] = sum(1 for output in successful_outputs if output.stop_reason == "length")
+ else:
+ rollout_metrics = {}
+ rollout_metrics["generate/num_failed_instances"] = len(failed_instance_ids)
+ rollout_metrics["generate/num_failed_trajectories"] = num_failed_trajectories
+
+ generator_output: GeneratorOutput = {
+ "prompt_token_ids": [output.prompt_ids for output in all_outputs],
+ "response_ids": [output.response_ids for output in all_outputs],
+ "rewards": [output.reward for output in all_outputs],
+ "loss_masks": [output.loss_mask for output in all_outputs],
+ "stop_reasons": [output.stop_reason for output in all_outputs],
+ "rollout_metrics": rollout_metrics,
+ "rollout_logprobs": None,
+ }
+
+ return generator_output
+
+ async def terminal_bench_agent_loop(
+ self,
+ prompt: ConversationType,
+ trajectory_id: TrajectoryID,
+ ) -> TerminalBenchAgentOutput:
+ """
+ Run a single terminal_bench agent.
+ """
+ # Generate session_id for sticky routing to inference engines
+ # All LLM requests in this trial will share the same session_id
+ session_id = uuid4().hex
+
+ environment_config = EnvironmentConfig(
+ type=EnvironmentType.DAYTONA,
+ override_cpus=self.override_cpus,
+ override_memory_mb=self.override_memory_mb,
+ override_storage_mb=self.override_storage_mb,
+ )
+
+ if self.agent_name == "terminus":
+ trial_config = TrialConfig(
+ task=TaskConfig(path=prompt),
+ trials_dir=Path(self.trials_dir),
+ environment=environment_config,
+ agent=AgentConfig(
+ name=AgentName.TERMINUS_2.value,
+ model_name=f"hosted_vllm/{self.model_name}",
+ kwargs={
+ "api_base": f"{self.base_url}/v1",
+ "key": "fake_key",
+ "max_episodes": self.max_episodes,
+ "session_id": session_id,
+ "enable_summarize": self.enable_summarize,
+ "store_all_messages": True,
+ "collect_rollout_details": True,
+ },
+ ),
+ )
+ elif self.agent_name == "oracle":
+ trial_config = TrialConfig(
+ task=TaskConfig(path=prompt),
+ trials_dir=Path(self.trials_dir),
+ environment=environment_config,
+ agent=AgentConfig(
+ name=AgentName.ORACLE,
+ model_name=f"hosted_vllm/{self.model_name}",
+ ),
+ )
+ else:
+ raise ValueError(f"Invalid agent name: {self.agent_name}")
+
+ trial = Trial(trial_config)
+
+ # Run the trial to get `rewards`, `chat_history`, and `summarization_count`
+ successful = False
+ reward = None
+ chat_history = None
+ summarization_count = None
+ for i in range(MAX_NUM_RETRIES_PER_TRIAL):
+ prefix = f"Trajectory {trajectory_id} attempt {i+1}/{MAX_NUM_RETRIES_PER_TRIAL}"
+ results = None
+ try:
+ results = await trial.run()
+ if not results.verifier_result:
+ logger.warning(f"{prefix} failed: Exception info: {results.exception_info}. Results: {results}")
+ continue
+
+ reward = results.verifier_result.rewards["reward"]
+ chat_history = results.agent_result.metadata['all_messages']
+ summarization_count = results.agent_result.metadata['summarization_count']
+ if len(chat_history) > 1 and chat_history[0]["role"] == "user":
+ successful = True
+ logger.info(f"{prefix} successful: Results: {results.agent_result.metadata}")
+ break
+ else:
+ logger.warning(f"{prefix} failed: Agent {self.agent_name} did not return a chat history with a user message. chat_history: {chat_history}\n\nResults: {results}")
+ except Exception as e:
+ logger.warning(f"{prefix} failed: Error running trial: {e}. Results: {results}")
+ continue
+
+ if not successful:
+ # We make loss mask 0 so it does not contribute to model updates
+ logger.warning(f"Trajectory {trajectory_id} failed after {MAX_NUM_RETRIES_PER_TRIAL} attempts, will set loss mask to [0].")
+ return TerminalBenchAgentOutput(
+ response_ids=[0],
+ reward=0,
+ stop_reason="error",
+ loss_mask=[0],
+ prompt_ids=[0],
+ trajectory_id=trajectory_id,
+ )
+
+ # Use the first message as the prompt. We assume to be no systems messages.
+ assert chat_history[0]["role"] == "user", "The first message should be a user message"
+ prompt = [chat_history[0]]
+ prompt_ids = self.tokenizer.apply_chat_template(
+ prompt,
+ add_generation_prompt=False, # the message below will add it themselves
+ tokenize=True,
+ chat_template=self.custom_chat_template_content,
+ )
+ initial_prompt_length = len(prompt_ids)
+
+ # Process response messages (everything after the first message)
+ response_messages = chat_history[1:]
+ rollout_details = getattr(results.agent_result, "rollout_details", None)
+
+ # Extract from rollout_details tuple: ([{'prompt_token_ids': [...], 'completion_token_ids': [...], 'logprobs': [...]}],)
+ try:
+ details = rollout_details[0]
+ if isinstance(details, list):
+ details = details[0]
+ except:
+ logger.info(f"Unpacked details: {rollout_details[0]}")
+ prompt_ids_list = details['prompt_token_ids']
+ completion_ids_list = details['completion_token_ids']
+ logprobs_list = details['logprobs']
+
+ # Initial prompt is the first one
+ prompt_ids = list(prompt_ids_list[0])
+
+ # Build response_ids, loss_mask, and rollout_logprobs turn by turn
+ # Structure: completion[0] + user_feedback[0] + completion[1] + user_feedback[1] + ... + completion[n-1]
+ # where user_feedback[i] = prompt[i+1][len(prompt[i]) + len(completion[i]):]
+ response_ids = []
+ loss_mask = []
+ rollout_logprobs = []
+
+ num_turns = len(prompt_ids_list)
+ for turn in range(num_turns):
+ # Add completion tokens (mask = 1)
+ completion = list(completion_ids_list[turn])
+ logprobs = list(logprobs_list[turn])
+ response_ids.extend(completion)
+ loss_mask.extend([1] * len(completion))
+ rollout_logprobs.extend(logprobs)
+
+ # Add user feedback tokens after this completion (if not last turn)
+ if turn < num_turns - 1:
+ # user_feedback = prompt[turn+1] minus (prompt[turn] + completion[turn])
+ prev_len = len(prompt_ids_list[turn]) + len(completion_ids_list[turn])
+ user_tokens = list(prompt_ids_list[turn + 1][prev_len:])
+ response_ids.extend(user_tokens)
+ loss_mask.extend([0] * len(user_tokens))
+ rollout_logprobs.extend([0.0] * len(user_tokens))
+
+ # Determine stop reason
+ max_response_tokens = (
+ self.generator_cfg.sampling_params.max_generate_length
+ + self.generator_cfg.max_input_length
+ - initial_prompt_length
+ )
+ stop_reason = "complete" # Default for trial completion
+ if len(response_ids) > max_response_tokens:
+ stop_reason = "length"
+
+ # Truncate to maximum allowed length
+ response_ids = response_ids[:max_response_tokens]
+ loss_mask = loss_mask[:max_response_tokens]
+ if rollout_logprobs is not None:
+ rollout_logprobs = rollout_logprobs[:max_response_tokens]
+ return TerminalBenchAgentOutput(
+ response_ids=response_ids,
+ reward=reward,
+ stop_reason=stop_reason,
+ loss_mask=loss_mask,
+ prompt_ids=prompt_ids,
+ trajectory_id=trajectory_id,
+ rollout_logprobs=rollout_logprobs,
+ summarization_count=summarization_count,
+ )
diff --git a/skyrl-train/skyrl_train/entrypoints/main_base.py b/skyrl-train/skyrl_train/entrypoints/main_base.py
index 0fb2dc4b5..a2d56a64e 100644
--- a/skyrl-train/skyrl_train/entrypoints/main_base.py
+++ b/skyrl-train/skyrl_train/entrypoints/main_base.py
@@ -22,6 +22,7 @@
import hydra
from loguru import logger
from skyrl_train.utils.tracking import Tracking
+import asyncio
import multiprocessing as mp
# NOTE (sumanthrh): We use ray heavily and thus disable `fork` start method.
@@ -312,7 +313,7 @@ def _setup_trainer(self):
def run(self):
trainer = self._setup_trainer()
# Start the training loop
- trainer.train()
+ asyncio.run(trainer.train())
@ray.remote(num_cpus=1)
diff --git a/skyrl-train/skyrl_train/generators/utils.py b/skyrl-train/skyrl_train/generators/utils.py
index d65ea9bec..692daf69d 100644
--- a/skyrl-train/skyrl_train/generators/utils.py
+++ b/skyrl-train/skyrl_train/generators/utils.py
@@ -104,13 +104,13 @@ def get_custom_chat_template(chat_template_config: Optional[Union[dict, DictConf
raise ValueError(f"Invalid source '{source}'. Must be 'name' or 'file'")
-def get_generation_prompt_ids(tokenizer) -> List[int]:
+def get_generation_prompt_ids(tokenizer, custom_chat_template=None) -> List[int]:
"""
Helper function to get the generation prompt ids for a given tokenizer.
"""
- empty_user = tokenizer.apply_chat_template([{"role": "user", "content": ""}], tokenize=True)
+ empty_user = tokenizer.apply_chat_template([{"role": "user", "content": ""}], tokenize=True, chat_template=custom_chat_template)
empty_user_with_generation_prompt = tokenizer.apply_chat_template(
- [{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True
+ [{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True, chat_template=custom_chat_template
)
generation_prompt_ids = empty_user_with_generation_prompt[len(empty_user) :]
diff --git a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
index b2e038702..d12d6e7dc 100644
--- a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
+++ b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
@@ -354,6 +354,7 @@ async def _chat_completion_with_retry(
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
session_id = request_payload["json"].pop("session_id", None)
+ # print(f"CHARLIE session_id: {session_id}")
if session_id is None:
# if session_id is not provided, we'll use a random engine
engine_idx = random.randint(0, len(self.engines) - 1)
diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
index 0a254695f..26e57e652 100644
--- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
+++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
@@ -1,6 +1,7 @@
import os
from typing import List, Any, Dict, Optional, Tuple, Iterator
from dataclasses import dataclass
+from loguru import logger
from http import HTTPStatus
import ray
import torch
@@ -19,6 +20,7 @@
CompletionRequest,
CompletionResponse,
)
+from vllm.v1.metrics.loggers import LoggingStatLogger
from vllm.lora.request import LoRARequest
from torch.distributed import destroy_process_group
from skyrl_train.distributed.utils import init_custom_process_group
@@ -47,6 +49,7 @@ class Logprob:
def setup_envvars_for_vllm(kwargs, bundle_indices):
noset_visible_devices = kwargs.pop("noset_visible_devices")
+ os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" # TODO(Charlie): may not be needed.
if kwargs.get("distributed_executor_backend") == "ray":
# a hack to make the script work.
# stop ray from manipulating *_VISIBLE_DEVICES
@@ -362,6 +365,22 @@ async def _destroy_weights_update_group(self):
engine = self._get_engine()
return await asyncio.to_thread(engine.collective_rpc, "destroy_weights_update_group")
+class V1LoggingStatLoggerFixed(LoggingStatLogger):
+ """
+ A fixed version of LoggingStatLogger that actually logs during the record method.
+ The log method is otherwise not called in the VLLM codebase.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+ self.log_interval = 5
+
+ def record(self, *args: Any, **kwargs: Any) -> None:
+ super().record(*args, **kwargs)
+ now = time.monotonic()
+ if now - self.last_log_time > self.log_interval:
+ self.log()
+ self.last_log_time = now
class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine):
"""Asynchronous VLLM engine."""
@@ -373,11 +392,16 @@ def __init__(self, *args, **kwargs):
def _create_engine(self, *args, **kwargs):
openai_kwargs = pop_openai_kwargs(kwargs)
# TODO (erictang000): potentially enable log requests for a debugging mode
+ custom_chat_template_path = kwargs.pop("custom_chat_template_chat_completion_path", None)
+ stat_loggers = [V1LoggingStatLoggerFixed]
+ engine_args = vllm.AsyncEngineArgs(**kwargs)
+
if version.parse(vllm.__version__) >= version.parse("0.10.0"):
engine_args = vllm.AsyncEngineArgs(enable_log_requests=False, **kwargs)
else:
engine_args = vllm.AsyncEngineArgs(disable_log_requests=True, **kwargs)
- engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
+ engine = vllm.AsyncLLMEngine.from_engine_args(engine_args, stat_loggers=stat_loggers)
+
# Adapted from https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L327
model_config = engine.model_config
@@ -386,16 +410,24 @@ def _create_engine(self, *args, **kwargs):
model_name = model_path
base_model_paths = [BaseModelPath(name=model_name, model_path=model_path)]
- models = OpenAIServingModels(engine, model_config, base_model_paths)
+ models = OpenAIServingModels(engine, base_model_paths)
+
+ # TODO(Charlie): adding custom chat template for chat completion. Hacky!
+ if custom_chat_template_path:
+ with open(custom_chat_template_path, "r") as f:
+ custom_chat_template_content = f.read()
+ logger.info(f"Initializing OpenAIServingChat with custom_chat_template read from: {custom_chat_template_path}")
+ else:
+ custom_chat_template_content = None
+
# TODO(Charlie): revisit kwargs `enable_auto_tools` and `tool_parser` when we need to
# support OAI-style tool calling; and `request_logger` for better debugging.
self.openai_serving_chat = OpenAIServingChat(
engine_client=engine,
- model_config=model_config,
models=models,
response_role="assistant",
request_logger=None,
- chat_template=None,
+ chat_template=custom_chat_template_content,
chat_template_content_format="auto",
**openai_kwargs,
)
@@ -404,7 +436,6 @@ def _create_engine(self, *args, **kwargs):
# `enable_prompt_tokens_details`, `enable_force_include_usage`.
self.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine,
- model_config=model_config,
models=models,
request_logger=None,
)
@@ -526,6 +557,15 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint
body = request_payload.get("json", {})
headers = request_payload.get("headers", {})
+ # TODO(Charlie): Hacky! We are hijacking to update the sampling params.
+ # Can we allow Harbor to use customized sampling params?
+ body.update({
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "min_p": 0.0,
+ })
+
# 1. Build request
try:
if endpoint == "/chat/completions":
diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py
index 0f5c41151..c71987217 100644
--- a/skyrl-train/skyrl_train/trainer.py
+++ b/skyrl-train/skyrl_train/trainer.py
@@ -1,4 +1,3 @@
-import asyncio
import math
import os
import shutil
@@ -148,7 +147,7 @@ async def eval(self) -> Dict[str, float]:
)
return eval_metrics
- def train(self):
+ async def train(self):
"""
Main training loop for PPO
"""
diff --git a/skyrl-train/skyrl_train/utils/tracking.py b/skyrl-train/skyrl_train/utils/tracking.py
index 5a6f45d8e..0f703c2c5 100644
--- a/skyrl-train/skyrl_train/utils/tracking.py
+++ b/skyrl-train/skyrl_train/utils/tracking.py
@@ -23,6 +23,34 @@
from loguru import logger
from omegaconf import DictConfig, OmegaConf
import pprint
+import ray
+from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
+
+
+@ray.remote
+class WandbNodeLogger:
+ """
+ A Ray actor that initializes wandb on a specific node to capture system metrics.
+ """
+ def __init__(self, project_name, experiment_name, config, run_id, group_name, x_label):
+ import wandb
+ # Initialize wandb with the same run ID to aggregate system metrics
+ run = wandb.init(
+ project=project_name,
+ name=experiment_name,
+ id=run_id,
+ config=config,
+ resume="allow",
+ group=group_name,
+ job_type="worker_monitor",
+ settings=wandb.Settings(
+ mode="shared",
+ x_primary=False,
+ x_update_finish_state=False,
+ x_label=x_label,
+ )
+ )
+ self.wandb = run
# TODO(tgriggs): Test all backends.
@@ -41,8 +69,28 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str]
import wandb
from omegaconf import OmegaConf
- wandb.init(project=project_name, name=experiment_name, config=OmegaConf.to_container(config, resolve=True))
- self.logger["wandb"] = wandb
+ current_node_ip = "head"
+ if ray.is_initialized():
+ try:
+ current_node_ip = ray.util.get_node_ip_address()
+ except Exception as e:
+ logger.warning(f"Failed to get node IP address, defaulting to 'head'. Error: {e}")
+
+ run = wandb.init(
+ project=project_name,
+ name=experiment_name,
+ config=OmegaConf.to_container(config, resolve=True),
+ group=experiment_name,
+ resume="allow",
+ settings=wandb.Settings(
+ mode="shared", # mainly for multi-node training's systems metrics aggregation
+ x_primary=True,
+ x_label=f"node-{current_node_ip}"
+ )
+ )
+ run_id = run.id
+ self.logger["wandb"] = run
+ self._prepare_worker_nodes_systems_logging_wandb(project_name, experiment_name, run_id, config, current_node_ip)
if "mlflow" in backends:
self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config)
@@ -73,6 +121,58 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str]
self.console_logger = ConsoleLogger()
self.logger["console"] = self.console_logger
+
+ def _prepare_worker_nodes_systems_logging_wandb(self, project_name, experiment_name, run_id, config, current_node_ip):
+ """
+ In multi-node training, we spawn WandbNodeLogger actors on each worker node to capture system metrics like
+ GPU utilization. We use `mode="shared"` to aggregate system metrics from all nodes to the same Wandb run.
+ However, with this approach, the systems metrics panels do not appear in the Wandb UI automatically but requires
+ us to manually create the panels. The alternative is to create a run for each node and group them by group_name.
+ We prefer to keep all nodes metrics to the same run.
+ """
+ # Start WandB loggers on other nodes using Ray
+ if ray.is_initialized():
+ try:
+ # current_node_ip (head) is already set in __init__
+ nodes = ray.nodes()
+ self.remote_loggers = []
+
+ for node in nodes:
+ if not node["Alive"]:
+ continue
+
+ node_ip = node["NodeManagerAddress"]
+ # Skip the current node as it's already logging
+ if node_ip == current_node_ip:
+ continue
+
+ try:
+ # Launch a logger on this node
+ logger_actor = WandbNodeLogger.options(
+ num_cpus=0,
+ scheduling_strategy=NodeAffinitySchedulingStrategy(
+ node_id=node["NodeID"],
+ soft=False, # fail if that node is gone or infeasible
+ ),
+ ).remote(
+ project_name=project_name,
+ experiment_name=experiment_name,
+ config=OmegaConf.to_container(config, resolve=True),
+ run_id=run_id,
+ group_name=experiment_name,
+ x_label=f"node-{node_ip}"
+ )
+ self.remote_loggers.append(logger_actor)
+ except Exception as e:
+ logger.warning(f"Failed to spawn WandbNodeLogger on {node_ip}: {e}")
+ logger.info(f"WandbNodeLogger initialized on 'node-{node_ip}'")
+
+ except Exception as e:
+ logger.warning(f"Failed to setup distributed wandb logging: {e}")
+ else:
+ logger.warning("Ray is not initialized, skipping distributed wandb logging")
+
+
def log(self, data, step, commit=False):
for logger_name, logger_instance in self.logger.items():
if logger_name == "wandb":
diff --git a/skyrl-train/skyrl_train/weights_manager.py b/skyrl-train/skyrl_train/weights_manager.py
index bfd3aa5b2..d63322c0a 100644
--- a/skyrl-train/skyrl_train/weights_manager.py
+++ b/skyrl-train/skyrl_train/weights_manager.py
@@ -27,6 +27,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.weights_manager.__exit__(exc_type, exc_val, exc_tb)
return False
+ async def __aenter__(self):
+ if self.condition:
+ await self.weights_manager.__aenter__()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if self.condition:
+ return await self.weights_manager.__aexit__(exc_type, exc_val, exc_tb)
+ return False
+
class InferenceWeightsManager:
"""Manages weight syncing and offloading/backloading between the policy model and the InferenceEngines.
diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
index 9aea864ac..3d7bfbc82 100644
--- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
+++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
@@ -129,7 +129,7 @@ def init_model(self, model_path, num_training_steps: int = None):
use_flash_attention_2=self.cfg.trainer.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
- bf16=False,
+ bf16=True,
lora_rank=self.cfg.trainer.policy.model.lora.rank,
lora_alpha=self.cfg.trainer.policy.model.lora.alpha,
lora_dropout=self.cfg.trainer.policy.model.lora.dropout,
@@ -357,7 +357,7 @@ def init_model(self, model_path, num_training_steps: int = None):
use_flash_attention_2=self.cfg.trainer.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
- bf16=False,
+ bf16=True,
lora_rank=self.cfg.trainer.critic.model.lora.rank,
lora_alpha=self.cfg.trainer.critic.model.lora.alpha,
lora_dropout=self.cfg.trainer.critic.model.lora.dropout,
diff --git a/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2 b/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2
new file mode 100644
index 000000000..15691e71b
--- /dev/null
+++ b/skyrl-train/tests/gpu/gpu_ci/qwen3_acc_thinking.jinja2
@@ -0,0 +1,85 @@
+{%- if tools %}
+ {{- '<|im_start|>system\n' }}
+ {%- if messages[0].role == 'system' %}
+ {{- messages[0].content + '\n\n' }}
+ {%- endif %}
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
+ {%- for tool in tools %}
+ {{- "\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }}
+{%- else %}
+ {%- if messages[0].role == 'system' %}
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
+ {%- endif %}
+{%- endif %}
+{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
+{%- for message in messages[::-1] %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set ns.multi_step_tool = false %}
+ {%- set ns.last_query_index = index %}
+ {%- endif %}
+{%- endfor %}
+{%- for message in messages %}
+ {%- if message.content is string %}
+ {%- set content = message.content %}
+ {%- else %}
+ {%- set content = '' %}
+ {%- endif %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
+ {%- elif message.role == "assistant" %}
+ {%- set reasoning_content = '' %}
+ {%- if message.reasoning_content is string %}
+ {%- set reasoning_content = message.reasoning_content %}
+ {%- else %}
+ {%- if '' in content %}
+ {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %}
+ {%- set content = content.split('')[-1].lstrip('\n') %}
+ {%- endif %}
+ {%- endif %}
+ {%- if reasoning_content %}
+ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
+ {%- else %}
+ {{- '<|im_start|>' + message.role + '\n' + content }}
+ {%- endif %}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- content }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\n' }}
+ {%- if enable_thinking is defined and enable_thinking is false %}
+ {{- '\n\n\n\n' }}
+ {%- endif %}
+{%- endif %}
\ No newline at end of file