-
Notifications
You must be signed in to change notification settings - Fork 244
[tx] Implement expert parallelism #842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 28 commits
27a81ca
56cf3bd
30e4488
14df7df
678b293
a498953
807d50c
7da5297
2d1fe78
5ef67a7
237e15b
b36e164
ba2dea8
3a76ad7
bba2c5b
fe21f93
5bebcb8
2470ad6
6670c0b
ecd2de8
5b7d6b6
a4348d8
7b62cbd
e0dcd43
a6c41d1
97393c6
fc602c1
786e449
ed86518
789b441
f64fa89
8995379
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from jax import numpy as jnp | ||
|
|
||
| from tx.utils.models import filter_lora | ||
| from tx.layers.util import Param, prepare_routing | ||
| from tx.layers.util import Param, prepare_routing, ragged_dot | ||
| from tx.models.types import ModelForCausalLM | ||
| from tx.tinker.types import LoraConfig | ||
|
|
||
|
|
@@ -236,32 +236,45 @@ def __call__( | |
| x: jax.Array, | ||
| group_sizes: jax.Array, | ||
| adapter_indices_sorted: jax.Array | None = None, | ||
| *, | ||
| group_offset: jax.Array | None = None, | ||
| ) -> jax.Array: | ||
| base_out = jax.lax.ragged_dot(x, self.weight.value, group_sizes) | ||
| # Inside shard_map, weights are already local shards | ||
| weight = self.weight.value | ||
| num_local_experts = weight.shape[0] | ||
|
|
||
| base_out = ragged_dot(x, weight, group_sizes, group_offset=group_offset) | ||
|
|
||
| if self.max_lora_adapters == 0 or adapter_indices_sorted is None: | ||
| return base_out | ||
|
|
||
| if self.lora_A is None or self.lora_B is None or self.lora_scaling is None: | ||
| raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.") | ||
|
|
||
| # Reconstruct expert indices from group_sizes | ||
| # Reconstruct expert indices from group_sizes (global indices) | ||
| expert_indices = jnp.repeat(jnp.arange(self.num_experts), group_sizes, total_repeat_length=x.shape[0]) | ||
|
|
||
| # Flatten (adapter, expert) into a single routing dimension. | ||
| flattened_indices = adapter_indices_sorted * self.num_experts + expert_indices | ||
| num_flattened_groups = self.max_lora_adapters * self.num_experts | ||
| # Expert-first flattening so local expert groups are contiguous | ||
| flattened_indices = expert_indices * self.max_lora_adapters + adapter_indices_sorted | ||
| num_flattened_groups = self.num_experts * self.max_lora_adapters | ||
|
|
||
| # Reshape lora_A and lora_B to merge (max_lora_adapters, num_experts) dimensions | ||
| lora_A_reshaped = self.lora_A.value.reshape(num_flattened_groups, self.in_features, self.max_lora_rank) | ||
| lora_B_reshaped = self.lora_B.value.reshape(num_flattened_groups, self.max_lora_rank, self.out_features) | ||
| # Reshape LoRA weights in expert-first order (already local shards) | ||
| lora_A = self.lora_A.value.transpose((1, 0, 2, 3)).reshape( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also tried without the transpose and just putting experts first in the weight tensors when initializing them -- the code is more complicated and also slower, e.g. 54s step time vs 40s with Though this is somewhat surprising and there might be more optimization potential in the future, but for now it is best to keep it as simple as possible. |
||
| self.max_lora_adapters * num_local_experts, self.in_features, self.max_lora_rank | ||
| ) | ||
| lora_B = self.lora_B.value.transpose((1, 0, 2, 3)).reshape( | ||
| self.max_lora_adapters * num_local_experts, self.max_lora_rank, self.out_features | ||
| ) | ||
|
|
||
| # Sort tokens by combined index | ||
| x_sorted, combined_group_sizes, unsort_indices, _ = prepare_routing(x, flattened_indices, num_flattened_groups) | ||
|
|
||
| # Compute group_offset for LoRA (scaled by max_lora_adapters) | ||
| lora_group_offset = group_offset * self.max_lora_adapters if group_offset is not None else None | ||
|
|
||
| # Apply LoRA using ragged_dot: x @ A @ B | ||
| intermediate = jax.lax.ragged_dot(x_sorted, lora_A_reshaped, combined_group_sizes) | ||
| lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_B_reshaped, combined_group_sizes) | ||
| intermediate = ragged_dot(x_sorted, lora_A, combined_group_sizes, group_offset=lora_group_offset) | ||
| lora_output_sorted = ragged_dot(intermediate, lora_B, combined_group_sizes, group_offset=lora_group_offset) | ||
|
|
||
| # Unsort and apply scaling | ||
| lora_output = lora_output_sorted[unsort_indices] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| from jax.sharding import get_abstract_mesh | ||
|
|
||
| from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear | ||
| from tx.layers.util import prepare_routing | ||
| from tx.layers.util import prepare_routing, shard_map_ep | ||
| from tx.layers.rotary_embedding import apply_rope | ||
| from tx.models.configs import Qwen3Config | ||
| from tx.layers.layernorm import RMSNorm | ||
|
|
@@ -171,7 +171,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), | ||
| rngs=rngs, | ||
| ) | ||
| self.up_proj = LoRAExpert( | ||
|
|
@@ -181,7 +181,7 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "fsdp", "tp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "fsdp", "tp")), | ||
| rngs=rngs, | ||
| ) | ||
| self.down_proj = LoRAExpert( | ||
|
|
@@ -191,39 +191,44 @@ def __init__(self, config: Qwen3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> | |
| max_lora_adapters=config.max_lora_adapters, | ||
| max_lora_rank=config.max_lora_rank, | ||
| dtype=dtype, | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "tp", "fsdp")), | ||
| kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("ep", "tp", "fsdp")), | ||
| rngs=rngs, | ||
| ) | ||
|
|
||
| def __call__( | ||
| self, hidden_states: jax.Array, router_logits: jax.Array, adapter_indices: jax.Array | None = None | ||
| ) -> jax.Array: | ||
| # Get top-k experts for each token and compute routing weights | ||
| routing_weights, selected_experts = jax.lax.top_k(router_logits, k=self.config.num_experts_per_tok) | ||
| routing_weights = nnx.softmax(routing_weights, axis=-1) | ||
|
|
||
| # Prepare for ragged_dot by sorting tokens based on their assigned expert | ||
| selected_experts_flat = selected_experts.ravel() | ||
| hidden_states_expanded = jnp.repeat(hidden_states, self.config.num_experts_per_tok, axis=0) | ||
| adapter_indices_expanded = ( | ||
| jnp.repeat(adapter_indices, self.config.num_experts_per_tok) if adapter_indices is not None else None | ||
| ) | ||
| hidden_states_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing( | ||
| hidden_states_expanded, | ||
| selected_experts_flat, | ||
| self.config.num_experts, | ||
| adapter_indices=adapter_indices_expanded, | ||
| num_experts = self.config.num_experts | ||
| num_experts_per_tok = self.config.num_experts_per_tok | ||
| hidden_size = self.config.hidden_size | ||
|
|
||
| # Prepare routing (inputs are replicated, so every rank generates the same sorted lists) | ||
| hidden_expanded = jnp.repeat(hidden_states, num_experts_per_tok, axis=0) | ||
| adapter_expanded = jnp.repeat(adapter_indices, num_experts_per_tok) if adapter_indices is not None else None | ||
| hidden_sorted, group_sizes, unsort_indices, adapter_sorted = prepare_routing( | ||
| hidden_expanded, selected_experts.ravel(), num_experts, adapter_indices=adapter_expanded | ||
| ) | ||
|
|
||
| # Apply expert layers using LoRAExpert | ||
| gate_out = self.gate_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) | ||
| up_out = self.up_proj(hidden_states_sorted, group_sizes, adapter_indices_sorted) | ||
| down_out = self.down_proj(nnx.silu(gate_out) * up_out, group_sizes, adapter_indices_sorted) | ||
| def forward(experts, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights): | ||
| # Calculate local offset for this shard | ||
| ep_rank = jax.lax.axis_index("ep") | ||
| experts_per_rank = num_experts // jax.lax.axis_size("ep") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This integer division assumes that For example, you could add this check in ep_size = get_abstract_mesh().shape.get("ep", 1)
assert self.config.num_experts % ep_size == 0, f"Number of experts ({self.config.num_experts}) must be divisible by expert parallel size ({ep_size})." |
||
| group_offset = jnp.array([ep_rank * experts_per_rank], dtype=jnp.int32) | ||
|
|
||
| # Expert computation | ||
| gate = experts.gate_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) | ||
| up = experts.up_proj(hidden_sorted, group_sizes, adapter_sorted, group_offset=group_offset) | ||
| down = experts.down_proj(nnx.silu(gate) * up, group_sizes, adapter_sorted, group_offset=group_offset) | ||
|
|
||
| # Unsort and combine | ||
| out = down[unsort_indices].reshape(-1, num_experts_per_tok, hidden_size) | ||
| local_out = jnp.sum(out * routing_weights[..., None], axis=1) | ||
| return jax.lax.psum(local_out, axis_name="ep") | ||
|
|
||
| # Unsort and combine the expert outputs | ||
| unsorted_out = down_out[unsort_indices] | ||
| reshaped_out = unsorted_out.reshape(-1, self.config.num_experts_per_tok, self.config.hidden_size) | ||
| return jnp.sum(reshaped_out * routing_weights[..., None], axis=1) | ||
| return shard_map_ep(self, forward, hidden_sorted, group_sizes, unsort_indices, adapter_sorted, routing_weights) | ||
|
|
||
|
|
||
| class Qwen3MoeSparseMoeBlock(nnx.Module): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This device initialization logic is duplicated in
test_qwen3_moe_layer_lora(lines 117-118). To improve maintainability and reduce redundancy, consider extracting this setup into a shared pytest fixture. The fixture could automatically run for tests that requireepandtpparameters, making the test suite cleaner and easier to manage.