diff --git a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile index 25da7f460..69561ac5b 100644 --- a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile @@ -27,23 +27,37 @@ RUN pip install keyring keyrings.google-artifactregistry-auth RUN pip install numba==0.61.2 -COPY tunix /tunix -RUN pip uninstall -y google-tunix -RUN pip install -e /tunix --no-cache-dir +RUN pip install vllm-tpu +# Clone directly into /vllm +RUN pip install vllm==0.12.0 +# 1. TUNIX +# Clone directly into /tunix instead of COPYing local files +# RUN git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git +RUN pip uninstall -y tunix && git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git && cd tunix && pip install -e . && cd .. -COPY vllm /vllm -RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir +# 2. TPU-INFERENCEs +# Clone directly into /tpu-inference +# RUN git clone https://github.com/vllm-project/tpu-inference.git /tpu-inference +RUN pip uninstall -y tpu-inference && git clone https://github.com/abhinavclemson/tpu-inference.git && cd tpu-inference && pip install -e . && cd .. +# Note: The repo name is 'tpu-inference' (dash), but python package might be 'tpu_inference'. +# pip install handles this mapping automatically. -COPY tpu-inference /tpu-inference -RUN pip install -e /tpu-inference --no-cache-dir +# 3. vLLM + + +# RUN git clone https://github.com/vllm-project/vllm.git /vllm +# Set the TPU target and install + +# --- REPLACEMENT END --- RUN pip install --no-deps qwix==0.1.4 +RUN pip install google-metrax numpy==2.2 + RUN if [ "$MODE" = "post-training-experimental" ]; then \ echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \ pip uninstall -y jax jaxlib libtpu && \ - pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ - pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + pip install --pre jax==0.8.0.dev20251013 jaxlib==0.8.0.dev20251013 libtpu==0.0.25.dev20251012+nightly -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ fi diff --git a/dependencies/dockerfiles/patch_work.sh b/dependencies/dockerfiles/patch_work.sh new file mode 100644 index 000000000..1a843fef0 --- /dev/null +++ b/dependencies/dockerfiles/patch_work.sh @@ -0,0 +1,53 @@ + +#!/bin/bash + +# 1. Define the target directory +SITE_PACKAGES="/usr/local/lib/python*/site-packages" +TEMP_DIR="temp_patch_work" + +# Ensure the script stops if any command fails +set -e + +echo "Navigate to site-packages: $SITE_PACKAGES" +cd "$SITE_PACKAGES" + +# 2. Create a temporary directory for cloning +echo "Creating temporary directory..." +# Remove it first if it exists from a previous failed run to ensure a clean slate +if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi +mkdir "$TEMP_DIR" +cd "$TEMP_DIR" + +# 3. Clone the repositories +echo "Cloning repositories..." +git clone https://github.com/vllm-project/vllm.git +git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git +git clone https://github.com/vllm-project/tpu-inference.git + +# Go back up to site-packages +cd .. + +# 4. Copy files +# We use 'cp -rf' to force overwrite existing files recursively. +# We assume the destination folders (./tunix, ./vllm) already exist as installed packages. +# If they don't exist, we create them. + +echo "Patching Tunix..." +mkdir -p ./tunix +cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/ + +echo "Patching TPU-Inference..." +# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash). +# Based on your prompt, we are using 'tpu-inference'. +mkdir -p ./tpu_inference +cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/ + +echo "Patching vLLM..." +mkdir -p ./vllm +cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/ + +# 5. Cleanup +echo "Cleaning up temporary files..." +rm -rf "$TEMP_DIR" + +echo "Done! Packages have been patched." \ No newline at end of file diff --git a/dependencies/scripts/patch_work.sh b/dependencies/scripts/patch_work.sh new file mode 100644 index 000000000..1a843fef0 --- /dev/null +++ b/dependencies/scripts/patch_work.sh @@ -0,0 +1,53 @@ + +#!/bin/bash + +# 1. Define the target directory +SITE_PACKAGES="/usr/local/lib/python*/site-packages" +TEMP_DIR="temp_patch_work" + +# Ensure the script stops if any command fails +set -e + +echo "Navigate to site-packages: $SITE_PACKAGES" +cd "$SITE_PACKAGES" + +# 2. Create a temporary directory for cloning +echo "Creating temporary directory..." +# Remove it first if it exists from a previous failed run to ensure a clean slate +if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi +mkdir "$TEMP_DIR" +cd "$TEMP_DIR" + +# 3. Clone the repositories +echo "Cloning repositories..." +git clone https://github.com/vllm-project/vllm.git +git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git +git clone https://github.com/vllm-project/tpu-inference.git + +# Go back up to site-packages +cd .. + +# 4. Copy files +# We use 'cp -rf' to force overwrite existing files recursively. +# We assume the destination folders (./tunix, ./vllm) already exist as installed packages. +# If they don't exist, we create them. + +echo "Patching Tunix..." +mkdir -p ./tunix +cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/ + +echo "Patching TPU-Inference..." +# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash). +# Based on your prompt, we are using 'tpu-inference'. +mkdir -p ./tpu_inference +cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/ + +echo "Patching vLLM..." +mkdir -p ./vllm +cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/ + +# 5. Cleanup +echo "Cleaning up temporary files..." +rm -rf "$TEMP_DIR" + +echo "Done! Packages have been patched." \ No newline at end of file diff --git a/patch_work.sh b/patch_work.sh new file mode 100644 index 000000000..56a426ca5 --- /dev/null +++ b/patch_work.sh @@ -0,0 +1,57 @@ + +#!/bin/bash + +# Ensure the script stops if any command fails +set -e + +cd .. + +# 1. Define the target directory +SITE_PACKAGES=$(find . -type d -name "*site-packages*" -print -quit) + +TEMP_DIR="temp_patch_work" + +echo "Navigate to site-packages: $SITE_PACKAGES" +cd "$SITE_PACKAGES" + +# 2. Create a temporary directory for cloning +echo "Creating temporary directory..." +# Remove it first if it exists from a previous failed run to ensure a clean slate +if [ -d "$TEMP_DIR" ]; then rm -rf "$TEMP_DIR"; fi +mkdir "$TEMP_DIR" +cd "$TEMP_DIR" + +# 3. Clone the repositories +echo "Cloning repositories..." +git clone https://github.com/vllm-project/vllm.git && cd vllm && git checkout 8c363ed6663f69b97c9f34b0be0091d8135f958c && cd .. +git clone -b make-moe-work https://github.com/abhinavclemson/tunix.git +git clone https://github.com/abhinavclemson/tpu-inference.git + + +# Go back up to site-packages +cd .. + +# 4. Copy files +# We use 'cp -rf' to force overwrite existing files recursively. +# We assume the destination folders (./tunix, ./vllm) already exist as installed packages. +# If they don't exist, we create them. + +echo "Patching Tunix..." +mkdir -p ./tunix +cp -rf "$TEMP_DIR/tunix/tunix/"* ./tunix/ + +echo "Patching TPU-Inference..." +# Note: Verify if the installed package name is 'tpu_inference' (underscore) or 'tpu-inference' (dash). +# Based on your prompt, we are using 'tpu-inference'. +mkdir -p ./tpu_inference +cp -rf "$TEMP_DIR/tpu-inference/tpu_inference/"* ./tpu_inference/ + +echo "Patching vLLM..." +mkdir -p ./vllm +cp -rf "$TEMP_DIR/vllm/vllm/"* ./vllm/ + +# 5. Cleanup +echo "Cleaning up temporary files..." +rm -rf "$TEMP_DIR" + +echo "Done! Packages have been patched." diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index e5bead6e1..e0d993826 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -23,6 +23,10 @@ sampler_devices_fraction: 0.5 chips_per_vm: 4 # depends on hardware, for v5p this is 4 num_trainer_slices: -1 num_samplers_slices: -1 +# Only specify rollout_data_parallelism when you would like to use more than one model +# replicas in rollout. If not specified, rollout_tensor_parallelism will be auto-determined. +rollout_data_parallelism: -1 +rollout_tensor_parallelism: -1 # ====== Reproducibility ====== data_shuffle_seed: 42 @@ -83,13 +87,13 @@ debug: enable_tunix_perf_metrics: False # ====== Training ====== -batch_size: 1 +batch_size: 8 # Increase `batch_size` and `MAX_STEPS` for better results. # num_batches: 3738 num_batches: 4 # 200 # A batch can be split into multiple micro batches for memory management # and/or async sampling and training. -micro_batch_size: -1 +micro_batch_size: 8 # Keep `num_test_batches` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). num_test_batches: 5 # 200 @@ -130,7 +134,7 @@ eval_make_lst: False # If True, return a list of (question, answer, responses) d max_prefill_predict_length: 256 max_target_length: 1024 kv_cache_buffer: 256 -hbm_utilization_vllm: 0.72 +hbm_utilization_vllm: 0.6 swap_space_vllm_gb: 2 # Generation Configuration During Training # Important to keep a high-ish temperature for varied, diverse responses during diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 4f2406534..a520a4770 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1264,7 +1264,14 @@ class RLHardware(BaseModel): use_pathways: bool = Field(True, description="Whether to use Pathways for multihost orchestration.") num_trainer_slices: int = Field(-1, description="Number of slices for the trainer.") num_samplers_slices: int = Field(-1, description="Number of slices for the samplers.") - + rollout_data_parallelism: int = Field( + -1, + description="Total model replicas for rollout. It should only be specified when you would like to use more " + "than one model replica in rollout.", + ) + rollout_tensor_parallelism: int = Field( + -1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined." + ) class VLLM(BaseModel): """vLLM-specific configuration for rollouts.""" diff --git a/src/MaxText/integration/tunix/utils.py b/src/MaxText/integration/tunix/utils.py index 2cf12c048..161202c52 100644 --- a/src/MaxText/integration/tunix/utils.py +++ b/src/MaxText/integration/tunix/utils.py @@ -14,8 +14,10 @@ """Utils for Tunix integration.""" +import inspect import re + import MaxText.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS @@ -127,7 +129,17 @@ def __init__(self, model_name, config=None, use_standalone_mappings=False): def to_hf_mapping(self): """Returns a mapping from MaxText parameter names to HuggingFace parameter names.""" if self.use_standalone_mappings: - return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping() + mapping_fn = STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping + total_num_layers = self.config["num_hidden_layers"] + print(f"total_num_layers: {total_num_layers} for model: {self.model_name}") + sig = inspect.signature(mapping_fn) + if len(sig.parameters) >= 1 and "total_num_layers" in sig.parameters: + mapping = mapping_fn( + total_num_layers=total_num_layers, + ) + return mapping + + return mapping_fn() config = self.config mapping = self.convert_hf_map_to_sharding_map( diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index d250ee2fe..9735998f5 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -19,6 +19,7 @@ model name. This allows for easy extension to support new models. """ +from MaxText.integration.tunix.weight_mapping.gpt_oss import GptOssMaxTextMapping from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING @@ -31,6 +32,8 @@ def __getattr__(self, name): return LLAMA3_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING + elif name.startswith("gpt"): + return GptOssMaxTextMapping else: raise ValueError(f"{name} vLLM weight mapping not found.") diff --git a/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py new file mode 100644 index 000000000..adebce2d2 --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/gpt_oss.py @@ -0,0 +1,212 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's GPT-OSS model to a vLLM-compatible format. +""" + +from dataclasses import dataclass +import logging +from typing import Dict, Optional, Tuple +import jax + +@dataclass +class GptOssMaxTextMapping: + """ + Mapping definition from MaxText GPT-OSS (Scanned/Interleaved) to vLLM JAX NNX. + + Supports: + - Modulo Interleaving (e.g., Block 0 -> Layers 0, 2, 4...) + """ + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_hook_fns(): + def fuse_interleaved_gate(val, tgt_param): + """Fuse Gate (wi_0) with Multi-Host Sharding Support.""" + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param + + # Safety Check + if current.shape[-1] != val.shape[-1] * 2: + if current.shape[-1] == val.shape[-1]: + logging.debug(f"Gate Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}") + return val + logging.warning(f"Gate Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}") + + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. + # # MULTI-HOST case. + val = jax.device_put(val, current.sharding) + val.block_until_ready() + + logging.debug("Hook: Interleaving Gate -> Even columns") + return current.at[..., 0::2].set(val) + + def fuse_interleaved_up(val, tgt_param): + """Fuse Up (wi_1) with Multi-Host Sharding Support.""" + current = tgt_param.value if hasattr(tgt_param, "value") else tgt_param + + if current.shape[-1] != val.shape[-1] * 2: + if current.shape[-1] == val.shape[-1]: + logging.debug(f"Up Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}") + return val + logging.warning(f"Up Fusion Shape Warning: Src {val.shape} -> Tgt {current.shape}") + + # TODO: Enable multi-host sharding, if there is a mismatch in shapes. + # # MULTI-HOST case. + val = jax.device_put(val, current.sharding) + val.block_until_ready() + + logging.debug("Hook: Interleaving Up -> Odd columns") + return current.at[..., 1::2].set(val) + + return { + r'.*GptOssMlp\.wi_0.*': fuse_interleaved_gate, + r'.*GptOssMlp\.wi_1.*': fuse_interleaved_up, + } + + @staticmethod + def to_hf_transpose_keys(): + return {} + + @staticmethod + def to_hf_mapping( + layer_cycle_interval: int = 2, + total_num_layers: int = 36, + interleave_style: str = "modulo" + ) -> Dict[str, Tuple[str, Tuple[Optional[str], ...]]]: + + mapping = {} + + # --- 1. Global Parameters --- + mapping.update({ + "base.token_embedder.embedding": ("embedder.input_embedding_table_VD", ("model", None)), + "base.decoder.decoder_norm.scale": ("final_norm.scale", (None,)), + "base.decoder.logits_dense.kernel": ("lm_head.input_embedding_table_DV", (None, "model")), + }) + + # --- 2. Layer Mapping Loop --- + layers_per_block = total_num_layers // layer_cycle_interval + + for block_idx in range(layer_cycle_interval): + src_block = f"base.decoder.layers.layers_{block_idx}" + if interleave_style == "modulo": + target_indices = range(block_idx, total_num_layers, layer_cycle_interval) + else: + start = block_idx * layers_per_block + target_indices = range(start, start + layers_per_block) + + regex_indices = "|".join(map(str, target_indices)) + layer_regex = f"layers\.({regex_indices})" + + # --- 3. Block Mappings (Standard) --- + mapping.update({ + f"{src_block}.pre_self_attention_layer_norm.scale": + (f"{layer_regex}.pre_attention_norm.scale", (None, "layer")), + f"{src_block}.post_self_attention_layer_norm.scale": ( + f"{layer_regex}.pre_mlp_norm.scale", (None, "layer") + ), + f"{src_block}.GptOssAttention.query.kernel": ( + f"{layer_regex}.attn.kernel_q_DNH", + (None, "layer", "model", None) + ), + f"{src_block}.GptOssAttention.key.kernel": + (f"{layer_regex}.attn.kernel_k_DKH", (None, "layer", "model", None)), + f"{src_block}.GptOssAttention.value.kernel": + (f"{layer_regex}.attn.kernel_v_DKH", (None, "layer", "model", None)), + f"{src_block}.GptOssAttention.out.kernel": ( + f"{layer_regex}.attn.kernel_o_proj_NHD", + ("model", "layer", None, None) + ), + f"{src_block}.GptOssAttention.query.bias": ( + f"{layer_regex}.attn.bias_q_NH", (None, "layer", None) + ), + f"{src_block}.GptOssAttention.key.bias": ( + f"{layer_regex}.attn.bias_k_KH", (None, "layer", None) + ), + f"{src_block}.GptOssAttention.value.bias": ( + f"{layer_regex}.attn.bias_v_KH", (None, "layer", None) + ), + f"{src_block}.GptOssAttention.out.bias": ( + f"{layer_regex}.attn.bias_o_D", (None, "layer") + ), + f"{src_block}.GptOssAttention.sinks": ( + f"{layer_regex}.attn.sinks_N", (None, "layer") + ), + }) + + # MoE Router + mapping.update({ + f"{src_block}.GptOssMlp.gate.kernel": ( + f"{layer_regex}.custom_module.router.kernel_DE", + (None, "layer", "model") + ), + f"{src_block}.GptOssMlp.gate.bias": ( + f"{layer_regex}.custom_module.router.bias_E", + ("model", "layer") + ), + }) + + # --- MOE EXPERTS --- + + # MLP1 BIASES + mapping.update({ + f"{src_block}.GptOssMlp.wi_0_bias": ( + f"{layer_regex}.custom_module.mlp1_bias_EF2", + ("model", "layer") + ), + f"{src_block}.GptOssMlp.wi_1_bias": ( + f"{layer_regex}.custom_module.mlp1_bias_EF2", + ("model", "layer") + ), + }) + + # MLP1 WEIGHTS (Split -> Fused) + mapping.update({ + f"{src_block}.GptOssMlp.wi_0": ( + f"{layer_regex}.custom_module.mlp1_weight_EDF2", + ("model", "layer", None) + ), + f"{src_block}.GptOssMlp.wi_1": ( + f"{layer_regex}.custom_module.mlp1_weight_EDF2", + # Original: (None, "layer", "expert", "model", None) + ("model", "layer", None) + ), + }) + + # MLP2 (Down Projection) + mapping.update({ + f"{src_block}.GptOssMlp.wo_bias": ( + f"{layer_regex}.custom_module.mlp2_bias_ED", ("model", "layer") + ), + + f"{src_block}.GptOssMlp.wo": ( + f"{layer_regex}.custom_module.mlp2_weight_EFD", + ("model", "layer", None) + ), + }) + + # --- 4. Additional Config --- + mapping.update({ + "additional_config": { + "layer_cycle_interval": layer_cycle_interval, + } + }) + + return mapping diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index ceab65980..42e79466d 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -215,6 +215,36 @@ def setup_configs_and_devices(argv: Sequence[str]): return trainer_config, sampler_config, trainer_devices, sampler_devices +def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices): + """Get rollout kwargs for vLLM rollout when using data parallelism.""" + dp = sampler_config.rollout_data_parallelism + if dp == -1: + return {} + + rollout_kwargs = {} + tp = sampler_config.rollout_tensor_parallelism + + if tp == -1: + if num_sampler_devices % dp != 0: + raise ValueError( + f"num_sampler_devices({num_sampler_devices}) must be divisible by " + f"rollout_data_parallelism({dp}) " + f"when rollout_tensor_parallelism is -1." + ) + tp = num_sampler_devices // dp + elif tp * dp != num_sampler_devices: + raise ValueError( + f"rollout_tensor_parallelism({tp}) * " + f"rollout_data_parallelism({dp}) " + f"!= len(sampler_devices)({num_sampler_devices})" + ) + rollout_kwargs["tensor_parallel_size"] = tp + rollout_kwargs["data_parallel_size"] = dp + rollout_kwargs["rollout_vllm_async_scheduling"] = True + + return rollout_kwargs + + def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): """ Run RL training with the provided configuration. @@ -347,9 +377,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): eval_every_n_steps=trainer_config.eval_interval, max_steps=max_train_steps, # Micro batching - mini_batch_size=trainer_config.batch_size, - train_micro_batch_size=micro_batch_size, - rollout_micro_batch_size=micro_batch_size, + mini_batch_size=int(trainer_config.batch_size), + train_micro_batch_size=int(micro_batch_size), + rollout_micro_batch_size=int(micro_batch_size), # Metrics logging metrics_logging_options=metrics_logging_options, # Profiling @@ -369,6 +399,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + **get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)), ), ) grpo_config = GrpoConfig(