Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
43729b1
Add stochastic mixer for supernet training
tscholak Oct 12, 2025
8b1eb08
Fix stochastic mixer test failures
tscholak Oct 14, 2025
8ada30b
Fix stochastic mixer checkpoint conversion
tscholak Oct 14, 2025
cd1dbf8
Handle lossy HF conversions for stochastic mixer
tscholak Nov 12, 2025
d0fd648
Merge remote-tracking branch 'origin/main' into stochastic-mixer
tscholak Nov 12, 2025
d693f74
Clean up extra blank line in huggingface.py
tscholak Nov 12, 2025
6962de9
Apply pre-commit formatting
tscholak Nov 12, 2025
a96c0cb
Refactor stochastic mixer: set main_mixer_name in validation, preproc…
tscholak Nov 12, 2025
735ee3f
wip
tscholak Nov 15, 2025
aed779c
resolve merge conflicts
tscholak Nov 20, 2025
982d409
Implement full stochastic mixer support in Apriel HuggingFace format
tscholak Nov 20, 2025
0d8ab4d
Add Apriel2 checkpoint format and fix weight tying
tscholak Nov 21, 2025
bcd93b2
Optimize Apriel2: compute position embeddings and masks per unique block
tscholak Nov 21, 2025
ebe75c4
Add HuggingFace generation and caching improvements to Apriel2
tscholak Nov 21, 2025
ffd55e5
Add Apriel2DynamicCache for hybrid attention/SSM layer support
tscholak Nov 21, 2025
fe259c3
Add Mamba incremental generation support to Apriel2
tscholak Nov 21, 2025
708917d
Add GatedDeltaNet support via Qwen3NextGatedDeltaNet wrapper
tscholak Nov 21, 2025
77ceae2
Standardize naming: recurrent_states and Apriel2 prefixes
tscholak Nov 21, 2025
ec95ccc
Remove debug print statements and irrelevant changes
tscholak Nov 21, 2025
571fede
Remove stochastic mixer support from apriel conversion
tscholak Nov 21, 2025
8e7c154
Remove trivial formatting change from apriel_hybrid_ssm config
tscholak Nov 21, 2025
4d0a01b
Remove test changes for lossy HF conversion
tscholak Nov 21, 2025
71cf778
Revert trivial setup.py formatting and restore .eval() calls in tests
tscholak Nov 21, 2025
75847d0
Rename SamplingStrategy to StochasticMixerSamplingStrategy
tscholak Nov 21, 2025
eacdf61
Use normalize_probabilities for sampling weights validation
tscholak Nov 21, 2025
192e985
Remove tools/supernet_beam_search.py
tscholak Nov 21, 2025
2fe9596
Fix stochastic mixer sampling to be consistent across all ranks
tscholak Nov 22, 2025
acb4751
Add Apriel2Cache with JetNemotron pattern and HF Cache compliance
tscholak Nov 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ assignees: ''
---

# 🎯 **Goal (What & Why)**
> **Clearly state the purpose of this feature.**
> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_

# 🚀 **Execution Plan**
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_

### **Step 1: What is the smallest working version?**
> _(Describe the simplest way to implement this feature with minimal effort.)_
> _(Describe the simplest way to implement this feature with minimal effort.)_

### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_

# 📌 **Acceptance Criteria** (Must-Haves for Completion)
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.

# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
- [ ] **Assign an owner when opening the issue.**
- [ ] **Assign an owner when opening the issue.**
14 changes: 7 additions & 7 deletions .github/workflows/manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true

- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }}

- name: Get commit info
id: commit_info
run: |
Expand All @@ -48,7 +48,7 @@ jobs:
echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT
echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT
echo "Building from commit: ${COMMIT_SHA}"

- name: Docker meta
id: meta
uses: docker/metadata-action@v5
Expand All @@ -59,18 +59,18 @@ jobs:
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }}
type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GHCR
if: ${{ inputs.push_image }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand All @@ -80,7 +80,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max

- name: Output build info
run: |
echo "Built Docker image with tags:"
Expand Down
75 changes: 75 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
Expand All @@ -11,6 +12,7 @@

if typing.TYPE_CHECKING:
from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer


@config_class()
Expand Down Expand Up @@ -55,6 +57,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


class SamplingStrategy(str, enum.Enum):
"""Strategy for sampling mixers in a stochastic mixer."""

uniform = "uniform"
weighted = "weighted"


@config_class(registry=True)
class MixerConfig(BlockWithBiasConfig):
"""
Expand All @@ -71,6 +80,72 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


@config_class(dynamic_type={MixerConfig: "stochastic"})
class StochasticMixerConfig(MixerConfig):
"""
Stochastic mixer that uniformly samples from multiple mixer options during training.

For supernet training, each forward pass randomly selects one mixer to execute,
training all mixers with different subsets of data.
"""

_abstract = False

mixers: list[MixerConfig] = Field(
desc="List of mixer options to sample from (must contain at least 1).",
hint=FieldHint.architecture,
)

sampling_strategy: SamplingStrategy = Field(
default=SamplingStrategy.uniform,
desc="Strategy for sampling mixers during training.",
hint=FieldHint.feature,
)

sampling_weights: list[float] | None = Field(
default=None,
desc="Sampling probability for each mixer (must sum to 1.0). "
"Only used when sampling_strategy='weighted'. "
"If None with uniform strategy, all mixers have equal probability.",
hint=FieldHint.feature,
)

main_mixer_index: int = Field(
default=0,
desc="Index of the main mixer. "
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
super()._validate()

# Validate mixers list is not empty
Assert.gt(len(self.mixers), 0)

# Validate sampling weights
if self.sampling_weights is not None:
Assert.eq(len(self.sampling_weights), len(self.mixers))
# Check sum is close to 1.0
weight_sum = sum(self.sampling_weights)
if abs(weight_sum - 1.0) > 1e-5:
raise ValueError(f"Sampling weights must sum to 1.0, got {weight_sum}")
# Check all weights are non-negative
if any(w < 0 for w in self.sampling_weights):
raise ValueError("All sampling weights must be non-negative")

# Validate main mixer index
Assert.lt(self.main_mixer_index, len(self.mixers))

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer

return StochasticMixer


@config_class(dynamic_type={BlockConfig: "decoder"})
class DecoderBlockConfig(BlockConfig):
_abstract = False
Expand Down
193 changes: 193 additions & 0 deletions fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import logging
import typing

import torch

from fast_llm.core.distributed import set_generator
from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.layers.decoder.config import SamplingStrategy, StochasticMixerConfig
from fast_llm.tensor import TensorMeta

logger = logging.getLogger(__name__)


class StochasticMixer[ConfigType: StochasticMixerConfig](BlockWithBias[ConfigType]):
"""
A mixer that stochastically samples from multiple mixer options during training.
In training mode, each forward pass randomly selects one mixer according to
the sampling strategy. In eval mode, uses the configured inference mixer.
This is useful for supernet training where you want to train multiple
architecture variants (e.g., attention vs. Mamba) with different data subsets.
"""

_config: ConfigType

def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_bias: bool = True,
):
super().__init__(
config,
distributed_config,
hidden_dim=hidden_dim,
lr_scale=lr_scale,
peft=peft,
return_bias=return_bias,
)

# Initialize all mixers
self.mixers = torch.nn.ModuleList(
[
mixer_config.get_layer(
distributed_config,
hidden_dim,
lr_scale,
peft=peft,
return_bias=return_bias,
)
for mixer_config in self._config.mixers
]
)

# Precompute sampling probabilities as a tensor
if self._config.sampling_strategy == SamplingStrategy.uniform:
self._sampling_probs = torch.ones(len(self.mixers)) / len(self.mixers)
elif self._config.sampling_strategy == SamplingStrategy.weighted:
if self._config.sampling_weights is None:
raise ValueError("sampling_weights must be provided when using weighted sampling strategy")
self._sampling_probs = torch.tensor(self._config.sampling_weights, dtype=torch.float32)
else:
raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented")

logger.info(
f"Initialized StochasticMixer with {len(self.mixers)} mixers: "
f"{[type(m).__name__ for m in self.mixers]}"
)

# Mark all mixer parameters with allow_no_grad since only one mixer
# is active per forward pass during training. Even though all mixers
# will eventually be trained, on any single forward pass, the non-selected
# mixers won't receive gradients.
for mixer in self.mixers:
for param in mixer.parameters(recurse=True):
if hasattr(param, 'allow_no_grad'):
param.allow_no_grad = True

def setup(self, distributed: Distributed) -> None:
"""Setup all mixers with the distributed context."""
super().setup(distributed)
for mixer in self.mixers:
mixer.setup(distributed)

def _sample_mixer_index(self) -> int:
"""
Sample a mixer index according to the configured strategy.
Returns:
Index of the mixer to use for this forward pass.
"""
if not self.training:
# Inference mode: use the configured main mixer
return self._config.main_mixer_index

# Training mode: stochastic sampling
# Use distributed RNG to ensure consistency across TP/PP ranks
# This ensures all ranks in a TP/PP group use the same mixer
generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's right, tp_generator will result in different tensor ranks selecting different mixers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also do we actually want different DP ranks / micro-batches to select different sets? I guess this increases randomness but it will affect reproducibility and prevent distributed tests from working.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for checking this. what we want/need is that all ranks sample the same mixer for each batch. How can that be done? I thought that's what the tp generator does. maybe it does the exact opposite, and all tp ranks do it differently?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TP generator is meant for TP tensors, which need different random numbers for each TP ranks. (Ex. for we want different dropouts for different slices of a tensor).

I had a second look, and I don't think any existing generator can provide consistent mixers for a given batch. pp_generator gives inconsistent results between DP ranks and gradient accumulation steps, but is probably still the best option. The CPU generator on the other hand is consistent between DP ranks (not grad accumulation), but is not reproducible (see Distributed.set_step). Getting consistency for a whole batch would require a custom seed/generator, and access to the current training step in preprocess. (I guess preprocess_batch could add it to the kwargs).

Another issue I'm seeing, preprocess is called only once for all layers so I think the current approach will result in all layer choosing the same mixer. And fixing is non-trivial since the preprocessor doesn't know the number of layers and the mixer doesn't know about its layer index. So my suggestion would be to go with a custom generator, seed it in preprocess using the step index, then generate on the fly in forward


with set_generator(generator):
# Sample from categorical distribution
idx = torch.multinomial(self._sampling_probs, num_samples=1).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now done during preprocessing


return idx

def _forward(
self,
input_: torch.Tensor,
kwargs: dict[str, typing.Any],
losses: dict[str, typing.Any] | None = None,
metrics: dict[str, typing.Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Forward pass through a randomly selected mixer.
Args:
input_: Input tensor
kwargs: Forward pass arguments
losses: Optional dictionary to store losses
metrics: Optional dictionary to store metrics
Returns:
Tuple of (output tensor, bias tensor or None)
"""
# Sample which mixer to use
mixer_idx = self._sample_mixer_index()

if self._debug.enabled:
logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous if multiple mixers share the same type. Use named mixers instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now using named mixers. we retrieve mixer_name from kwargs (line 151) and use it for logging (line 160) and accessing the correct mixer (line 163).


# Forward through selected mixer
return self.mixers[mixer_idx]._forward(input_, kwargs, losses, metrics)

def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
"""
Preprocess for all mixers.
Since we don't know which mixer will be selected during training,
we need to preprocess for all of them. This includes things like
attention masks, rotary embeddings, etc.
"""
for mixer in self.mixers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be name conflicts. Consider namespace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now namespaced. see lines 214-216 where we prefix with f"{mixer_name}/{loss_def.name}".

mixer.preprocess(batch, kwargs)

def get_compute_usage(
self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig
) -> int:
"""
Return expected compute usage (weighted average of all mixers).
This gives a more accurate estimate than just using one mixer,
since during training we'll be using all of them according to
their sampling probabilities.
"""
usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers]

# Weight by sampling probability and return the expected value
expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs))

return int(expected_usage)

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. The current approach ensures we allocate space for all possible losses, but you're right that counts won't match actual usage since only one mixer runs per forward pass. We could track which mixer was use and only record its losses, but that adds complexity. I think what we have is good enough for now.

"""
Merge loss definitions from all mixers.
Returns the union of all loss definitions, deduplicated by name.
This ensures we allocate space for any auxiliary losses that any
of the mixers might need.
"""
all_losses = []
for mixer in self.mixers:
all_losses.extend(mixer.get_loss_definitions(count=count))

# Deduplicate by loss name
seen = set()
unique_losses = []
for loss_def in all_losses:
if loss_def.name not in seen:
seen.add(loss_def.name)
unique_losses.append(loss_def)

return unique_losses
1 change: 1 addition & 0 deletions fast_llm/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Import these submodules to ensure classes are added to the dynamic class registry.
"""

from fast_llm.layers.attention.config import AttentionConfig # isort: skip
from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip
from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip
Loading