Skip to content

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Nov 10, 2025

Implements a torch definition of Qwen3 and multi-LoRA.

The intent was to closely match the interface and implementation of the jax-based code, with exceptions for differences in the canonical torch library (e.g., the interface to attention is a little different with tensor dimensions in a different order).

This PR focuses on dense Qwen3 with LoRA applied to the linear layers. This PR does not add support for Qwen3 MoE or LoRA in the embedding or expert layers.

There are also performance improvements not included. For example, apply_lora loops over adapter indices and performs individual mm's instead of a raggeddot/gemm. These will be resolved in following PRs.

@tyler-griggs tyler-griggs marked this pull request as ready for review November 10, 2025 02:45
A = torch.empty(*shape_A, dtype=dtype, device=device)
B = torch.zeros(*shape_B, dtype=dtype, device=device)
nn.init.kaiming_uniform_(A, a=math.sqrt(5)) # He-uniform A
self.lora_A = nn.Parameter(A, requires_grad=True)
Copy link
Collaborator

@pcmoritz pcmoritz Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can / if it makes sense to have something a little more like our Param class s we can do Param(*shape, init=..., sharding=...) going forward (we can also experiment with this as a follow up / later). The imperative nn.init.kaiming_uniform_ initialization is slightly ugly :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave this a shot!


if x.dim() != 3:
raise ValueError("x must be [B, T, in_features].")
B, T, in_features = x.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to make sure this can support #511

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it won't right now, but will be updated to support embedding layer adapters

in_features: int,
out_features: int,
*,
max_lora_adapters: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be best if we get rid of all the defaults here going forwards (I know the current code has them, but I don't think it is good, it can only lead to errors if somebody forgets to pass the parameter and the default was not good)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this is a good chance to clean it up. I set all optional arguments to default to None, but otherwise removed the defaults

sorted_adapter_indices = None if adapter_indices is None else adapter_indices[sort_idx]

# Compute group sizes (minlength guarantees output length)
sorted_indices = indices[sort_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is not needed, right? E.g. bincount is permutation invariant so we can just pass indices in there and otherwise we don't need sorted_indices.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good catch

sorted_adapter_indices: Adapter indices sorted with tokens (or None if not provided)
"""
# Sort by group index
sort_idx = torch.argsort(indices)
Copy link
Collaborator

@pcmoritz pcmoritz Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this sort_indices so it is consistent with unsort_indices below and make clear it is multiple indices? (or alternatively maybe sort_perm and unsort_perm for "permutation" if you prefer)?


updated_cache = (k, v)

# Attention (causal only during prefill, GQA handled via repeat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to just set enable_gqa=True in scaled_dot_product_attention, right?

def __init__(self, config: Qwen3Config, *, dtype: torch.dtype):
super().__init__()
self.config = config
max_lora_adapters = getattr(config, "max_lora_adapters", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since #636, you don't need these any more, you can just do config.{max_lora_adapters, max_lora_rank), and it is probably easiest to just inline it in the Qwen3DecoderLayer constructor below


# Load all safetensors files
state_dict = {}
from pathlib import Path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to the top?

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Before merging it, let's move it to tx/extra/torch/{....} and tests/extra/torch/... until we replace to jax model definitions to make it clear to contributors that this is not the main code path yet and avoid confusions?

I also made some small comments :)

@tyler-griggs tyler-griggs changed the title [WIP] Torch definition for qwen3 and LoRA [WIP] [tx] Torch definition for qwen3 and LoRA Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants