Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions skyrl-tx/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax
import pytest


@pytest.fixture(scope="session", autouse=True)
def configure_jax_cpu_devices():
"""Configure JAX to use 2 CPU devices for testing parallelism."""
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update("jax_num_cpu_devices", 2)
3 changes: 0 additions & 3 deletions skyrl-tx/tests/models/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
@pytest.mark.parametrize("tp", [1, 2])
def test_llama3(tp: int):
"""Test LLama3 model against HuggingFace reference implementation."""
if not jax._src.xla_bridge.backends_are_initialized(): # type: ignore
jax.config.update("jax_num_cpu_devices", 2)

if os.getenv("CI"):
pytest.skip("Test currently runs out of memory in the CI")

Expand Down
19 changes: 9 additions & 10 deletions skyrl-tx/tests/models/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

@pytest.mark.parametrize("tp", [1, 2])
def test_qwen3(tp: int):
if not jax._src.xla_bridge.backends_are_initialized(): # ty: ignore
jax.config.update("jax_num_cpu_devices", 2)

if tp > 1 and os.getenv("CI"):
pytest.skip("TP > 1 currently runs out of memory in the CI")

Expand Down Expand Up @@ -64,7 +61,8 @@ def load_moe_base_weights(jax_moe_layer: Qwen3MoeSparseMoeBlock, hf_moe_layer: H
jax_moe_layer.experts.down_proj.weight[i, :, :] = expert.down_proj.weight.detach().numpy().T


def test_qwen3_moe_layer():
@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
def test_qwen3_moe_layer(ep: int, tp: int):
model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
base_config = PretrainedConfig.from_pretrained(model_name)
Expand All @@ -75,15 +73,15 @@ def test_qwen3_moe_layer():
with torch.no_grad():
hf_final_hidden_states, hf_router_logits = hf_moe_layer.forward(x)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)

final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True)
final_hidden_states, router_logits = moe_layer(x.numpy(), return_router_logits=True)

assert np.allclose(hf_router_logits, router_logits, rtol=1e-4)
assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2)
assert np.allclose(hf_router_logits, router_logits, rtol=1e-4)
assert np.allclose(hf_final_hidden_states, final_hidden_states, rtol=1e-2, atol=1e-2)


def load_lora_weights(
Expand All @@ -107,7 +105,8 @@ def load_lora_weights(
jax_module.lora_ranks.value = jax_module.lora_ranks.value.at[adapter_idx].set(rank)


def test_qwen3_moe_layer_lora():
@pytest.mark.parametrize("ep,tp", [(1, 1), (1, 2), (2, 1)])
def test_qwen3_moe_layer_lora(ep: int, tp: int):
"""Test MoE LoRA by merging adapter into base weights and comparing outputs."""
model_name = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
Expand All @@ -117,7 +116,7 @@ def test_qwen3_moe_layer_lora():
hf_moe_layer = hf_model.model.layers[0].mlp
x = torch.randn(3, 4, config.hidden_size)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
mesh = jax.make_mesh((1, ep, tp), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
moe_layer = Qwen3MoeSparseMoeBlock(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_moe_base_weights(moe_layer, hf_moe_layer)
Expand Down
32 changes: 21 additions & 11 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp

from tx.utils.models import filter_lora
from tx.layers.util import Param, prepare_routing
from tx.layers.util import Param, prepare_routing, ragged_dot
from tx.models.types import ModelForCausalLM
from tx.tinker.types import LoraConfig

Expand Down Expand Up @@ -236,32 +236,42 @@ def __call__(
x: jax.Array,
group_sizes: jax.Array,
adapter_indices_sorted: jax.Array | None = None,
*,
group_offset: jax.Array | None = None,
) -> jax.Array:
base_out = jax.lax.ragged_dot(x, self.weight.value, group_sizes)
base_out = ragged_dot(x, self.weight.value, group_sizes, group_offset=group_offset)

if self.max_lora_adapters == 0 or adapter_indices_sorted is None:
return base_out

if self.lora_A is None or self.lora_B is None or self.lora_scaling is None:
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")

# Reconstruct expert indices from group_sizes
# Reconstruct expert indices from group_sizes (global indices)
expert_indices = jnp.repeat(jnp.arange(self.num_experts), group_sizes, total_repeat_length=x.shape[0])

# Flatten (adapter, expert) into a single routing dimension.
flattened_indices = adapter_indices_sorted * self.num_experts + expert_indices
num_flattened_groups = self.max_lora_adapters * self.num_experts
# Expert-first flattening so local expert groups are contiguous
flattened_indices = expert_indices * self.max_lora_adapters + adapter_indices_sorted
num_flattened_groups = self.num_experts * self.max_lora_adapters
num_local_experts = self.lora_A.value.shape[1]

# Reshape lora_A and lora_B to merge (max_lora_adapters, num_experts) dimensions
lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank)
lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features)
# Reshape LoRA weights in expert-first order
lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with sl_loop.py and

uv run --extra gpu --extra tinker -m tx.tinker.api     --base-model Qwen/Qwen3-30B-A3B     --backend-config
 '{"max_lora_adapters": 2, "max_lora_rank": 1, "expert_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

Though this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible.

self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank
)
lora_B = self.lora_B.value.transpose((1, 0, 2, 3)).reshape(
self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features
)

# Sort tokens by combined index
x_sorted, combined_group_sizes, unsort_indices, _ = prepare_routing(x, flattened_indices, num_flattened_groups)

# Compute group_offset for LoRA (scaled by max_lora_adapters)
lora_group_offset = group_offset * self.max_lora_adapters if group_offset is not None else None

# Apply LoRA using ragged_dot: x @ A @ B
intermediate = jax.lax.ragged_dot(x_sorted, lora_A_reshaped, combined_group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_B_reshaped, combined_group_sizes)
intermediate = ragged_dot(x_sorted, lora_A, combined_group_sizes, group_offset=lora_group_offset)
lora_output_sorted = ragged_dot(intermediate, lora_B, combined_group_sizes, group_offset=lora_group_offset)

# Unsort and apply scaling
lora_output = lora_output_sorted[unsort_indices]
Expand Down
26 changes: 26 additions & 0 deletions skyrl-tx/tx/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
from jax import lax
from jax import numpy as jnp
from jax.sharding import get_abstract_mesh, PartitionSpec


def ragged_dot(
Expand Down Expand Up @@ -84,3 +85,28 @@ def prepare_routing(
group_sizes = jnp.bincount(indices, length=num_groups)
unsort_indices = jnp.argsort(sort_indices)
return sorted_tokens, group_sizes, unsort_indices, sorted_adapter_indices


def shard_map_ep(module: nnx.Module, func, *args):
"""Apply shard_map over the 'ep' axis for a stateful nnx.Module.

Args:
module: The NNX module (will be split into graph/state).
func: Function to run inside shard_map. Signature: (module, *args).
*args: Arguments to pass to func (replicated across shards).
"""
graphdef, state = nnx.split(module)
# Extract only 'ep' dims from PartitionSpecs, replacing others with None
state_specs = jax.tree.map(
lambda s: PartitionSpec(*(p if p == "ep" else None for p in s)) if isinstance(s, PartitionSpec) else s,
nnx.get_partition_spec(state),
is_leaf=lambda x: isinstance(x, PartitionSpec),
)
in_specs = (state_specs,) + (PartitionSpec(),) * len(args)

@jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"})
def _body(state, *fn_args):
module_shard = nnx.merge(graphdef, state)
return func(module_shard, *fn_args)

return _body(state, *args)
56 changes: 32 additions & 24 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.sharding import get_abstract_mesh

from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear
from tx.layers.util import prepare_routing
from tx.layers.util import prepare_routing, shard_map_ep
from tx.layers.rotary_embedding import apply_rope
from tx.models.configs import Qwen3Config
from tx.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
rngs=rngs,
)
self.up_proj = LoRAExpert(
Expand All @@ -181,7 +181,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")),
rngs=rngs,
)
self.down_proj = LoRAExpert(
Expand All @@ -191,39 +191,47 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
max_lora_adapters=config.max_lora_adapters,
max_lora_rank=config.max_lora_rank,
dtype=dtype,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")),
rngs=rngs,
)

def __call__(
self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None
) -> jax.Array:
# Get top-k experts for each token and compute routing weights
routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok)
routing_weights = nnx.softmax(routing_weights, axis=-1)

# Prepare for ragged_dot by sorting tokens based on their assigned expert
selected_experts_flat = selected_experts.ravel()
hidden_states_expanded = jnp.repeat(hidden_states, self.config.num_experts_per_tok, axis=0)
adapter_indices_expanded = (
jnp.repeat(adapter_indices, self.config.num_experts_per_tok) if adapter_indices is not None else None
)
hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing(
hidden_states_expanded,
selected_experts_flat,
self.config.num_experts,
adapter_indices=adapter_indices_expanded,
num_experts = self.config.num_experts
num_experts_per_tok = self.config.num_experts_per_tok
hidden_size = self.config.hidden_size

ep = get_abstract_mesh().shape.get("ep", 1)
assert num_experts % ep == 0, f"num_experts={num_experts} must be divisible by ep={ep}"

# Prepare routing (inputs are replicated, so every rank generates the same sorted lists)
hidden_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0)
adapter_expanded = jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None
hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing(
hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded
)

# Apply expert layers using LoRAExpert
gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted)
up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted)
down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted)
def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights):
# Calculate local offset for this shard
ep_rank = jax.lax.axis_index("ep")
experts_per_rank = num_experts // jax.lax.axis_size("ep")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This integer division assumes that num_experts is evenly divisible by the number of devices in the 'ep' mesh axis. If it's not, this could lead to an incorrect number of experts being assigned per rank, causing silent errors or incorrect model behavior. It would be much safer to add an assertion to validate this assumption, ideally during model initialization.

For example, you could add this check in Qwen3Experts.__init__:

ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})."

group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32)

# Expert computation
gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset)
up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset)
down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset)

# Unsort and combine
out = down[unsort_indices].reshape(-1, num_experts_per_tok, hidden_size)
local_out = jnp.sum(out * routing_weights[..., None], axis=1)
return jax.lax.psum(local_out, axis_name="ep")

# Unsort and combine the expert outputs
unsorted_out = down_out[unsort_indices]
reshaped_out = unsorted_out.reshape(-1, self.config.num_experts_per_tok, self.config.hidden_size)
return jnp.sum(reshaped_out * routing_weights[..., None], axis=1)
return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights)


class Qwen3MoeSparseMoeBlock(nnx.Module):
Expand Down
34 changes: 17 additions & 17 deletions skyrl-tx/tx/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,28 @@ def train(
loader = get_loader(loader_name)

model_class = get_model_class(base_config)
mesh = jax.make_mesh((1, tp_size), ("fsdp", "tp"))
mesh = jax.make_mesh((1, 1, tp_size), ("fsdp", "ep", "tp"))
with jax.set_mesh(mesh):
model = model_class(config, dtype=get_dtype(config.dtype), rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, get_optimizer(optimizer_name, optimizer_args), wrt=nnx.Param)

if load_checkpoint_path:
load_safetensors(load_checkpoint_path, base_config, model)
if load_checkpoint_path:
load_safetensors(load_checkpoint_path, base_config, model)

num_steps = train_dataset.num_rows / batch_size
for step, (batch, metrics) in enumerate(loader(tokenizer, train_dataset, batch_size)):
if max_steps and step >= max_steps:
break
num_steps = train_dataset.num_rows / batch_size
for step, (batch, metrics) in enumerate(loader(tokenizer, train_dataset, batch_size)):
if max_steps and step >= max_steps:
break

model.train()
loss, gradnorm = train_step(model, optimizer, batch)
tracker.log({"epoch": step / num_steps, **metrics, "gradnorm": gradnorm.item(), "loss": loss.item()}, step)
model.train()
loss, gradnorm = train_step(model, optimizer, batch)
tracker.log({"epoch": step / num_steps, **metrics, "gradnorm": gradnorm.item(), "loss": loss.item()}, step)

if step % save_steps == 0:
logger.info(f"Saving checkpoint to {output_dir}")
save_safetensors(base_config, model, output_dir / "model.safetensors")
if step % save_steps == 0:
logger.info(f"Saving checkpoint to {output_dir}")
save_safetensors(base_config, model, output_dir / "model.safetensors")

logger.info(f"Saving final checkpoint to {output_dir}")
base_config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
save_safetensors(base_config, model, output_dir / "model.safetensors")
logger.info(f"Saving final checkpoint to {output_dir}")
base_config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
save_safetensors(base_config, model, output_dir / "model.safetensors")
8 changes: 7 additions & 1 deletion skyrl-tx/tx/tinker/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters")
max_lora_rank: int = Field(default=32, description="Maximum LoRA rank")
tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model")
expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers")
fully_sharded_data_parallel_size: int = Field(
default=1, description="Fully sharded data parallelism degree for the model"
)
Expand Down Expand Up @@ -168,7 +169,12 @@ def __init__(self, base_model: str, config: JaxBackendConfig):

# Create model and load weights
self.mesh = jax.make_mesh(
(config.fully_sharded_data_parallel_size, config.tensor_parallel_size), ("fsdp", "tp")
(
config.fully_sharded_data_parallel_size,
config.expert_parallel_size,
config.tensor_parallel_size,
),
("fsdp", "ep", "tp"),
)
with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True):
self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0))
Expand Down