diff --git a/skyrl-tx/tests/conftest.py b/skyrl-tx/tests/conftest.py new file mode 100644 index 000000000..0e6763a15 --- /dev/null +++ b/skyrl-tx/tests/conftest.py @@ -0,0 +1,9 @@ +import jax +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def configure_jax_cpu_devices(): + """Configure JAX to use 2 CPU devices for testing parallelism.""" + if not jax._src.xla_bridge.backends_are_initialized(): + jax.config.update("jax_num_cpu_devices", 2) diff --git a/skyrl-tx/tests/models/test_llama3.py b/skyrl-tx/tests/models/test_llama3.py index a2fecc424..7666d5428 100644 --- a/skyrl-tx/tests/models/test_llama3.py +++ b/skyrl-tx/tests/models/test_llama3.py @@ -17,9 +17,6 @@ @pytest.mark.parametrize("tp", [1, 2]) def test_llama3(tp: int): """Test LLama3 model against HuggingFace reference implementation.""" - if not jax._src.xla_bridge.backends_are_initialized(): # type: ignore - jax.config.update("jax_num_cpu_devices", 2) - if os.getenv("CI"): pytest.skip("Test currently runs out of memory in the CI") diff --git a/skyrl-tx/tests/models/test_qwen3.py b/skyrl-tx/tests/models/test_qwen3.py index f1ce818b9..c450efbf8 100644 --- a/skyrl-tx/tests/models/test_qwen3.py +++ b/skyrl-tx/tests/models/test_qwen3.py @@ -19,9 +19,6 @@ @pytest.mark.parametrize("tp", [1, 2]) def test_qwen3(tp: int): - if not jax._src.xla_bridge.backends_are_initialized(): # ty: ignore - jax.config.update("jax_num_cpu_devices", 2) - if tp > 1 and os.getenv("CI"): pytest.skip("TP > 1 currently runs out of memory in the CI") @@ -64,7 +61,8 @@ def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: H jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T -def test_qwen3_moe_layer(): +@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) +def test_qwen3_moe_layer(ep: int, tp: int): model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM" hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) base_config = PretrainedConfig.from_pretrained(model_name) @@ -75,15 +73,15 @@ def test_qwen3_moe_layer(): with torch.no_grad(): hf_final_hidden_states, hf_router_logits = hf_moe_layer.forward(x) - mesh = jax.make_mesh((1, 1), ("fsdp", "tp")) + mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp")) with jax.set_mesh(mesh): moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_moe_base_weights(moe_layer, hf_moe_layer) - final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True) + final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True) - assert np.allclose(hf_router_logits, router_logits, rtol=1e-4) - assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2) + assert np.allclose(hf_router_logits, router_logits, rtol=1e-4) + assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2) def load_lora_weights( @@ -107,7 +105,8 @@ def load_lora_weights( jax_module.lora_ranks.value = jax_module.lora_ranks.value.at[adapter_idx].set(rank) -def test_qwen3_moe_layer_lora(): +@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)]) +def test_qwen3_moe_layer_lora(ep: int, tp: int): """Test MoE LoRA by merging adapter into base weights and comparing outputs.""" model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM" hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) @@ -117,7 +116,7 @@ def test_qwen3_moe_layer_lora(): hf_moe_layer = hf_model.model.layers[0].mlp x = torch.randn(3, 4, config.hidden_size) - mesh = jax.make_mesh((1, 1), ("fsdp", "tp")) + mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp")) with jax.set_mesh(mesh): moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) load_moe_base_weights(moe_layer, hf_moe_layer) diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py index f980a72c9..9bb1ac808 100644 --- a/skyrl-tx/tx/layers/lora.py +++ b/skyrl-tx/tx/layers/lora.py @@ -3,7 +3,7 @@ from jax import numpy as jnp from tx.utils.models import filter_lora -from tx.layers.util import Param, prepare_routing +from tx.layers.util import Param, prepare_routing, ragged_dot from tx.models.types import ModelForCausalLM from tx.tinker.types import LoraConfig @@ -236,8 +236,10 @@ def __call__( x: jax.Array, group_sizes: jax.Array, adapter_indices_sorted: jax.Array | None = None, + *, + group_offset: jax.Array | None = None, ) -> jax.Array: - base_out = jax.lax.ragged_dot(x, self.weight.value, group_sizes) + base_out = ragged_dot(x, self.weight.value, group_sizes, group_offset=group_offset) if self.max_lora_adapters == 0 or adapter_indices_sorted is None: return base_out @@ -245,23 +247,31 @@ def __call__( if self.lora_A is None or self.lora_B is None or self.lora_scaling is None: raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.") - # Reconstruct expert indices from group_sizes + # Reconstruct expert indices from group_sizes (global indices) expert_indices = jnp.repeat(jnp.arange(self.num_experts), group_sizes, total_repeat_length=x.shape[0]) - # Flatten (adapter, expert) into a single routing dimension. - flattened_indices = adapter_indices_sorted * self.num_experts + expert_indices - num_flattened_groups = self.max_lora_adapters * self.num_experts + # Expert-first flattening so local expert groups are contiguous + flattened_indices = expert_indices * self.max_lora_adapters + adapter_indices_sorted + num_flattened_groups = self.num_experts * self.max_lora_adapters + num_local_experts = self.lora_A.value.shape[1] - # Reshape lora_A and lora_B to merge (max_lora_adapters, num_experts) dimensions - lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank) - lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features) + # Reshape LoRA weights in expert-first order + lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape( + self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank + ) + lora_B = self.lora_B.value.transpose((1, 0, 2, 3)).reshape( + self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features + ) # Sort tokens by combined index x_sorted, combined_group_sizes, unsort_indices, _ = prepare_routing(x, flattened_indices, num_flattened_groups) + # Compute group_offset for LoRA (scaled by max_lora_adapters) + lora_group_offset = group_offset * self.max_lora_adapters if group_offset is not None else None + # Apply LoRA using ragged_dot: x @ A @ B - intermediate = jax.lax.ragged_dot(x_sorted, lora_A_reshaped, combined_group_sizes) - lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_B_reshaped, combined_group_sizes) + intermediate = ragged_dot(x_sorted, lora_A, combined_group_sizes, group_offset=lora_group_offset) + lora_output_sorted = ragged_dot(intermediate, lora_B, combined_group_sizes, group_offset=lora_group_offset) # Unsort and apply scaling lora_output = lora_output_sorted[unsort_indices] diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 1d5213b5b..0030c604d 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -2,6 +2,7 @@ import jax from jax import lax from jax import numpy as jnp +from jax.sharding import get_abstract_mesh, PartitionSpec def ragged_dot( @@ -84,3 +85,28 @@ def prepare_routing( group_sizes = jnp.bincount(indices, length=num_groups) unsort_indices = jnp.argsort(sort_indices) return sorted_tokens, group_sizes, unsort_indices, sorted_adapter_indices + + +def shard_map_ep(module: nnx.Module, func, *args): + """Apply shard_map over the 'ep' axis for a stateful nnx.Module. + + Args: + module: The NNX module (will be split into graph/state). + func: Function to run inside shard_map. Signature: (module, *args). + *args: Arguments to pass to func (replicated across shards). + """ + graphdef, state = nnx.split(module) + # Extract only 'ep' dims from PartitionSpecs, replacing others with None + state_specs = jax.tree.map( + lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s, + nnx.get_partition_spec(state), + is_leaf=lambda x: isinstance(x, PartitionSpec), + ) + in_specs = (state_specs,) + (PartitionSpec(),) * len(args) + + @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) + def _body(state, *fn_args): + module_shard = nnx.merge(graphdef, state) + return func(module_shard, *fn_args) + + return _body(state, *args) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 1f9424d40..cdc9c3a76 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -4,7 +4,7 @@ from jax.sharding import get_abstract_mesh from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear -from tx.layers.util import prepare_routing +from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope from tx.models.configs import Qwen3Config from tx.layers.layernorm import RMSNorm @@ -171,7 +171,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), rngs=rngs, ) self.up_proj = LoRAExpert( @@ -181,7 +181,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), rngs=rngs, ) self.down_proj = LoRAExpert( @@ -191,39 +191,47 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, dtype=dtype, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), rngs=rngs, ) def __call__( self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None ) -> jax.Array: - # Get top-k experts for each token and compute routing weights routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok) routing_weights = nnx.softmax(routing_weights, axis=-1) - # Prepare for ragged_dot by sorting tokens based on their assigned expert - selected_experts_flat = selected_experts.ravel() - hidden_states_expanded = jnp.repeat(hidden_states, self.config.num_experts_per_tok, axis=0) - adapter_indices_expanded = ( - jnp.repeat(adapter_indices, self.config.num_experts_per_tok) if adapter_indices is not None else None - ) - hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing( - hidden_states_expanded, - selected_experts_flat, - self.config.num_experts, - adapter_indices=adapter_indices_expanded, + num_experts = self.config.num_experts + num_experts_per_tok = self.config.num_experts_per_tok + hidden_size = self.config.hidden_size + + ep = get_abstract_mesh().shape.get("ep", 1) + assert num_experts % ep == 0, f"num_experts={num_experts} must be divisible by ep={ep}" + + # Prepare routing (inputs are replicated, so every rank generates the same sorted lists) + hidden_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0) + adapter_expanded = jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None + hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing( + hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded ) - # Apply expert layers using LoRAExpert - gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) - up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) - down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted) + def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights): + # Calculate local offset for this shard + ep_rank = jax.lax.axis_index("ep") + experts_per_rank = num_experts // jax.lax.axis_size("ep") + group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32) + + # Expert computation + gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) + up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) + down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset) + + # Unsort and combine + out = down[unsort_indices].reshape(-1, num_experts_per_tok, hidden_size) + local_out = jnp.sum(out * routing_weights[..., None], axis=1) + return jax.lax.psum(local_out, axis_name="ep") - # Unsort and combine the expert outputs - unsorted_out = down_out[unsort_indices] - reshaped_out = unsorted_out.reshape(-1, self.config.num_experts_per_tok, self.config.hidden_size) - return jnp.sum(reshaped_out * routing_weights[..., None], axis=1) + return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights) class Qwen3MoeSparseMoeBlock(nnx.Module): diff --git a/skyrl-tx/tx/run/train.py b/skyrl-tx/tx/run/train.py index 6ded5cbbf..1d6ff22d0 100644 --- a/skyrl-tx/tx/run/train.py +++ b/skyrl-tx/tx/run/train.py @@ -82,28 +82,28 @@ def train( loader = get_loader(loader_name) model_class = get_model_class(base_config) - mesh = jax.make_mesh((1, tp_size), ("fsdp", "tp")) + mesh = jax.make_mesh((1, 1, tp_size), ("fsdp", "ep", "tp")) with jax.set_mesh(mesh): model = model_class(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, get_optimizer(optimizer_name, optimizer_args), wrt=nnx.Param) - if load_checkpoint_path: - load_safetensors(load_checkpoint_path, base_config, model) + if load_checkpoint_path: + load_safetensors(load_checkpoint_path, base_config, model) - num_steps = train_dataset.num_rows / batch_size - for step, (batch, metrics) in enumerate(loader(tokenizer, train_dataset, batch_size)): - if max_steps and step >= max_steps: - break + num_steps = train_dataset.num_rows / batch_size + for step, (batch, metrics) in enumerate(loader(tokenizer, train_dataset, batch_size)): + if max_steps and step >= max_steps: + break - model.train() - loss, gradnorm = train_step(model, optimizer, batch) - tracker.log({"epoch": step / num_steps, **metrics, "gradnorm": gradnorm.item(), "loss": loss.item()}, step) + model.train() + loss, gradnorm = train_step(model, optimizer, batch) + tracker.log({"epoch": step / num_steps, **metrics, "gradnorm": gradnorm.item(), "loss": loss.item()}, step) - if step % save_steps == 0: - logger.info(f"Saving checkpoint to {output_dir}") - save_safetensors(base_config, model, output_dir / "model.safetensors") + if step % save_steps == 0: + logger.info(f"Saving checkpoint to {output_dir}") + save_safetensors(base_config, model, output_dir / "model.safetensors") - logger.info(f"Saving final checkpoint to {output_dir}") - base_config.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - save_safetensors(base_config, model, output_dir / "model.safetensors") + logger.info(f"Saving final checkpoint to {output_dir}") + base_config.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + save_safetensors(base_config, model, output_dir / "model.safetensors") diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 7a03f9cf6..720b760eb 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -62,6 +62,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters") max_lora_rank: int = Field(default=32, description="Maximum LoRA rank") tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model") + expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers") fully_sharded_data_parallel_size: int = Field( default=1, description="Fully sharded data parallelism degree for the model" ) @@ -168,7 +169,12 @@ def __init__(self, base_model: str, config: JaxBackendConfig): # Create model and load weights self.mesh = jax.make_mesh( - (config.fully_sharded_data_parallel_size, config.tensor_parallel_size), ("fsdp", "tp") + ( + config.fully_sharded_data_parallel_size, + config.expert_parallel_size, + config.tensor_parallel_size, + ), + ("fsdp", "ep", "tp"), ) with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0))