-
Notifications
You must be signed in to change notification settings - Fork 37
Add stochastic mixer for supernet training #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements a stochastic mixer layer that randomly samples from multiple mixer options during training, enabling supernet training where different architecture variants (e.g., attention vs. Mamba) are trained with different data subsets. Key components: - StochasticMixerConfig: Configuration for stochastic sampling strategy (uniform or weighted) with configurable main_mixer_index for inference - StochasticMixer: Layer implementation with distributed RNG support - Checkpoint conversion: Apriel converter handles stochastic mixers - Beam search tool: Hierarchical beam search for optimal mixer placement The beam search tool finds which layers benefit most from expensive mixers (e.g., full attention) vs. efficient mixers (e.g., linear attention) by evaluating different configurations using Fast-LLM's evaluation system. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Fix Assert.gt_len AttributeError by moving validation to _validate() method - Add AttentionConfig import to models/auto.py for proper registration - Mark all mixer parameters with allow_no_grad=True since only one mixer is active per forward pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Fixed nested config structure bug in AprielStochasticMixerConverter.import_config that was causing validation errors when loading Apriel checkpoints. The converter was returning the entire block config (with mixer, mlp, and normalization keys) instead of just the mixer config, causing these fields to be incorrectly nested under the mixer field during import. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
jlamypoirier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some minor comments
|
|
||
| with set_generator(generator): | ||
| # Sample from categorical distribution | ||
| idx = torch.multinomial(self._sampling_probs, num_samples=1).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now done during preprocessing
| mixer_idx = self._sample_mixer_index() | ||
|
|
||
| if self._debug.enabled: | ||
| logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous if multiple mixers share the same type. Use named mixers instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now using named mixers. we retrieve mixer_name from kwargs (line 151) and use it for logging (line 160) and accessing the correct mixer (line 163).
| we need to preprocess for all of them. This includes things like | ||
| attention masks, rotary embeddings, etc. | ||
| """ | ||
| for mixer in self.mixers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be name conflicts. Consider namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now namespaced. see lines 214-216 where we prefix with f"{mixer_name}/{loss_def.name}".
|
|
||
| return int(expected_usage) | ||
|
|
||
| def get_loss_definitions(self, count: int = 1) -> list[LossDef]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Acknowledged. The current approach ensures we allocate space for all possible losses, but you're right that counts won't match actual usage since only one mixer runs per forward pass. We could track which mixer was use and only record its losses, but that adds complexity. I think what we have is good enough for now.
| mixer_converter_class.get_converters( | ||
| mixer, | ||
| f"{fast_llm_prefix}.mixers.{mixer_index}", | ||
| hf_prefix if is_main_mixer else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf_prefix. drop_on_export handles the rest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now uses just hf_prefix without the mixer name prefix.
| f"{hf_prefix}.{block_index}", | ||
| drop_on_export, | ||
| ) | ||
| match config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think match is warranted here, since it involves a (slow) initialization of configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uses if-else with instance type check now.
- Add _is_lossy_hf_conversion() utility to detect when HF conversion drops weights - Skip incompatible tests (test_converted_round_trip, test_load_pretrained) for lossy conversions - Check converters for IgnoreExportWeightConverter instances - Factor out config loading into _load_config_from_test_dir() and _load_config_from_checkpoint() - Export main_mixer_type in stochastic mixer config for HF compatibility
# Conflicts: # fast_llm/models/gpt/conversion/apriel.py
…ess only selected mixer, remove caching
| # Set main_mixer_name to first mixer if not specified | ||
| if self.main_mixer_name is None: | ||
| with self._set_implicit_default(): | ||
| self.main_mixer_name = next(iter(self.mixers.keys())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just enforce this and make main_mixer_name a cached property instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it the way it is
| # Sample index in training mode | ||
| generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator | ||
| # Move sampling_probs to the same device as the generator for multinomial | ||
| sampling_probs_device = self._sampling_probs.to(generator.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just use a cpu generator instead? Or at the very least move to device in setup
| return self._config.main_mixer_name | ||
|
|
||
| # Sample index in training mode | ||
| generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's right, tp_generator will result in different tensor ranks selecting different mixers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also do we actually want different DP ranks / micro-batches to select different sets? I guess this increases randomness but it will affect reproducibility and prevent distributed tests from working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for checking this. what we want/need is that all ranks sample the same mixer for each batch. How can that be done? I thought that's what the tp generator does. maybe it does the exact opposite, and all tp ranks do it differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TP generator is meant for TP tensors, which need different random numbers for each TP ranks. (Ex. for we want different dropouts for different slices of a tensor).
I had a second look, and I don't think any existing generator can provide consistent mixers for a given batch. pp_generator gives inconsistent results between DP ranks and gradient accumulation steps, but is probably still the best option. The CPU generator on the other hand is consistent between DP ranks (not grad accumulation), but is not reproducible (see Distributed.set_step). Getting consistency for a whole batch would require a custom seed/generator, and access to the current training step in preprocess. (I guess preprocess_batch could add it to the kwargs).
Another issue I'm seeing, preprocess is called only once for all layers so I think the current approach will result in all layer choosing the same mixer. And fixing is non-trivial since the preprocessor doesn't know the number of layers and the mixer doesn't know about its layer index. So my suggestion would be to go with a custom generator, seed it in preprocess using the step index, then generate on the fly in forward
- Extended AprielHybridSSMConfig to support nested lists in hybrid_block_layout for stochastic mixers - Created AprielStochasticDecoderLayer that directly instantiates mixer modules (MistralAttention, Mamba2) - Updated AprielHybridSSMModel to detect and instantiate stochastic layers from nested lists - Updated AprielStochasticMixerConverter to export all mixer weights with correct HF prefixes: * Attention mixers → self_attn * Non-attention mixers → mixer - Removed drop_on_export workaround - now properly exports all mixer weights - Updated converter to generate nested lists in config and import them back correctly - Fixed enum serialization for sampling_strategy in config export - Updated test fixture to use HF layout names (t, m2) as mixer names - Removed initialization workarounds (now exports full weights instead) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Implement new Apriel2 HuggingFace checkpoint format that mirrors Fast-LLM's hierarchical config structure with declarative mixer/block definitions. New features: - Apriel2Config and Apriel2ForCausalLM with pattern decoder support - Full conversion support for attention, mamba, and stochastic mixers - get_block_config() method for per-layer configuration access Fixes: - Fix weight tying: add drop_on_export flag to skip lm_head.weight when tied - Fix Apriel2Config.get_text_config() to return self for proper tie_word_embeddings access - Remove stochastic mixer support from apriel_hybrid_ssm (HuggingFace side) Testing: - Add apriel2_mixed test config with tied_embedding_weight=True - Add debug prints for weight comparison (to be removed) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Performance optimization: - Compute RoPE position embeddings once per unique block type (O(unique_blocks)) instead of per layer (O(num_layers)) - Compute causal masks once per unique block type instead of per layer - For models with 32 layers and 2 unique blocks: 16x reduction in computation Architecture changes: - Build shared rotary_embs ModuleDict at Apriel2Model level (one per unique attention block) - Use nested ModuleDicts for stochastic mixers instead of dot notation (PyTorch doesn't allow dots in module names) - Separate top-level dicts for position_embeddings and attention_masks for cleaner API - Each layer receives only the data it needs (direct value or nested dict for stochastic mixers) Code improvements: - Remove create_attention_from_config() indirection - Remove all debug prints - Use config._attn_implementation instead of hardcoding "eager" - Add get_block_name() helper to Apriel2Config - Factored out _create_rotary_emb_for_attention() and _build_attn_config_for_mask() - Type annotations: dict[str, Any], Optional[Union[torch.Tensor, BlockMask]] 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Add infrastructure for efficient generation: - Apriel2PreTrainedModel base class with cache support flags (_supports_flash_attn_2, _supports_sdpa, _supports_flex_attn, _supports_cache_class, _supports_quantized_cache, _supports_static_cache) - GenerationMixin inheritance for Apriel2ForCausalLM - FlashAttentionKwargs support via Unpack[FlashAttentionKwargs] - cache_position parameter throughout forward methods for efficient KV cache indexing - logits_to_keep optimization (only compute logits for last N tokens during generation) Implementation follows Mistral's pattern: - slice_indices = slice(-logits_to_keep, None) for clean slicing - Only upcast logits to float when computing loss - Use outputs.last_hidden_state instead of outputs[0] Note: Custom cache class for hybrid attention/SSM layers (Mamba, GatedDeltaNet) to be implemented in follow-up commit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Infrastructure for incremental generation with mixed architectures: - Apriel2DynamicCache class that handles both attention and linear attention layers - Separate storage for attention (key_cache, value_cache) and SSM layers (conv_states, ssm_states) - Automatically determines mixer type per layer (attention, mamba, gated_delta_net, etc.) - For stochastic mixers, uses main_mixer type - Implements get_seq_length(), reorder_cache() for beam search - Auto-initialize cache in Apriel2Model.forward() when use_cache=True - Follows Qwen3Next pattern: initialize in model forward, not in prepare_inputs_for_generation - Cleaner than custom prepare_inputs_for_generation Note: Mamba/GatedDeltaNet layers not yet updated to read/write cache states. Will be implemented in follow-up commit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Implement caching for Mamba layers to enable efficient incremental generation: - Update Mamba.forward() to support both full sequence and incremental modes - Add step() method for single-token generation using selective_state_update - Add allocate_inference_cache() and _get_states_from_cache() helpers - Update Apriel2DynamicCache: remove Cache inheritance, simplify __init__ - Add get_mask_sizes() and has_previous_state() for HuggingFace compatibility - Auto-initialize cache states lazily during forward pass Implementation follows Mamba2 pattern from apriel_hybrid_ssm for consistency. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Implement GatedDeltaNet by wrapping Qwen3NextGatedDeltaNet: - Import Qwen3NextGatedDeltaNet at top level for consistency with Mistral imports - Create GatedDeltaNet wrapper class to adapt interfaces - Maps config_dict to Qwen3NextConfig format - Maps past_key_value -> cache_params parameter - Extracts cache_position from kwargs - Add recurrent_states property to Apriel2DynamicCache - Aliases ssm_states for Qwen3Next interface compatibility - Allows direct use of Apriel2DynamicCache with Qwen3NextGatedDeltaNet This enables gated_delta_net mixer type in apriel2 models. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Replace ssm_states with recurrent_states throughout - Aligns with Qwen3NextDynamicCache naming convention - Updates Apriel2DynamicCache and all Mamba cache access - Removes alias property in favor of direct naming - Rename classes for consistency: - Mamba -> Apriel2Mamba - GatedDeltaNet -> Apriel2GatedDeltaNet - Matches Apriel2Attention, Apriel2StochasticMixer naming 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Revert debug prints added for troubleshooting: - fast_llm/layers/attention/attention.py - fast_llm/layers/attention/rotary/rotary.py - fast_llm/layers/decoder/block.py - fast_llm/layers/language_model/head.py Revert irrelevant whitespace changes: - .github/ISSUE_TEMPLATE/feature_request.md - .github/workflows/manual-build.yml 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Stochastic mixer support is only available in apriel2, not apriel. Revert apriel.py to its original state. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Revert unnecessary formatting change (parameter line breaks). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Test changes for lossy conversion (stochastic mixer export) are not needed for the core functionality. Can be added later if needed for testing stochastic mixer checkpoint conversion. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Revert import reordering and blank line changes in setup.py - Add .eval() calls to 4 from_pretrained() calls in test_checkpoint.py for deterministic test behavior 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Use a more specific name to avoid potential conflicts and make the purpose clearer. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Replace manual weight sum validation with normalize_probabilities utility, consistent with dataset blending approach. Weights are now automatically normalized to sum to 1.0 during validation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
This feature will be implemented differently in the future. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Add iteration to BlockKwargs and pass it through preprocess_batch - Create torch CPU generator in preprocess, seeded with iteration - Sample mixers in forward using torch.multinomial with CPU generator - Store sampling probabilities on CPU to avoid device transfers - Preprocess all mixers since we don't know which will be selected per layer - Remove TP/PP generator usage which caused rank inconsistencies - Remove debug validation check (no longer needed with deterministic sampling) This ensures all DP/TP/PP ranks sample the same mixer sequence for each batch, while different layers can sample different mixers deterministically based on iteration. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
|
Hi @jlamypoirier! Thanks for the thorough review! I've addressed all your feedback and cleaned up the Stochastic mixer sampling fix (your main concern):
Code cleanup:
Note: I added a new apriel2 conversion module for the newer model format. This doesn't need review as it's separate from the stochastic mixer changes. Apriel hybrid SSM export is not supported. Stochastic layers are only in apriel2. Ready for another look! |
Implements modular, HuggingFace-compatible cache for Apriel2: - Extends transformers.Cache base class for ecosystem integration - Modular sub-caches for stochastic mixers (prevents cache corruption) - Dual initialization: forward() fallback + _prepare_cache_for_generation() - SSM direct access via property accessors (conv_states, recurrent_states) - Sliding window optimization with roll() for 97% memory savings - Active mixer routing for stochastic layers via set_active_mixer() - Type hints use specific Apriel2Cache (not generic Cache) - Fixed is_sliding to return list[bool] per HF spec - Fixed cache return in forward() (was None, now returns past_key_values) All Cache ABC methods implemented: - update(), get_seq_length(), get_max_cache_shape(), get_mask_sizes() - reorder_cache(), reset(), crop(), batch_repeat_interleave(), batch_select_indices() - Properties: is_compileable, is_initialized, is_sliding, max_batch_size, max_cache_len Model flags updated to match architecture: - _supports_quantized_cache = False (custom modular cache incompatible) - _supports_static_cache = False (only DynamicCache implemented) - _supports_attention_backend = True (standard attention) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Summary
Implements a stochastic mixer layer for supernet training, enabling random sampling from multiple mixer options (e.g., attention vs. Mamba) during training. Includes checkpoint conversion support and a hierarchical beam search tool for finding optimal mixer placement post-training.
Implementation Details
Stochastic Mixer (
fast_llm/layers/decoder/stochastic_mixer.py)main_mixer_indexfor deterministic behaviorConfiguration (
fast_llm/layers/decoder/config.py)StochasticMixerConfig: List-based mixer configuration with sampling strategymain_mixer_index: Specifies which mixer to use during inference and which receives pretrained weights during checkpoint conversionCheckpoint Conversion (
fast_llm/models/gpt/conversion/apriel.py)AprielStochasticMixerConverter: Handles conversion between Fast-LLM and Apriel formatsmain_mixer_indexweights are exported/imported (other mixers randomly initialized during supernet training)Beam Search Tool (
tools/supernet_beam_search.py)main_mixer_indexin-place for each candidateTests (
tests/utils/model_configs.py)stochastic_mixertest configuration with FA/Mamba mixersAprielHybridSSMCheckpointFormatUse Case
Supernet Training: Train a model where each layer can be either full attention or Mamba, with random sampling at each step. After training, use beam search to find which specific layers benefit most from full attention vs. Mamba, given a budget constraint (e.g., "I can afford 4 FA layers").
Testing
Run the stochastic mixer tests:
pytest tests/models/test_checkpoint.py::test_checkpoint_and_eval tests/models/test_checkpoint.py::test_conversion -k "stochastic_mixer" -vExample beam search usage:
fast-llm tools/supernet_beam_search.py \ training_config=path/to/supernet_config.yaml \ budgets=[4,8] \ beam_width=12 \ score_metric="lm_eval/accuracy" \ output_path=results.json🤖 Generated with Claude Code
Co-Authored-By: Claude [email protected]