-
Notifications
You must be signed in to change notification settings - Fork 221
[tx] Implement expert parallelism #842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
27a81ca
56cf3bd
30e4488
14df7df
678b293
a498953
807d50c
7da5297
2d1fe78
5ef67a7
237e15b
b36e164
ba2dea8
3a76ad7
bba2c5b
fe21f93
5bebcb8
2470ad6
6670c0b
ecd2de8
5b7d6b6
a4348d8
7b62cbd
e0dcd43
a6c41d1
97393c6
fc602c1
786e449
ed86518
789b441
f64fa89
8995379
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import jax | ||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session", autouse=True) | ||
| def configure_jax_cpu_devices(): | ||
| """Configure JAX to use 2 CPU devices for testing parallelism.""" | ||
| if not jax._src.xla_bridge.backends_are_initialized(): | ||
| jax.config.update("jax_num_cpu_devices", 2) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| from jax.sharding import get_abstract_mesh | ||
|
|
||
| from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear | ||
| from tx.layers.util import prepare_routing | ||
| from tx.layers.util import prepare_routing, shard_map_ep | ||
| from tx.layers.rotary_embedding import apply_rope | ||
| from tx.models.configs import Qwen3Config | ||
| from tx.layers.layernorm import RMSNorm | ||
|
|
@@ -171,7 +171,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), | ||
| rngs=rngs, | ||
| ) | ||
| self.up_proj = LoRAExpert( | ||
|
|
@@ -181,7 +181,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), | ||
| rngs=rngs, | ||
| ) | ||
| self.down_proj = LoRAExpert( | ||
|
|
@@ -191,39 +191,47 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), | ||
| rngs=rngs, | ||
| ) | ||
|
|
||
| def __call__( | ||
| self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None | ||
| ) -> jax.Array: | ||
| # Get top-k experts for each token and compute routing weights | ||
| routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok) | ||
| routing_weights = nnx.softmax(routing_weights, axis=-1) | ||
|
|
||
| # Prepare for ragged_dot by sorting tokens based on their assigned expert | ||
| selected_experts_flat = selected_experts.ravel() | ||
| hidden_states_expanded = jnp.repeat(hidden_states, self.config.num_experts_per_tok, axis=0) | ||
| adapter_indices_expanded = ( | ||
| jnp.repeat(adapter_indices, self.config.num_experts_per_tok) if adapter_indices is not None else None | ||
| ) | ||
| hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing( | ||
| hidden_states_expanded, | ||
| selected_experts_flat, | ||
| self.config.num_experts, | ||
| adapter_indices=adapter_indices_expanded, | ||
| num_experts = self.config.num_experts | ||
| num_experts_per_tok = self.config.num_experts_per_tok | ||
| hidden_size = self.config.hidden_size | ||
|
|
||
| ep = get_abstract_mesh().shape.get("ep", 1) | ||
| assert num_experts % ep == 0, f"num_experts={num_experts} must be divisible by ep={ep}" | ||
|
|
||
| # Prepare routing (inputs are replicated, so every rank generates the same sorted lists) | ||
| hidden_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0) | ||
| adapter_expanded = jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None | ||
| hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing( | ||
| hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded | ||
| ) | ||
|
|
||
| # Apply expert layers using LoRAExpert | ||
| gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) | ||
| up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) | ||
| down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted) | ||
| def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights): | ||
| # Calculate local offset for this shard | ||
| ep_rank = jax.lax.axis_index("ep") | ||
| experts_per_rank = num_experts // jax.lax.axis_size("ep") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This integer division assumes that For example, you could add this check in ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})." |
||
| group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32) | ||
|
|
||
| # Expert computation | ||
| gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) | ||
| up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) | ||
| down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset) | ||
|
|
||
| # Unsort and combine | ||
| out = down[unsort_indices].reshape(-1, num_experts_per_tok, hidden_size) | ||
| local_out = jnp.sum(out * routing_weights[..., None], axis=1) | ||
| return jax.lax.psum(local_out, axis_name="ep") | ||
|
|
||
| # Unsort and combine the expert outputs | ||
| unsorted_out = down_out[unsort_indices] | ||
| reshaped_out = unsorted_out.reshape(-1, self.config.num_experts_per_tok, self.config.hidden_size) | ||
| return jnp.sum(reshaped_out * routing_weights[..., None], axis=1) | ||
| return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights) | ||
|
|
||
|
|
||
| class Qwen3MoeSparseMoeBlock(nnx.Module): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with
sl_loop.pyandThough this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible.