Skip to content

feat: add support and end-to-end tests for multiple custom optimizers…#3457

Open
OnePunchMonk wants to merge 8 commits intoaxolotl-ai-cloud:mainfrom
OnePunchMonk:main
Open

feat: add support and end-to-end tests for multiple custom optimizers…#3457
OnePunchMonk wants to merge 8 commits intoaxolotl-ai-cloud:mainfrom
OnePunchMonk:main

Conversation

@OnePunchMonk
Copy link

@OnePunchMonk OnePunchMonk commented Mar 5, 2026

… including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.
Feature support with test case for flashoptim's FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion

Description

Added features for 'flashoptim' usage

Motivation and Context

#3451

How has this been tested?

Still untested

AI Usage Disclaimer

Claude

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for five Flash-based optimizers for training: FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, and FlashLion; identifiers added to supported options.
  • Validation

    • Prevents using Flash optimizers with incompatible distributed backends and enforces FSDP version compatibility when FSDP is enabled.
  • Tests

    • Added end-to-end tests verifying each Flash optimizer is selectable and integrated during training.

… including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f9f89e18-973a-423e-9e93-6d7517d150a8

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds five Flash-based optimizers (flash_adamw, flash_adam, flash_sgd, flash_sgdw, flash_lion): maps identifiers to flashoptim classes in the core builder, extends the optimizer enum, enforces DeepSpeed/FSDP validation for flash optimizers, and adds an end-to-end parametrized test covering each optimizer.

Changes

Cohort / File(s) Summary
Optimizer configuration
src/axolotl/core/builders/base.py
Added flash_* branches in _configure_custom_optimizer to import corresponding flashoptim classes, set optimizer_cls, and forward/adapt existing adam-related kwargs (propagates betas to flash_lion).
Enums
src/axolotl/utils/schemas/enums.py
Extended CustomSupportedOptimizers with flash_adamw, flash_adam, flash_sgd, flash_sgdw, and flash_lion.
Validation
src/axolotl/utils/schemas/validation.py
Added check_flashoptim_deepspeed_fsdp validator to reject DeepSpeed with flash optimizers and require fsdp_version == 2 when FSDP is used with flash optimizers (defaults fsdp_version to 1 if absent).
End-to-end tests
tests/e2e/test_optimizers.py
Added parametrized test test_flash_optimizers for the five flash_* optimizers (guards on torch>=2.7.0) and added imports for pytest, torch, and packaging.version.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title references 'multiple custom optimizers' and covers the main additions (support and end-to-end tests for Flash optimizers), though it is slightly truncated with ellipsis and somewhat broad.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@OnePunchMonk
Copy link
Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/core/builders/base.py`:
- Around line 356-379: Add a validation in the same validation module that
contains the existing Muon optimizer check to reject FlashOptimizers when FSDP1
is selected: detect optimizer values "flash_adamw", "flash_adam", "flash_sgd",
"flash_sgdw", "flash_lion" (or any optimizer name starting with "flash_") and if
the fsdp configuration/type equals "fsdp1" raise the same kind of
ValidationError used by the Muon check, with a clear message that FlashOptim
does not support FSDP1 and only supports DDP and FSDP2; follow the exact pattern
and placement used by the Muon optimizer validation block so the new check runs
during schema validation.

In `@tests/e2e/test_optimizers.py`:
- Around line 286-328: Add the `@require_torch_2_7_0` decorator to
test_flash_adamw and add four new tests mirroring its structure for the other
flash optimizers: test_flash_adam, test_flash_sgd, test_flash_sgdw, and
test_flash_lion. For each test (same pattern as test_flash_adamw) set
cfg["optimizer"] to the appropriate string ("flash_adam", "flash_sgd",
"flash_sgdw", "flash_lion"), run
validate_config/normalize_config/load_datasets/train, and assert the
trainer.optimizer.optimizer.__class__.__name__ contains the corresponding class
identifier (e.g., "FlashAdam", "FlashSGD", "FlashSGDW", "FlashLion") to verify
the correct optimizer is instantiated. Ensure each new test uses `@with_temp_dir`
and the `@require_torch_2_7_0` decorator consistent with existing
version-dependent tests like test_dion and test_came_pytorch.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c99e1e60-b323-476d-a422-5a8f77a5a892

📥 Commits

Reviewing files that changed from the base of the PR and between 4b8bc52 and c6cc0c5.

📒 Files selected for processing (3)
  • src/axolotl/core/builders/base.py
  • src/axolotl/utils/schemas/enums.py
  • tests/e2e/test_optimizers.py

…rious custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.
@OnePunchMonk
Copy link
Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 811-821: In check_flashoptim_deepspeed_fsdp, guard against
optimizer being None by coercing it to a string before calling startswith (e.g.,
optimizer = data.get("optimizer") or ""), and resolve FSDP version by checking
both top-level "fsdp_version" and "fsdp_config" keys (look for
data.get("fsdp_version") or data.get("fsdp_config", {}).get("version") or
data.get("fsdp_config", {}).get("fsdp_version")), then compare str(fsdp_version)
to "2"; update the logic in check_flashoptim_deepspeed_fsdp to use these safe
lookups so valid FSDP2 configs with only fsdp_config.version pass and optimizer:
null doesn't raise.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b891219e-2d85-439d-a8cf-b8e32b7ad40b

📥 Commits

Reviewing files that changed from the base of the PR and between c6cc0c5 and f7a27ef.

📒 Files selected for processing (2)
  • src/axolotl/utils/schemas/validation.py
  • tests/e2e/test_optimizers.py

@OnePunchMonk
Copy link
Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 7, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)

819-824: Keep FSDP version resolution consistent across optimizer validators.

This block now accepts fsdp_config.version, but check_muon_deepspeed_fsdp still only reads top-level fsdp_version / fsdp_config.fsdp_version. Equivalent FSDP2 configs will therefore validate for FlashOptim and fail for Muon depending on the selected optimizer. Please extract the version lookup into a shared helper, or normalize once before both validators, so these compatibility rules do not drift.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/schemas/validation.py` around lines 819 - 824, The FSDP
version lookup is inconsistent between validators; extract the logic used around
fsdp_config/fsdp_version into a shared helper (e.g., get_fsdp_version(data) or
normalize_fsdp_in_config(data)) and use it from both check_muon_deepspeed_fsdp
and the FlashOptim validator so both read the same resolved version (checking
data["fsdp_version"], fsdp_config["version"], fsdp_config["fsdp_version"], then
defaulting to 1). Update places that currently compute fsdp_version inline to
call the helper or run the normalizer once before validation to ensure
consistent behavior across optimizers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 289-301: The test method test_flash_optimizers is a
unittest.TestCase-style method decorated with pytest.mark.parametrize, which
pytest will not apply to TestCase methods; update the test to be a plain pytest
function or remove parametrize and iterate manually inside
test_flash_optimizers. Specifically, either convert the method
test_flash_optimizers (and its decorators `@with_temp_dir`, `@require_torch_2_7_0`)
into a top-level function test_flash_optimizers(...) so pytest can inject
optimizer_name/expected_class/learning_rate, or keep it as a TestCase method and
replace the `@pytest.mark.parametrize` block by a for-loop over the list of tuples
(optimizer_name, expected_class, learning_rate) inside test_flash_optimizers to
run each case.

---

Nitpick comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 819-824: The FSDP version lookup is inconsistent between
validators; extract the logic used around fsdp_config/fsdp_version into a shared
helper (e.g., get_fsdp_version(data) or normalize_fsdp_in_config(data)) and use
it from both check_muon_deepspeed_fsdp and the FlashOptim validator so both read
the same resolved version (checking data["fsdp_version"],
fsdp_config["version"], fsdp_config["fsdp_version"], then defaulting to 1).
Update places that currently compute fsdp_version inline to call the helper or
run the normalizer once before validation to ensure consistent behavior across
optimizers.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e42e998f-f1b1-4a27-a054-622c81b62d2b

📥 Commits

Reviewing files that changed from the base of the PR and between f7a27ef and 2980117.

📒 Files selected for processing (2)
  • src/axolotl/utils/schemas/validation.py
  • tests/e2e/test_optimizers.py

…pt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.
…pt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.
@OnePunchMonk
Copy link
Author

@CodeRabbit review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 8, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/e2e/test_optimizers.py (1)

308-340: Consider a small config factory for these optimizer E2Es.

This config block is now another near-copy of the setup used by the other optimizer tests, so future schema changes will keep fanning out across the file. A helper like build_optimizer_cfg(output_dir, optimizer, learning_rate, **overrides) would make these additions cheaper to maintain.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/e2e/test_optimizers.py` around lines 308 - 340, The cfg dictionary in
the optimizer E2E test is duplicated across tests and should be factored into a
small factory to reduce maintenance; create a helper function (e.g.,
build_optimizer_cfg(output_dir, optimizer_name, learning_rate, **overrides)) and
replace the inline cfg assignments with calls to build_optimizer_cfg in tests
like tests/e2e/test_optimizers.py, ensuring it sets the same keys (base_model,
model_type, tokenizer_type, sequence_len, load_in_8bit, adapter, lora_* params,
val_set_size, special_tokens, datasets, num_epochs, micro_batch_size,
gradient_accumulation_steps, output_dir, learning_rate, optimizer, optim_args,
max_steps, lr_scheduler, save_first_step) and accepts overrides to customize
per-test values.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 347-349: The assertion currently performs substring matching
against trainer.optimizer.optimizer.__class__.__name__; change it to require
exact equality so the test fails if a different class (e.g., FlashAdamW vs
flash_adam) is returned. Update the assertion to compare
trainer.optimizer.optimizer.__class__.__name__ directly to expected_class (use
==) and ensure expected_class holds the exact class name string produced by the
builder used in train(cfg=cfg, dataset_meta=dataset_meta); keep the surrounding
calls (train and check_model_output_exists) unchanged.

---

Nitpick comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 308-340: The cfg dictionary in the optimizer E2E test is
duplicated across tests and should be factored into a small factory to reduce
maintenance; create a helper function (e.g., build_optimizer_cfg(output_dir,
optimizer_name, learning_rate, **overrides)) and replace the inline cfg
assignments with calls to build_optimizer_cfg in tests like
tests/e2e/test_optimizers.py, ensuring it sets the same keys (base_model,
model_type, tokenizer_type, sequence_len, load_in_8bit, adapter, lora_* params,
val_set_size, special_tokens, datasets, num_epochs, micro_batch_size,
gradient_accumulation_steps, output_dir, learning_rate, optimizer, optim_args,
max_steps, lr_scheduler, save_first_step) and accepts overrides to customize
per-test values.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ccdc158f-f13c-4d31-8fea-b9ab828e87fb

📥 Commits

Reviewing files that changed from the base of the PR and between 2980117 and 9b36b14.

📒 Files selected for processing (1)
  • tests/e2e/test_optimizers.py

@OnePunchMonk
Copy link
Author

@CodeRabbit review

@OnePunchMonk OnePunchMonk marked this pull request as ready for review March 8, 2026 10:00
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 8, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/e2e/test_optimizers.py (1)

292-349: Well-structured parametrized test addressing all previous review concerns.

The test correctly:

  • Uses a standalone pytest function (not a unittest.TestCase method), allowing @pytest.mark.parametrize to work properly
  • Uses exact class-name matching (==) instead of substring matching
  • Includes the torch>=2.7.0 version requirement

One minor consistency consideration: the require_torch_2_7_0 decorator is already imported (line 20) and used by test_dion. You could reuse it for consistency, though the inline skipif is functionally equivalent.

,

♻️ Optional: Use existing decorator for consistency
+@require_torch_2_7_0
-@pytest.mark.skipif(
-    version.parse(torch.__version__) < version.parse("2.7.0"),
-    reason="test requires torch>=2.7.0",
-)
 `@pytest.mark.parametrize`(
     "optimizer_name,expected_class,learning_rate",

If you apply this change, you could also remove the unused imports torch and from packaging import version (lines 8-9).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/e2e/test_optimizers.py` around lines 292 - 349, The test uses an inline
pytest.mark.skipif to require torch>=2.7.0 but for consistency reuse the
existing require_torch_2_7_0 decorator already imported and used by test_dion:
replace the skipif on test_flash_optimizers with `@require_torch_2_7_0` and then
remove the now-unused imports (torch and version) at the top of the file; keep
the rest of test_flash_optimizers, including the parametrization and assertion,
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 292-349: The test uses an inline pytest.mark.skipif to require
torch>=2.7.0 but for consistency reuse the existing require_torch_2_7_0
decorator already imported and used by test_dion: replace the skipif on
test_flash_optimizers with `@require_torch_2_7_0` and then remove the now-unused
imports (torch and version) at the top of the file; keep the rest of
test_flash_optimizers, including the parametrization and assertion, unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ffa481ba-bfd3-4f5b-8206-396bff8e4cbc

📥 Commits

Reviewing files that changed from the base of the PR and between 9b36b14 and fc05aa0.

📒 Files selected for processing (1)
  • tests/e2e/test_optimizers.py

…fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args
@NanoCode012
Copy link
Collaborator

Hey @OnePunchMonk , since we won't be installing the package by default (and hence not on our CI), I believe your test would fail due to import error. If the test run and passes locally for you, I'm good with that. I also verified your changes look good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants