Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
43729b1
Add stochastic mixer for supernet training
tscholak Oct 12, 2025
8b1eb08
Fix stochastic mixer test failures
tscholak Oct 14, 2025
8ada30b
Fix stochastic mixer checkpoint conversion
tscholak Oct 14, 2025
cd1dbf8
Handle lossy HF conversions for stochastic mixer
tscholak Nov 12, 2025
d0fd648
Merge remote-tracking branch 'origin/main' into stochastic-mixer
tscholak Nov 12, 2025
d693f74
Clean up extra blank line in huggingface.py
tscholak Nov 12, 2025
6962de9
Apply pre-commit formatting
tscholak Nov 12, 2025
a96c0cb
Refactor stochastic mixer: set main_mixer_name in validation, preproc…
tscholak Nov 12, 2025
735ee3f
wip
tscholak Nov 15, 2025
aed779c
resolve merge conflicts
tscholak Nov 20, 2025
982d409
Implement full stochastic mixer support in Apriel HuggingFace format
tscholak Nov 20, 2025
0d8ab4d
Add Apriel2 checkpoint format and fix weight tying
tscholak Nov 21, 2025
bcd93b2
Optimize Apriel2: compute position embeddings and masks per unique block
tscholak Nov 21, 2025
ebe75c4
Add HuggingFace generation and caching improvements to Apriel2
tscholak Nov 21, 2025
ffd55e5
Add Apriel2DynamicCache for hybrid attention/SSM layer support
tscholak Nov 21, 2025
fe259c3
Add Mamba incremental generation support to Apriel2
tscholak Nov 21, 2025
708917d
Add GatedDeltaNet support via Qwen3NextGatedDeltaNet wrapper
tscholak Nov 21, 2025
77ceae2
Standardize naming: recurrent_states and Apriel2 prefixes
tscholak Nov 21, 2025
ec95ccc
Remove debug print statements and irrelevant changes
tscholak Nov 21, 2025
571fede
Remove stochastic mixer support from apriel conversion
tscholak Nov 21, 2025
8e7c154
Remove trivial formatting change from apriel_hybrid_ssm config
tscholak Nov 21, 2025
4d0a01b
Remove test changes for lossy HF conversion
tscholak Nov 21, 2025
71cf778
Revert trivial setup.py formatting and restore .eval() calls in tests
tscholak Nov 21, 2025
75847d0
Rename SamplingStrategy to StochasticMixerSamplingStrategy
tscholak Nov 21, 2025
eacdf61
Use normalize_probabilities for sampling weights validation
tscholak Nov 21, 2025
192e985
Remove tools/supernet_beam_search.py
tscholak Nov 21, 2025
2fe9596
Fix stochastic mixer sampling to be consistent across all ranks
tscholak Nov 22, 2025
acb4751
Add Apriel2Cache with JetNemotron pattern and HF Cache compliance
tscholak Nov 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BlockKwargs:
sequence_lengths = "sequence_lengths"
# TODO: Belongs elsewhere?
grad_output = "grad_output"
iteration = "iteration"


@config_class(registry=True)
Expand Down
89 changes: 87 additions & 2 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.parameter import combine_lr_scales
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.block.config import BlockConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.normalization.config import NormalizationConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import Assert
from fast_llm.utils import Assert, normalize_probabilities

if typing.TYPE_CHECKING:
from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer


class StochasticMixerKwargs(BlockKwargs):
"""Kwargs keys for stochastic mixer."""

mixer_name = "stochastic_mixer_name"
generator = "stochastic_mixer_generator"


@config_class()
Expand Down Expand Up @@ -55,6 +64,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


class StochasticMixerSamplingStrategy(str, enum.Enum):
"""Strategy for sampling mixers in a stochastic mixer."""

uniform = "uniform"
weighted = "weighted"


@config_class(registry=True)
class MixerConfig(BlockWithBiasConfig):
"""
Expand All @@ -71,6 +87,75 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


@config_class(dynamic_type={MixerConfig: "stochastic"})
class StochasticMixerConfig(MixerConfig):
"""
Stochastic mixer that uniformly samples from multiple mixer options during training.

For supernet training, each forward pass randomly selects one mixer to execute,
training all mixers with different subsets of data.
"""

_abstract = False

mixers: dict[str, MixerConfig] = Field(
desc="Dict of mixer options to sample from (must contain at least 1). "
"Keys are mixer names used for debugging and namespacing.",
hint=FieldHint.architecture,
)

sampling_strategy: StochasticMixerSamplingStrategy = Field(
default=StochasticMixerSamplingStrategy.uniform,
desc="Strategy for sampling mixers during training.",
hint=FieldHint.feature,
)

sampling_weights: dict[str, float] | None = Field(
default=None,
desc="Sampling probability for each mixer by name (will be normalized to sum to 1.0). "
"Only used when sampling_strategy='weighted'. "
"If None with uniform strategy, all mixers have equal probability.",
hint=FieldHint.feature,
)

main_mixer_name: str | None = Field(
default=None,
desc="Name of the main mixer. "
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported). "
"If None, uses the first mixer in the dict.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
super()._validate()

# Validate mixers dict is not empty
Assert.gt(len(self.mixers), 0)

# Set main_mixer_name to first mixer if not specified
if self.main_mixer_name is None:
with self._set_implicit_default():
self.main_mixer_name = next(iter(self.mixers.keys()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just enforce this and make main_mixer_name a cached property instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it the way it is


# Validate main mixer name exists
if self.main_mixer_name not in self.mixers:
raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers")

# Validate and normalize sampling weights
if self.sampling_weights is not None:
Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys()))
# Normalize weights to sum to 1.0 (also validates non-negative and positive sum)
normalized_values = normalize_probabilities(list(self.sampling_weights.values()))
self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values))

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer

return StochasticMixer


@config_class(dynamic_type={BlockConfig: "decoder"})
class DecoderBlockConfig(BlockConfig):
_abstract = False
Expand Down
167 changes: 167 additions & 0 deletions fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import logging
import typing

import torch

from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.layers.decoder.config import StochasticMixerConfig, StochasticMixerKwargs, StochasticMixerSamplingStrategy
from fast_llm.tensor import TensorMeta

logger = logging.getLogger(__name__)


class StochasticMixer[ConfigType: StochasticMixerConfig](BlockWithBias[ConfigType]):
"""
A mixer that stochastically samples from multiple mixer options during training.

In training mode, each forward pass randomly selects one mixer according to
the sampling strategy. In eval mode, uses the configured inference mixer.

This is useful for supernet training where you want to train multiple
architecture variants (e.g., attention vs. Mamba) with different data subsets.
"""

_config: ConfigType

def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_bias: bool = True,
):
super().__init__(
config,
distributed_config,
hidden_dim=hidden_dim,
lr_scale=lr_scale,
peft=peft,
return_bias=return_bias,
)

# Initialize all mixers
self.mixers = torch.nn.ModuleDict(
{
name: mixer_config.get_layer(
distributed_config,
hidden_dim,
lr_scale,
peft=peft,
return_bias=return_bias,
)
for name, mixer_config in self._config.mixers.items()
}
)

if self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform:
self._sampling_probs = torch.ones(len(self.mixers), device="cpu") / len(self.mixers)
elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.weighted:
if self._config.sampling_weights is None:
raise ValueError("sampling_weights must be provided when using weighted sampling strategy")
self._sampling_probs = torch.tensor(
[self._config.sampling_weights[name] for name in self.mixers.keys()],
dtype=torch.float32,
device="cpu",
)
else:
raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented")

logger.info(
f"Initialized StochasticMixer with {len(self.mixers)} mixers: "
f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} "
f"(main={self._config.main_mixer_name})"
)

# Mark all mixer parameters with allow_no_grad since only one mixer
# is active per forward pass during training. Even though all mixers
# will eventually be trained, on any single forward pass, the non-selected
# mixers won't receive gradients.
for mixer in self.mixers.values():
for param in mixer.parameters(recurse=True):
if hasattr(param, "allow_no_grad"):
param.allow_no_grad = True

def setup(self, distributed: Distributed) -> None:
"""Setup all mixers with the distributed context."""
super().setup(distributed)
for mixer in self.mixers.values():
mixer.setup(distributed)

def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str:
if not self.training:
return self._config.main_mixer_name

generator = kwargs[StochasticMixerKwargs.generator]
mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item()
return list(self.mixers.keys())[mixer_idx]

def _forward(
self,
input_: torch.Tensor,
kwargs: dict[str, typing.Any],
losses: dict[str, typing.Any] | None = None,
metrics: dict[str, typing.Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
mixer_name = self._sample_mixer_name(kwargs)

if self._debug.enabled:
logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}")

return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics)

def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
from fast_llm.layers.block.config import BlockKwargs

iteration = kwargs[BlockKwargs.iteration]
generator = torch.Generator(device="cpu")
generator.manual_seed(iteration)
kwargs[StochasticMixerKwargs.generator] = generator

for mixer in self.mixers.values():
mixer.preprocess(batch, kwargs)

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
"""
Return expected compute usage (weighted average of all mixers).

This gives a more accurate estimate than just using one mixer,
since during training we'll be using all of them according to
their sampling probabilities.
"""
usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()]

# Weight by sampling probability and return the expected value
expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs))

return int(expected_usage)

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. The current approach ensures we allocate space for all possible losses, but you're right that counts won't match actual usage since only one mixer runs per forward pass. We could track which mixer was use and only record its losses, but that adds complexity. I think what we have is good enough for now.

"""
Merge loss definitions from all mixers with namespacing.

Each mixer's losses are namespaced with the mixer name to avoid conflicts.
This ensures we allocate space for any auxiliary losses that any
of the mixers might need, even if multiple mixers have losses with the same name.
"""
all_losses = []
for mixer_name, mixer in self.mixers.items():
mixer_losses = mixer.get_loss_definitions(count=count)
# Namespace each loss with the mixer name to avoid conflicts
for loss_def in mixer_losses:
namespaced_loss = LossDef(
name=f"{mixer_name}/{loss_def.name}",
formatted_name=f"{mixer_name}/{loss_def.formatted_name}",
count=loss_def.count,
dtype=loss_def.dtype,
)
all_losses.append(namespaced_loss)

return all_losses
1 change: 1 addition & 0 deletions fast_llm/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Import these submodules to ensure classes are added to the dynamic class registry.
"""

from fast_llm.layers.attention.config import AttentionConfig # isort: skip
from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip
from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip
2 changes: 2 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig
from fast_llm.models.gpt.conversion.config import (
Apriel2CheckpointFormat,
AprielHybridSSMCheckpointFormat,
AutoGPTHuggingfaceCheckpointFormat,
DiffusionDreamCheckpointFormat,
Expand Down Expand Up @@ -117,6 +118,7 @@ class GPTModelConfig(FastLLMModelConfig):
DiffusionDreamCheckpointFormat,
DiffusionLlamaCheckpointFormat,
AprielHybridSSMCheckpointFormat,
Apriel2CheckpointFormat,
)

@classmethod
Expand Down
Loading