-
Notifications
You must be signed in to change notification settings - Fork 220
[WIP] [tx] Torch definition for qwen3 and LoRA #649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] [tx] Torch definition for qwen3 and LoRA #649
Conversation
skyrl-tx/tx/torch/layers/lora.py
Outdated
| A = torch.empty(*shape_A, dtype=dtype, device=device) | ||
| B = torch.zeros(*shape_B, dtype=dtype, device=device) | ||
| nn.init.kaiming_uniform_(A, a=math.sqrt(5)) # He-uniform A | ||
| self.lora_A = nn.Parameter(A, requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can / if it makes sense to have something a little more like our Param class s we can do Param(*shape, init=..., sharding=...) going forward (we can also experiment with this as a follow up / later). The imperative nn.init.kaiming_uniform_ initialization is slightly ugly :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gave this a shot!
|
|
||
| if x.dim() != 3: | ||
| raise ValueError("x must be [B, T, in_features].") | ||
| B, T, in_features = x.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to make sure this can support #511
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it won't right now, but will be updated to support embedding layer adapters
skyrl-tx/tx/torch/layers/lora.py
Outdated
| in_features: int, | ||
| out_features: int, | ||
| *, | ||
| max_lora_adapters: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would probably be best if we get rid of all the defaults here going forwards (I know the current code has them, but I don't think it is good, it can only lead to errors if somebody forgets to pass the parameter and the default was not good)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, this is a good chance to clean it up. I set all optional arguments to default to None, but otherwise removed the defaults
skyrl-tx/tx/torch/layers/util.py
Outdated
| sorted_adapter_indices = None if adapter_indices is None else adapter_indices[sort_idx] | ||
|
|
||
| # Compute group sizes (minlength guarantees output length) | ||
| sorted_indices = indices[sort_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this is not needed, right? E.g. bincount is permutation invariant so we can just pass indices in there and otherwise we don't need sorted_indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good catch
| sorted_adapter_indices: Adapter indices sorted with tokens (or None if not provided) | ||
| """ | ||
| # Sort by group index | ||
| sort_idx = torch.argsort(indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this sort_indices so it is consistent with unsort_indices below and make clear it is multiple indices? (or alternatively maybe sort_perm and unsort_perm for "permutation" if you prefer)?
|
|
||
| updated_cache = (k, v) | ||
|
|
||
| # Attention (causal only during prefill, GQA handled via repeat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to just set enable_gqa=True in scaled_dot_product_attention, right?
| def __init__(self, config: Qwen3Config, *, dtype: torch.dtype): | ||
| super().__init__() | ||
| self.config = config | ||
| max_lora_adapters = getattr(config, "max_lora_adapters", 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since #636, you don't need these any more, you can just do config.{max_lora_adapters, max_lora_rank), and it is probably easiest to just inline it in the Qwen3DecoderLayer constructor below
|
|
||
| # Load all safetensors files | ||
| state_dict = {} | ||
| from pathlib import Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to the top?
pcmoritz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Before merging it, let's move it to tx/extra/torch/{....} and tests/extra/torch/... until we replace to jax model definitions to make it clear to contributors that this is not the main code path yet and avoid confusions?
I also made some small comments :)
Implements a torch definition of Qwen3 and multi-LoRA.
The intent was to closely match the interface and implementation of the jax-based code, with exceptions for differences in the canonical torch library (e.g., the interface to attention is a little different with tensor dimensions in a different order).
This PR focuses on dense Qwen3 with LoRA applied to the linear layers. This PR does not add support for Qwen3 MoE or LoRA in the embedding or expert layers.
There are also performance improvements not included. For example,
apply_loraloops over adapter indices and performs individual mm's instead of a raggeddot/gemm. These will be resolved in following PRs.