Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
214 commits
Select commit Hold shift + click to select a range
7e17715
Refactor Apriel2 configuration and preprocessing architecture
tscholak Nov 27, 2025
4496e2a
Fix cache validation test to properly test both empty and corrupted c…
tscholak Nov 27, 2025
c2c17e7
Fix Apriel2 config and converter issues
tscholak Nov 27, 2025
98a5d25
Clean up Apriel2 converters with stratified inheritance
tscholak Nov 27, 2025
c4a7709
Add Llava-to-Apriel2 HuggingFace converter with comprehensive tests
tscholak Nov 28, 2025
f3992bf
Separate model conversion from surgery for Apriel2
tscholak Nov 28, 2025
935f595
Replace legacy converters with expression-based plan system
tscholak Nov 28, 2025
c95b899
Add DIL conversion, stochastic mixer support, and fix tree collapsing
tscholak Nov 29, 2025
255be1b
Add streaming I/O for memory-efficient weight conversion
tscholak Nov 29, 2025
10a4f38
Refactor conversion into modular subpackage with source-agnostic conv…
tscholak Nov 30, 2025
31513b2
Add gated_delta_net mixer to stochastic supernet example
tscholak Nov 30, 2025
b9bd43a
Add surgery chains, Apriel2 source format, and clean up docstrings
tscholak Dec 1, 2025
3eb8bfb
Merge remote-tracking branch 'origin/main' into tscholak/apriel2-conv…
tscholak Dec 1, 2025
e135f00
Rename patch_convolution to embeddings for consistency with Fast-LLM
tscholak Dec 1, 2025
8445aaf
add non-approximated gelu
RaymondLi0 Dec 2, 2025
da3786b
Fix vision encoder numerical equivalence and add comprehensive test s…
tscholak Dec 2, 2025
aa46283
remove projector_intermediate_size
RaymondLi0 Dec 2, 2025
17c9970
fix llava hf weight prefixes
RaymondLi0 Dec 2, 2025
bd321bd
Fix Apriel2 converter weight paths after external model refactor
tscholak Dec 3, 2025
1eb07a2
Merge origin/main into tscholak/apriel2-conversion
tscholak Dec 3, 2025
249250b
Add 2D rotary embedding equivalence tests for FastLLM vs Pixtral
tscholak Dec 3, 2025
6e5da16
fix vision tower hf prefix
RaymondLi0 Dec 3, 2025
f260277
fix intermediate size import
RaymondLi0 Dec 3, 2025
98b6283
remove gelu_gaussian
RaymondLi0 Dec 3, 2025
2ab1825
Fix rotary 2d
jlamypoirier Dec 4, 2025
8305dd5
stuff
jlamypoirier Dec 4, 2025
b6e38b8
stuff
jlamypoirier Dec 4, 2025
c5aeb31
Inline GDN implementation in Apriel2 with Fast-LLM aligned naming
tscholak Dec 4, 2025
d90cb86
Fix llava converter to use explicit head_dim when available
tscholak Dec 4, 2025
9c4152a
handle pil images
RaymondLi0 Dec 4, 2025
f4d3ed6
Add mixer equivalence tests for Apriel2
tscholak Dec 4, 2025
eb9dfc2
add dict format
RaymondLi0 Dec 4, 2025
e85b573
handle large images
RaymondLi0 Dec 4, 2025
d075a16
Improve mixer equivalence test fixtures and cleanup
tscholak Dec 4, 2025
7312ea9
missing import
RaymondLi0 Dec 4, 2025
faf9cba
Use conversion machinery for mixer equivalence tests
tscholak Dec 4, 2025
ee92862
Add multi-seed verification for GDN layout conversion
tscholak Dec 4, 2025
97b10f8
Update DIL conversion to produce flat layout for in_proj_qkvz
tscholak Dec 4, 2025
4bbe459
Add test_mode fixture for coherent dtype/attn_impl/tolerance bundling
tscholak Dec 4, 2025
e88cf2e
fallback empty patch batch
oleksost Dec 4, 2025
8ad60a4
Fix Apriel2 config converter format mismatches and add training examples
tscholak Dec 4, 2025
4a7a283
Merge branch 'text_only_multimodal' into tscholak/apriel2-conversion
oleksost Dec 4, 2025
d27cef1
Merge branch 'tscholak/apriel2-conversion' of https://github.com/Serv…
oleksost Dec 4, 2025
72f915d
Merge branch 'main' into jlp/consistent_preprocessing
jlamypoirier Dec 4, 2025
350fb3d
stuff
jlamypoirier Dec 5, 2025
26cdd2b
Merge branch 'main' into tscholak/apriel2-conversion
oleksost Dec 5, 2025
a95b7bd
merging all together
oleksost Dec 5, 2025
8c9a0bf
Merge branch 'raymond/gelu_act' into oo/apriel2
oleksost Dec 5, 2025
8de4180
wip
oleksost Dec 5, 2025
d27a815
fix
jlamypoirier Dec 5, 2025
72f3a31
Merge branch 'main' into jlp/consistent_preprocessing
jlamypoirier Dec 5, 2025
5ab6cd0
fixes
jlamypoirier Dec 6, 2025
2bb330a
multimodal batch
oleksost Dec 7, 2025
1e74469
Merge remote-tracking branch 'origin/main' into jlp/consistent_prepro…
jlamypoirier Dec 8, 2025
6681e6f
Merge branch 'main' into oo/apriel2
oleksost Dec 8, 2025
310c311
Add KDA mixer and refactor Apriel2 conversion architecture
tscholak Dec 9, 2025
8c7a1ca
Merge remote-tracking branch 'origin/jlp/consistent_preprocessing' in…
oleksost Dec 9, 2025
d44244c
Fix unused variables and add CUDA skip for KDA test
tscholak Dec 9, 2025
933be9f
Improve StochasticMixer debug logging and increase bf16 test tolerance
tscholak Dec 9, 2025
68d1516
Fix vision encoder debug logging crash with model_debug_level
tscholak Dec 9, 2025
f4e9560
Add activation-level distillation and freeze non-mixer components
tscholak Dec 10, 2025
b2a2470
fixed kda test
oleksost Dec 10, 2025
ac94659
Merge branch 'main' into raymond/image_format
RaymondLi0 Dec 10, 2025
2ddaa23
Merge branch 'main' into oo/apriel2
oleksost Dec 10, 2025
95d9780
Merge branch 'tscholak/apriel2-kda' into oo/apriel2
oleksost Dec 10, 2025
ce58463
merged from apriel 2 kda
oleksost Dec 10, 2025
d598fc9
Merge remote-tracking branch 'origin/main' into oo/apriel2
oleksost Dec 10, 2025
de9b523
token sample int
oleksost Dec 10, 2025
f9a3bdf
wip
oleksost Dec 10, 2025
06f9a9a
empty image patches fix
oleksost Dec 10, 2025
e7129cb
Merge branch 'main' into raymond/image_format
RaymondLi0 Dec 11, 2025
66d8641
Merge remote-tracking branch 'origin/main' into oo/apriel2
oleksost Dec 11, 2025
617914c
Merge remote-tracking branch 'origin/raymond/image_format' into oo/ap…
oleksost Dec 11, 2025
5461900
fixes masked loss distillation
oleksost Dec 11, 2025
b8751c4
wip
oleksost Dec 11, 2025
00a3327
fixes masked loss distillation
oleksost Dec 11, 2025
c98cfed
test forward with loss masks
oleksost Dec 12, 2025
ba2c061
test forward with loss masks
oleksost Dec 12, 2025
29d0fdb
Merge remote-tracking branch 'origin/main' into fixes_masked_distilla…
oleksost Dec 12, 2025
493fe87
fix kda test
oleksost Dec 12, 2025
c68a742
varlen test fix
oleksost Dec 12, 2025
daba344
manual kl grad computation
oleksost Dec 12, 2025
1d0df17
comment
oleksost Dec 12, 2025
9ae4e73
clean
oleksost Dec 12, 2025
ab84273
Merge branch 'reverse_kl_fixes' into fixes_masked_distillation
oleksost Dec 12, 2025
3a3d06e
tests
oleksost Dec 12, 2025
bc2c525
test device
oleksost Dec 13, 2025
44c5f63
grad fix
oleksost Dec 13, 2025
0111e9f
fixes
oleksost Dec 13, 2025
fe946b7
Merge remote-tracking branch 'origin/main' into fixes_masked_distilla…
oleksost Dec 13, 2025
f6238c0
clean
oleksost Dec 13, 2025
f28b241
clean
oleksost Dec 13, 2025
e41c040
nvm
oleksost Dec 13, 2025
ed6b793
Refactor Apriel2 cache and add Qwen2 converter
tscholak Dec 14, 2025
843a355
fix qwen converted to correctly load qkv biases
bigximik Nov 28, 2025
33b6d31
fix converters
bigximik Dec 2, 2025
7822975
Add per-layer bias support, surgery improvements, and integration tests
tscholak Dec 14, 2025
00cc8a9
Merge origin/main into feature/cache-refactor-and-qwen2
tscholak Dec 14, 2025
6946a98
Merge branch 'main' into oo/apriel2
oleksost Dec 15, 2025
a0123b8
Merge branch 'fixes_masked_distillation' into oo/apriel2
oleksost Dec 15, 2025
4efcb25
clean warning
oleksost Dec 15, 2025
b6e8775
clean warnings
oleksost Dec 15, 2025
1f84e55
log selected mixer and activation loss per layer
RaymondLi0 Dec 15, 2025
6095317
handle padding in activation-distillation
RaymondLi0 Dec 15, 2025
7053d8c
Add conversation format support for SFT data preparation
tscholak Dec 15, 2025
53d6570
Cleanup: remove private method indirection, revert test changes
tscholak Dec 15, 2025
8002e96
Merge remote-tracking branch 'origin/raymond/mixer_metrics' into oo/a…
oleksost Dec 15, 2025
f5e4d93
train with only layer distillation losses
oleksost Dec 16, 2025
c335f6e
train with only layer distillation losses
oleksost Dec 16, 2025
d053d47
Refactor test organization: rename modules and remove duplication
tscholak Dec 16, 2025
0779c63
unscaled loss llogging + training with distillation loss factor = 0
oleksost Dec 16, 2025
e06a4b2
unscaled loss llogging + training with distillation loss factor = 0
oleksost Dec 16, 2025
dcd55a5
clean up
oleksost Dec 16, 2025
6fef1fb
loss mask transposition was missing
oleksost Dec 16, 2025
8da6f10
loss masking fixes: cross entropy averaging & training with minibatches
oleksost Dec 16, 2025
8c958d8
fix log selected mixer
RaymondLi0 Dec 16, 2025
b6dd6dc
Fix O(nΒ²) tokenization and add Qwen2 training examples
tscholak Dec 16, 2025
f61a6d1
Improve Apriel2 conversion config composition and documentation
tscholak Dec 16, 2025
e2032f5
added loss comparison
oleksost Dec 16, 2025
9beda77
Merge branch 'fixes_masked_distillation' into oo/apriel2
oleksost Dec 16, 2025
8671b68
Merge remote-tracking branch 'origin/raymond/mixer_metrics' into oo/a…
oleksost Dec 16, 2025
1fa2461
clean
oleksost Dec 16, 2025
d4baaff
nvm
oleksost Dec 17, 2025
8933953
Fix RangeSample.from_documents and loss mask distillation bugs
tscholak Dec 17, 2025
cfa3663
Merge origin/main into feature/cache-refactor-and-qwen2
tscholak Dec 17, 2025
711495e
refactor loss logging
oleksost Dec 17, 2025
179ae25
make logging more explicit
oleksost Dec 17, 2025
af456f0
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 17, 2025
9968aac
clean + tests
oleksost Dec 17, 2025
945c5a7
nvm
oleksost Dec 17, 2025
7bb54f4
Merge branch 'train_only_layer_losses' into oo/apriel2
oleksost Dec 17, 2025
4712744
lm head
oleksost Dec 17, 2025
3c3f597
nvm
oleksost Dec 18, 2025
d9e5e08
nvm
oleksost Dec 18, 2025
efa9b61
optimize
oleksost Dec 18, 2025
1a8f107
fuse
oleksost Dec 18, 2025
22ecfb0
gitignore
oleksost Dec 19, 2025
4a6be98
manual kl + memory savings
oleksost Dec 19, 2025
1277894
Skip roundtrip integration tests on CPU-only CI
tscholak Dec 19, 2025
eed426a
average by seq. length
oleksost Dec 19, 2025
9588fe3
Merge branch 'rev_kl_improvements' into oo/apriel2
oleksost Dec 19, 2025
1205c81
forward KL
oleksost Dec 19, 2025
490893f
empty ranges
oleksost Dec 19, 2025
e4ec34b
Merge branch 'main' into oo/apriel2
oleksost Dec 19, 2025
4b6e3d7
forward KL
oleksost Dec 19, 2025
c5fefa0
test forward kl
oleksost Dec 19, 2025
4119596
wip: report unscaled + kl loss
oleksost Dec 19, 2025
ae1e48b
layer distillation loss with masking and sequence parallelism
oleksost Dec 19, 2025
37a0be9
clean
oleksost Dec 20, 2025
9273966
Refactor conversation format handling and tokenize_chat
tscholak Dec 20, 2025
ac5da3c
Merge remote-tracking branch 'origin/main' into feature/cache-refacto…
tscholak Dec 20, 2025
b55a0a4
loss config
oleksost Dec 22, 2025
097baeb
wip
oleksost Dec 22, 2025
d773d98
tests
oleksost Dec 22, 2025
35400c1
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 22, 2025
282925c
test
oleksost Dec 22, 2025
0f73ea2
tests
oleksost Dec 22, 2025
04a0193
Merge branch 'main' into train_only_layer_losses
oleksost Dec 22, 2025
fa85c41
wip
oleksost Dec 22, 2025
feb416e
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 22, 2025
31cfb84
wip
oleksost Dec 23, 2025
24fe67b
no grad if factor 0
oleksost Dec 23, 2025
00f6118
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 23, 2025
0cadf98
Merge branch 'main' into train_only_layer_losses
oleksost Dec 23, 2025
0e562e9
addressed comments
oleksost Dec 23, 2025
2a474e2
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 23, 2025
52c1c11
addressed comments
oleksost Dec 23, 2025
406d0a2
Removed Targets class
oleksost Dec 30, 2025
f25380a
fixes
oleksost Dec 30, 2025
8adb7dd
imports
oleksost Dec 30, 2025
1a12917
Merge branch 'main' into oo/apriel2
oleksost Dec 30, 2025
c34bd7e
wip
oleksost Dec 30, 2025
41e2521
modeling checkout
oleksost Dec 30, 2025
0c37f80
double negation bug
oleksost Jan 5, 2026
ccbec88
config assertion bug
oleksost Jan 5, 2026
eacff5f
Merge remote-tracking branch 'origin/main' into feature/cache-refacto…
tscholak Jan 5, 2026
e5172d5
Fix GDN mixer dtype mismatches in Apriel2 model
tscholak Jan 5, 2026
ef990a5
Run code formatters (black, isort, autoflake, pyupgrade)
tscholak Jan 5, 2026
b1b0c31
Add forward KL evaluator for teacher trace evaluation
tscholak Dec 21, 2025
c774cec
Refactor ForwardKLEvaluator to use InferenceRunner
tscholak Dec 21, 2025
565d137
Add sequence length handling and global_logits support
tscholak Dec 21, 2025
90e3200
Make max_sequence_length mandatory with default 2048
tscholak Dec 21, 2025
66ceee1
Add distributed training support to ForwardKLEvaluator
tscholak Dec 21, 2025
fd7670b
Fix global_logits storage during distillation and clean up evaluator
tscholak Dec 22, 2025
10e24ca
Refactor ForwardKLEvaluator to compute IS accuracy and ESS metrics
tscholak Dec 24, 2025
54c5f9c
Fix eval mode for StochasticMixer and add diagnostics
tscholak Jan 5, 2026
ebf1174
empty buffer skip
oleksost Jan 5, 2026
1836bbc
remove double negation
oleksost Jan 5, 2026
b3653d0
undo skip empty buffer
oleksost Jan 5, 2026
3b9d367
Merge branch 'main' into oo/apriel2
oleksost Jan 5, 2026
cbebaa8
evoid padding overlap in state loading
oleksost Jan 6, 2026
b420290
debugging padding
oleksost Jan 6, 2026
d87f825
debugging
oleksost Jan 6, 2026
a9d146e
padding correction
oleksost Jan 6, 2026
2d23387
remove unnecessary logging
oleksost Jan 6, 2026
80c40af
Revert debugging commits
oleksost Jan 6, 2026
1ce641d
polish naming
oleksost Jan 6, 2026
9b4e287
test lm head
oleksost Jan 7, 2026
7a2142d
test ssm
oleksost Jan 7, 2026
9d6d61a
Merge branch 'bug_fixing' into oo/apriel2
oleksost Jan 7, 2026
574b1d4
tests and cross entropy loss averaging over all tokens
oleksost Jan 7, 2026
78311c9
Merge branch 'add-forward-kl-evaluator' into oo/apriel2
oleksost Jan 7, 2026
27ce285
set test time mixer type
oleksost Jan 7, 2026
28d90de
progress bar
oleksost Jan 8, 2026
44c9a6e
distributed bug (fsdp)
oleksost Jan 8, 2026
95f14af
addresseing comments
oleksost Jan 8, 2026
5ad4c0c
explicit z_loss grads
oleksost Jan 8, 2026
0a66e14
removed z_loss as aux loss
oleksost Jan 8, 2026
f8f7041
move loss configs to the lm config
oleksost Jan 8, 2026
ab9c917
tests
oleksost Jan 8, 2026
2199f51
Merge branch 'train_only_layer_losses' into oo/apriel2
oleksost Jan 9, 2026
b700470
nvm
oleksost Jan 9, 2026
2c27adb
no reference models at inference
oleksost Jan 9, 2026
66078fb
add padding and image placeholder into loss mask
oleksost Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ devenv.*

# direnv
.direnv

# wandb
wandb/
25 changes: 23 additions & 2 deletions fast_llm/data/sample/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,41 @@ def __init__(
chosen_spans: RangeBatch | None = None,
rejected_spans: RangeBatch | None = None,
image_patches: PatchBatch | None = None,
valid_tokens: int | None = None,
):
self.tokens = tokens
self.loss_masking_spans = loss_masking_spans
self.chosen_spans = chosen_spans
self.rejected_spans = rejected_spans
self.image_patches = image_patches
self.valid_tokens = valid_tokens

@classmethod
def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self:
samples = list(samples)
token_batch = TokenBatch.from_samples([sample.tokens for sample in samples])
loss_masking_spans = _merge_optional(
RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]
)

# Calculate valid tokens for this batch (used for gradient accumulation weighting)
valid_tokens = None
if loss_masking_spans is not None:
batch_size, sequence_length = token_batch.tokens.shape
# Start with all tokens
valid_tokens = batch_size * sequence_length
# Subtract masked tokens
for sample_ranges in loss_masking_spans.ranges:
for begin, end in sample_ranges:
valid_tokens -= end - begin

return cls(
TokenBatch.from_samples([sample.tokens for sample in samples]),
_merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]),
token_batch,
loss_masking_spans,
_merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]),
_merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]),
_merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]),
valid_tokens,
)

def crop(self, begin: int, end: int) -> typing.Self:
Expand All @@ -124,6 +144,7 @@ def crop(self, begin: int, end: int) -> typing.Self:
_crop_optional(self.chosen_spans, begin, end),
_crop_optional(self.rejected_spans, begin, end),
_crop_optional(self.image_patches, begin, end),
valid_tokens=None, # Cropped batches don't have valid token counts
)

def to_device_(self, device: "torch.device | str"):
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/sample/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_document(self, index: int, begin: int, end: int) -> Sample:
begin_ = self._size_cumsums[index].item()
# Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues.
# Convert begin and end to int to avoid numpy dtype overflow when adding to begin_
return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin])
return TokenSample(self._tokens[begin_ + int(begin) : begin_ + int(end)].to(torch.int64), [end - begin])

def get_document_sizes(self) -> torch.Tensor:
return self._size_cumsums[1:] - self._size_cumsums[:-1]
Expand Down
56 changes: 56 additions & 0 deletions fast_llm/engine/evaluation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if typing.TYPE_CHECKING:
from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator
from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator


@config_class()
Expand Down Expand Up @@ -119,3 +120,58 @@ def get_evaluator(
from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator

return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters)


@config_class(dynamic_type={EvaluatorConfig: "forward_kl"})
class ForwardKLEvaluatorConfig(EvaluatorConfig):
_abstract: typing.ClassVar[bool] = False

dataset_path: str | None = Field(
default=None,
desc="HuggingFace dataset path containing teacher traces.",
hint=FieldHint.core,
)
split: str = Field(
default="validation",
desc="Dataset split to evaluate on. Use 'train+validation' syntax to combine multiple splits.",
hint=FieldHint.optional,
)
seed: int = Field(
default=42,
desc="Random seed for shuffling traces. Ensures reproducible evaluation across runs.",
hint=FieldHint.optional,
)
num_samples: int | None = Field(
default=None,
desc="Maximum number of traces to evaluate (after shuffling). None for all.",
hint=FieldHint.optional,
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)
batch_size: int = Field(
default=8,
desc="Batch size for forward passes.",
hint=FieldHint.performance,
valid=check_field(Assert.gt, 0),
)
trust_remote_code: bool = Field(
default=False,
desc="Trust remote code when loading dataset.",
hint=FieldHint.optional,
)
inference_mixer: str | None = Field(
default=None,
desc="Name of the mixer to use during evaluation (for StochasticMixer models). "
"If None, uses the model's default main_mixer_name.",
hint=FieldHint.optional,
)

def get_evaluator(
self,
name: str,
batch_config: BatchConfig,
data_load_num_proc: int,
train_iters: int | None = None,
) -> "ForwardKLEvaluator":
from fast_llm.engine.evaluation.forward_kl.evaluator import ForwardKLEvaluator

return ForwardKLEvaluator(name, self, batch_config, data_load_num_proc, train_iters)
Empty file.
Loading