Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
43729b1
Add stochastic mixer for supernet training
tscholak Oct 12, 2025
8b1eb08
Fix stochastic mixer test failures
tscholak Oct 14, 2025
8ada30b
Fix stochastic mixer checkpoint conversion
tscholak Oct 14, 2025
cd1dbf8
Handle lossy HF conversions for stochastic mixer
tscholak Nov 12, 2025
d0fd648
Merge remote-tracking branch 'origin/main' into stochastic-mixer
tscholak Nov 12, 2025
d693f74
Clean up extra blank line in huggingface.py
tscholak Nov 12, 2025
6962de9
Apply pre-commit formatting
tscholak Nov 12, 2025
a96c0cb
Refactor stochastic mixer: set main_mixer_name in validation, preproc…
tscholak Nov 12, 2025
735ee3f
wip
tscholak Nov 15, 2025
aed779c
resolve merge conflicts
tscholak Nov 20, 2025
982d409
Implement full stochastic mixer support in Apriel HuggingFace format
tscholak Nov 20, 2025
0d8ab4d
Add Apriel2 checkpoint format and fix weight tying
tscholak Nov 21, 2025
bcd93b2
Optimize Apriel2: compute position embeddings and masks per unique block
tscholak Nov 21, 2025
ebe75c4
Add HuggingFace generation and caching improvements to Apriel2
tscholak Nov 21, 2025
ffd55e5
Add Apriel2DynamicCache for hybrid attention/SSM layer support
tscholak Nov 21, 2025
fe259c3
Add Mamba incremental generation support to Apriel2
tscholak Nov 21, 2025
708917d
Add GatedDeltaNet support via Qwen3NextGatedDeltaNet wrapper
tscholak Nov 21, 2025
77ceae2
Standardize naming: recurrent_states and Apriel2 prefixes
tscholak Nov 21, 2025
ec95ccc
Remove debug print statements and irrelevant changes
tscholak Nov 21, 2025
571fede
Remove stochastic mixer support from apriel conversion
tscholak Nov 21, 2025
8e7c154
Remove trivial formatting change from apriel_hybrid_ssm config
tscholak Nov 21, 2025
4d0a01b
Remove test changes for lossy HF conversion
tscholak Nov 21, 2025
71cf778
Revert trivial setup.py formatting and restore .eval() calls in tests
tscholak Nov 21, 2025
75847d0
Rename SamplingStrategy to StochasticMixerSamplingStrategy
tscholak Nov 21, 2025
eacdf61
Use normalize_probabilities for sampling weights validation
tscholak Nov 21, 2025
192e985
Remove tools/supernet_beam_search.py
tscholak Nov 21, 2025
2fe9596
Fix stochastic mixer sampling to be consistent across all ranks
tscholak Nov 22, 2025
acb4751
Add Apriel2Cache with JetNemotron pattern and HF Cache compliance
tscholak Nov 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ assignees: ''
---

# 🎯 **Goal (What & Why)**
> **Clearly state the purpose of this feature.**
> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_

# 🚀 **Execution Plan**
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_

### **Step 1: What is the smallest working version?**
> _(Describe the simplest way to implement this feature with minimal effort.)_
> _(Describe the simplest way to implement this feature with minimal effort.)_

### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_

# 📌 **Acceptance Criteria** (Must-Haves for Completion)
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.

# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
- [ ] **Assign an owner when opening the issue.**
- [ ] **Assign an owner when opening the issue.**
14 changes: 7 additions & 7 deletions .github/workflows/manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true

- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }}

- name: Get commit info
id: commit_info
run: |
Expand All @@ -48,7 +48,7 @@ jobs:
echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT
echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT
echo "Building from commit: ${COMMIT_SHA}"

- name: Docker meta
id: meta
uses: docker/metadata-action@v5
Expand All @@ -59,18 +59,18 @@ jobs:
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }}
type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GHCR
if: ${{ inputs.push_image }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand All @@ -80,7 +80,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max

- name: Output build info
run: |
echo "Built Docker image with tags:"
Expand Down
28 changes: 28 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,38 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch
"format": "pt",
}

def _initialize_missing_parameters(self) -> None:
# Parameters that exist in the model but not in the checkpoint import converters
missing_params = set(self._export_converters.keys()) - {
weight_converter.fast_llm_name[0]
for weight_converter in self._import_converters.values()
if weight_converter.fast_llm_name
}

print(f"[INIT DEBUG] Checking for missing parameters in HuggingFace checkpoint...")
print(f"[INIT DEBUG] Model has {len(self._export_converters)} parameters")
print(f"[INIT DEBUG] Checkpoint has {len(self._import_converters)} parameters")
print(f"[INIT DEBUG] Missing: {len(missing_params)} parameters")

if missing_params:
logger.warning(
f"Initializing {len(missing_params)} parameters not in HuggingFace checkpoint"
)
print(f"[INIT DEBUG] Initializing {len(missing_params)} parameters:")
for param in sorted(missing_params)[:5]: # Show first 5
print(f"[INIT DEBUG] {param}")
if len(missing_params) > 5:
print(f"[INIT DEBUG] ... and {len(missing_params) - 5} more")
for stage in self._model._stages:
stage.initialize_weights_for_parameters(missing_params)

def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
print(f"[INIT DEBUG] HuggingfaceStateDictCheckpointHandler.load() called")
assert not config.optimizer_state
metadata = self._model.config.load_metadata(config)
self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning)
# Initialize parameters not covered by import converters
self._initialize_missing_parameters()
super().load(config)

def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
Expand Down
11 changes: 11 additions & 0 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,15 @@ def _replace(module: torch.nn.Module):
Assert.eq(i, len(self._parameter_metas))
assert not tied_parameter_duplicate_buffers, tied_parameter_duplicate_buffers.keys()

def initialize_weights_for_parameters(self, parameter_names: set[str]) -> None:
"""Initialize only the specified parameters. Used for partial initialization after checkpoint load."""
self._initialize_weights_internal(lambda meta: meta.tensor_name in parameter_names)

def initialize_weights(self) -> None:
"""Initialize all weights."""
self._initialize_weights_internal(lambda meta: True)

def _initialize_weights_internal(self, should_initialize: typing.Callable) -> None:
# TODO: Avoid all the _on_device checks
assert self._is_setup
with torch.no_grad():
Expand All @@ -180,6 +188,9 @@ def initialize_weights(self) -> None:
]

for meta in metas:
# Skip parameters we shouldn't initialize
if not should_initialize(meta):
continue
if meta.tensor_name in self._tied_parameter_duplicates:
# Initialization is not managed by this stage.
continue
Expand Down
90 changes: 89 additions & 1 deletion fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.parameter import combine_lr_scales
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.block.config import BlockConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.normalization.config import NormalizationConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer


class StochasticMixerKwargs(BlockKwargs):
"""Kwargs keys for stochastic mixer."""

mixer_name = "stochastic_mixer_name"


@config_class()
Expand Down Expand Up @@ -55,6 +63,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


class SamplingStrategy(str, enum.Enum):
"""Strategy for sampling mixers in a stochastic mixer."""

uniform = "uniform"
weighted = "weighted"


@config_class(registry=True)
class MixerConfig(BlockWithBiasConfig):
"""
Expand All @@ -71,6 +86,79 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


@config_class(dynamic_type={MixerConfig: "stochastic"})
class StochasticMixerConfig(MixerConfig):
"""
Stochastic mixer that uniformly samples from multiple mixer options during training.

For supernet training, each forward pass randomly selects one mixer to execute,
training all mixers with different subsets of data.
"""

_abstract = False

mixers: dict[str, MixerConfig] = Field(
desc="Dict of mixer options to sample from (must contain at least 1). "
"Keys are mixer names used for debugging and namespacing.",
hint=FieldHint.architecture,
)

sampling_strategy: SamplingStrategy = Field(
default=SamplingStrategy.uniform,
desc="Strategy for sampling mixers during training.",
hint=FieldHint.feature,
)

sampling_weights: dict[str, float] | None = Field(
default=None,
desc="Sampling probability for each mixer by name (must sum to 1.0). "
"Only used when sampling_strategy='weighted'. "
"If None with uniform strategy, all mixers have equal probability.",
hint=FieldHint.feature,
)

main_mixer_name: str | None = Field(
default=None,
desc="Name of the main mixer. "
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported). "
"If None, uses the first mixer in the dict.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
super()._validate()

# Validate mixers dict is not empty
Assert.gt(len(self.mixers), 0)

# Set main_mixer_name to first mixer if not specified
if self.main_mixer_name is None:
with self._set_implicit_default():
self.main_mixer_name = next(iter(self.mixers.keys()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could just enforce this and make main_mixer_name a cached property instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it the way it is


# Validate main mixer name exists
if self.main_mixer_name not in self.mixers:
raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers")

# Validate sampling weights
if self.sampling_weights is not None:
Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys()))
# Check sum is close to 1.0
weight_sum = sum(self.sampling_weights.values())
if abs(weight_sum - 1.0) > 1e-5:
raise ValueError(f"Sampling weights must sum to 1.0, got {weight_sum}")
# Check all weights are non-negative
if any(w < 0 for w in self.sampling_weights.values()):
raise ValueError("All sampling weights must be non-negative")

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer

return StochasticMixer


@config_class(dynamic_type={BlockConfig: "decoder"})
class DecoderBlockConfig(BlockConfig):
_abstract = False
Expand Down
Loading