-
Notifications
You must be signed in to change notification settings - Fork 37
Add stochastic mixer for supernet training #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tscholak
wants to merge
28
commits into
main
Choose a base branch
from
stochastic-mixer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,615
−19
Open
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
43729b1
Add stochastic mixer for supernet training
tscholak 8b1eb08
Fix stochastic mixer test failures
tscholak 8ada30b
Fix stochastic mixer checkpoint conversion
tscholak cd1dbf8
Handle lossy HF conversions for stochastic mixer
tscholak d0fd648
Merge remote-tracking branch 'origin/main' into stochastic-mixer
tscholak d693f74
Clean up extra blank line in huggingface.py
tscholak 6962de9
Apply pre-commit formatting
tscholak a96c0cb
Refactor stochastic mixer: set main_mixer_name in validation, preproc…
tscholak 735ee3f
wip
tscholak aed779c
resolve merge conflicts
tscholak 982d409
Implement full stochastic mixer support in Apriel HuggingFace format
tscholak 0d8ab4d
Add Apriel2 checkpoint format and fix weight tying
tscholak bcd93b2
Optimize Apriel2: compute position embeddings and masks per unique block
tscholak ebe75c4
Add HuggingFace generation and caching improvements to Apriel2
tscholak ffd55e5
Add Apriel2DynamicCache for hybrid attention/SSM layer support
tscholak fe259c3
Add Mamba incremental generation support to Apriel2
tscholak 708917d
Add GatedDeltaNet support via Qwen3NextGatedDeltaNet wrapper
tscholak 77ceae2
Standardize naming: recurrent_states and Apriel2 prefixes
tscholak ec95ccc
Remove debug print statements and irrelevant changes
tscholak 571fede
Remove stochastic mixer support from apriel conversion
tscholak 8e7c154
Remove trivial formatting change from apriel_hybrid_ssm config
tscholak 4d0a01b
Remove test changes for lossy HF conversion
tscholak 71cf778
Revert trivial setup.py formatting and restore .eval() calls in tests
tscholak 75847d0
Rename SamplingStrategy to StochasticMixerSamplingStrategy
tscholak eacdf61
Use normalize_probabilities for sampling weights validation
tscholak 192e985
Remove tools/supernet_beam_search.py
tscholak 2fe9596
Fix stochastic mixer sampling to be consistent across all ranks
tscholak acb4751
Add Apriel2Cache with JetNemotron pattern and HF Cache compliance
tscholak File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,24 @@ | ||
| import enum | ||
| import typing | ||
|
|
||
| from fast_llm.config import Field, FieldHint, check_field, config_class | ||
| from fast_llm.engine.config_utils.parameter import combine_lr_scales | ||
| from fast_llm.engine.config_utils.tensor_dim import TensorDim | ||
| from fast_llm.engine.distributed.config import DistributedConfig | ||
| from fast_llm.layers.block.config import BlockConfig | ||
| from fast_llm.layers.block.config import BlockConfig, BlockKwargs | ||
| from fast_llm.layers.common.normalization.config import NormalizationConfig | ||
| from fast_llm.layers.common.peft.config import PeftConfig | ||
| from fast_llm.utils import Assert | ||
|
|
||
| if typing.TYPE_CHECKING: | ||
| from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock | ||
| from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer | ||
|
|
||
|
|
||
| class StochasticMixerKwargs(BlockKwargs): | ||
| """Kwargs keys for stochastic mixer.""" | ||
|
|
||
| mixer_name = "stochastic_mixer_name" | ||
|
|
||
|
|
||
| @config_class() | ||
|
|
@@ -55,6 +63,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi | |
| return super()._from_dict(default, strict=strict) | ||
|
|
||
|
|
||
| class SamplingStrategy(str, enum.Enum): | ||
tscholak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Strategy for sampling mixers in a stochastic mixer.""" | ||
|
|
||
| uniform = "uniform" | ||
| weighted = "weighted" | ||
|
|
||
|
|
||
| @config_class(registry=True) | ||
| class MixerConfig(BlockWithBiasConfig): | ||
| """ | ||
|
|
@@ -71,6 +86,79 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi | |
| return super()._from_dict(default, strict=strict) | ||
|
|
||
|
|
||
| @config_class(dynamic_type={MixerConfig: "stochastic"}) | ||
| class StochasticMixerConfig(MixerConfig): | ||
| """ | ||
| Stochastic mixer that uniformly samples from multiple mixer options during training. | ||
|
|
||
| For supernet training, each forward pass randomly selects one mixer to execute, | ||
| training all mixers with different subsets of data. | ||
| """ | ||
|
|
||
| _abstract = False | ||
|
|
||
| mixers: dict[str, MixerConfig] = Field( | ||
| desc="Dict of mixer options to sample from (must contain at least 1). " | ||
| "Keys are mixer names used for debugging and namespacing.", | ||
| hint=FieldHint.architecture, | ||
| ) | ||
|
|
||
| sampling_strategy: SamplingStrategy = Field( | ||
| default=SamplingStrategy.uniform, | ||
| desc="Strategy for sampling mixers during training.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
|
|
||
| sampling_weights: dict[str, float] | None = Field( | ||
| default=None, | ||
| desc="Sampling probability for each mixer by name (must sum to 1.0). " | ||
| "Only used when sampling_strategy='weighted'. " | ||
| "If None with uniform strategy, all mixers have equal probability.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
|
|
||
| main_mixer_name: str | None = Field( | ||
| default=None, | ||
| desc="Name of the main mixer. " | ||
| "Used for inference/eval, checkpoint loading (receives pretrained weights), " | ||
| "and checkpoint saving (only this mixer is exported). " | ||
| "If None, uses the first mixer in the dict.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
|
|
||
| def _validate(self) -> None: | ||
| super()._validate() | ||
|
|
||
| # Validate mixers dict is not empty | ||
| Assert.gt(len(self.mixers), 0) | ||
|
|
||
| # Set main_mixer_name to first mixer if not specified | ||
| if self.main_mixer_name is None: | ||
| with self._set_implicit_default(): | ||
| self.main_mixer_name = next(iter(self.mixers.keys())) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could just enforce this and make
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like it the way it is |
||
|
|
||
| # Validate main mixer name exists | ||
| if self.main_mixer_name not in self.mixers: | ||
| raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers") | ||
|
|
||
| # Validate sampling weights | ||
| if self.sampling_weights is not None: | ||
| Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys())) | ||
| # Check sum is close to 1.0 | ||
| weight_sum = sum(self.sampling_weights.values()) | ||
tscholak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if abs(weight_sum - 1.0) > 1e-5: | ||
| raise ValueError(f"Sampling weights must sum to 1.0, got {weight_sum}") | ||
| # Check all weights are non-negative | ||
| if any(w < 0 for w in self.sampling_weights.values()): | ||
| raise ValueError("All sampling weights must be non-negative") | ||
|
|
||
| @property | ||
| def layer_class(self) -> "type[StochasticMixer]": | ||
| from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer | ||
|
|
||
| return StochasticMixer | ||
|
|
||
|
|
||
| @config_class(dynamic_type={BlockConfig: "decoder"}) | ||
| class DecoderBlockConfig(BlockConfig): | ||
| _abstract = False | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.