Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions skyrl-tx/tests/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: H
jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T


def test_qwen3_moe_layer():
@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
def test_qwen3_moe_layer(ep: int, tp: int):
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update("jax_num_cpu_devices", ep * tp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This device initialization logic is duplicated in test_qwen3_moe_layer_lora (lines 117-118). To improve maintainability and reduce redundancy, consider extracting this setup into a shared pytest fixture. The fixture could automatically run for tests that require ep and tp parameters, making the test suite cleaner and easier to manage.


model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
base_config = PretrainedConfig.from_pretrained(model_name)
Expand All @@ -75,15 +79,15 @@ def test_qwen3_moe_layer():
with torch.no_grad():
hf_final_hidden_states, hf_router_logits = hf_moe_layer.forward(x)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)

final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True)
final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True)

assert np.allclose(hf_router_logits, router_logits, rtol=1e-4)
assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2)
assert np.allclose(hf_router_logits, router_logits, rtol=1e-4)
assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2)


def load_lora_weights(
Expand All @@ -107,8 +111,12 @@ def load_lora_weights(
jax_module.lora_ranks.value = jax_module.lora_ranks.value.at[adapter_idx].set(rank)


def test_qwen3_moe_layer_lora():
@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
def test_qwen3_moe_layer_lora(ep: int, tp: int):
"""Test MoE LoRA by merging adapter into base weights and comparing outputs."""
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update("jax_num_cpu_devices", ep * tp)

model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
base_config = PretrainedConfig.from_pretrained(model_name)
Expand All @@ -117,7 +125,7 @@ def test_qwen3_moe_layer_lora():
hf_moe_layer = hf_model.model.layers[0].mlp
x = torch.randn(3, 4, config.hidden_size)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)
Expand Down
35 changes: 24 additions & 11 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp

from tx.utils.models import filter_lora
from tx.layers.util import Param, prepare_routing
from tx.layers.util import Param, prepare_routing, ragged_dot
from tx.models.types import ModelForCausalLM
from tx.tinker.types import LoraConfig

Expand Down Expand Up @@ -236,32 +236,45 @@ def __call__(
x: jax.Array,
group_sizes: jax.Array,
adapter_indices_sorted: jax.Array | None = None,
*,
group_offset: jax.Array | None = None,
) -> jax.Array:
base_out = jax.lax.ragged_dot(x, self.weight.value, group_sizes)
# Inside shard_map, weights are already local shards
weight = self.weight.value
num_local_experts = weight.shape[0]

base_out = ragged_dot(x, weight, group_sizes, group_offset=group_offset)

if self.max_lora_adapters == 0 or adapter_indices_sorted is None:
return base_out

if self.lora_A is None or self.lora_B is None or self.lora_scaling is None:
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")

# Reconstruct expert indices from group_sizes
# Reconstruct expert indices from group_sizes (global indices)
expert_indices = jnp.repeat(jnp.arange(self.num_experts), group_sizes, total_repeat_length=x.shape[0])

# Flatten (adapter, expert) into a single routing dimension.
flattened_indices = adapter_indices_sorted * self.num_experts + expert_indices
num_flattened_groups = self.max_lora_adapters * self.num_experts
# Expert-first flattening so local expert groups are contiguous
flattened_indices = expert_indices * self.max_lora_adapters + adapter_indices_sorted
num_flattened_groups = self.num_experts * self.max_lora_adapters

# Reshape lora_A and lora_B to merge (max_lora_adapters, num_experts) dimensions
lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank)
lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features)
# Reshape LoRA weights in expert-first order (already local shards)
lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with sl_loop.py and

uv run --extra gpu --extra tinker -m tx.tinker.api     --base-model Qwen/Qwen3-30B-A3B     --backend-config
 '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

Though this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible.

self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank
)
lora_B = self.lora_B.value.transpose((1, 0, 2, 3)).reshape(
self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features
)

# Sort tokens by combined index
x_sorted, combined_group_sizes, unsort_indices, _ = prepare_routing(x, flattened_indices, num_flattened_groups)

# Compute group_offset for LoRA (scaled by max_lora_adapters)
lora_group_offset = group_offset * self.max_lora_adapters if group_offset is not None else None

# Apply LoRA using ragged_dot: x @ A @ B
intermediate = jax.lax.ragged_dot(x_sorted, lora_A_reshaped, combined_group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_B_reshaped, combined_group_sizes)
intermediate = ragged_dot(x_sorted, lora_A, combined_group_sizes, group_offset=lora_group_offset)
lora_output_sorted = ragged_dot(intermediate, lora_B, combined_group_sizes, group_offset=lora_group_offset)

# Unsort and apply scaling
lora_output = lora_output_sorted[unsort_indices]
Expand Down
26 changes: 26 additions & 0 deletions skyrl-tx/tx/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
from jax import lax
from jax import numpy as jnp
from jax.sharding import get_abstract_mesh, PartitionSpec


def ragged_dot(
Expand Down Expand Up @@ -84,3 +85,28 @@ def prepare_routing(
group_sizes = jnp.bincount(indices, length=num_groups)
unsort_indices = jnp.argsort(sort_indices)
return sorted_tokens, group_sizes, unsort_indices, sorted_adapter_indices


def shard_map_ep(module: nnx.Module, func, *args):
"""Apply shard_map over the 'ep' axis for a stateful nnx.Module.

Args:
module: The NNX module (will be split into graph/state).
func: Function to run inside shard_map. Signature: (module, *args).
*args: Arguments to pass to func (replicated across shards).
"""
graphdef, state = nnx.split(module)
# Extract only 'ep' dims from PartitionSpecs, replacing others with None
state_specs = jax.tree.map(
lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s,
nnx.get_partition_spec(state),
is_leaf=lambda x: isinstance(x, PartitionSpec),
)
in_specs = (state_specs,) + (PartitionSpec(),) * len(args)

@jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"})
def _body(state, *fn_args):
module_shard = nnx.merge(graphdef, state)
return func(module_shard, *fn_args)

return _body(state, *args)
53 changes: 29 additions & 24 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.sharding import get_abstract_mesh

from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear
from tx.layers.util import prepare_routing
from tx.layers.util import prepare_routing, shard_map_ep
from tx.layers.rotary_embedding import apply_rope
from tx.models.configs import Qwen3Config
from tx.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
rngs=rngs,
)
self.up_proj = LoRAExpert(
Expand All @@ -181,7 +181,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
rngs=rngs,
)
self.down_proj = LoRAExpert(
Expand All @@ -191,39 +191,44 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")),
rngs=rngs,
)

def __call__(
self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None
) -> jax.Array:
# Get top-k experts for each token and compute routing weights
routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok)
routing_weights = nnx.softmax(routing_weights, axis=-1)

# Prepare for ragged_dot by sorting tokens based on their assigned expert
selected_experts_flat = selected_experts.ravel()
hidden_states_expanded = jnp.repeat(hidden_states, self.config.num_experts_per_tok, axis=0)
adapter_indices_expanded = (
jnp.repeat(adapter_indices, self.config.num_experts_per_tok) if adapter_indices is not None else None
)
hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing(
hidden_states_expanded,
selected_experts_flat,
self.config.num_experts,
adapter_indices=adapter_indices_expanded,
num_experts = self.config.num_experts
num_experts_per_tok = self.config.num_experts_per_tok
hidden_size = self.config.hidden_size

# Prepare routing (inputs are replicated, so every rank generates the same sorted lists)
hidden_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0)
adapter_expanded = jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None
hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing(
hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded
)

# Apply expert layers using LoRAExpert
gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted)
up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted)
down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted)
def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights):
# Calculate local offset for this shard
ep_rank = jax.lax.axis_index("ep")
experts_per_rank = num_experts // jax.lax.axis_size("ep")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This integer division assumes that num_experts is evenly divisible by the number of devices in the 'ep' mesh axis. If it's not, this could lead to an incorrect number of experts being assigned per rank, causing silent errors or incorrect model behavior. It would be much safer to add an assertion to validate this assumption, ideally during model initialization.

For example, you could add this check in Qwen3Experts.__init__:

ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})."

group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32)

# Expert computation
gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset)
up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset)
down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset)

# Unsort and combine
out = down[unsort_indices].reshape(-1, num_experts_per_tok, hidden_size)
local_out = jnp.sum(out * routing_weights[..., None], axis=1)
return jax.lax.psum(local_out, axis_name="ep")

# Unsort and combine the expert outputs
unsorted_out = down_out[unsort_indices]
reshaped_out = unsorted_out.reshape(-1, self.config.num_experts_per_tok, self.config.hidden_size)
return jnp.sum(reshaped_out * routing_weights[..., None], axis=1)
return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights)


class Qwen3MoeSparseMoeBlock(nnx.Module):
Expand Down
2 changes: 1 addition & 1 deletion skyrl-tx/tx/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def train(
loader = get_loader(loader_name)

model_class = get_model_class(base_config)
mesh = jax.make_mesh((1, tp_size), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1, tp_size), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
model = model_class(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, get_optimizer(optimizer_name, optimizer_args), wrt=nnx.Param)
Expand Down
8 changes: 7 additions & 1 deletion skyrl-tx/tx/tinker/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters")
max_lora_rank: int = Field(default=32, description="Maximum LoRA rank")
tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model")
expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers")
fully_sharded_data_parallel_size: int = Field(
default=1, description="Fully sharded data parallelism degree for the model"
)
Expand Down Expand Up @@ -168,7 +169,12 @@ def __init__(self, base_model: str, config: JaxBackendConfig):

# Create model and load weights
self.mesh = jax.make_mesh(
(config.fully_sharded_data_parallel_size, config.tensor_parallel_size), ("fsdp", "tp")
(
config.fully_sharded_data_parallel_size,
config.expert_parallel_size,
config.tensor_parallel_size,
),
("fsdp", "ep", "tp"),
)
with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True):
self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0))
Expand Down
Loading