feat: add support and end-to-end tests for multiple custom optimizers…#3457
feat: add support and end-to-end tests for multiple custom optimizers…#3457OnePunchMonk wants to merge 8 commits intoaxolotl-ai-cloud:mainfrom
Conversation
… including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds five Flash-based optimizers (flash_adamw, flash_adam, flash_sgd, flash_sgdw, flash_lion): maps identifiers to flashoptim classes in the core builder, extends the optimizer enum, enforces DeepSpeed/FSDP validation for flash optimizers, and adds an end-to-end parametrized test covering each optimizer. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/builders/base.py`:
- Around line 356-379: Add a validation in the same validation module that
contains the existing Muon optimizer check to reject FlashOptimizers when FSDP1
is selected: detect optimizer values "flash_adamw", "flash_adam", "flash_sgd",
"flash_sgdw", "flash_lion" (or any optimizer name starting with "flash_") and if
the fsdp configuration/type equals "fsdp1" raise the same kind of
ValidationError used by the Muon check, with a clear message that FlashOptim
does not support FSDP1 and only supports DDP and FSDP2; follow the exact pattern
and placement used by the Muon optimizer validation block so the new check runs
during schema validation.
In `@tests/e2e/test_optimizers.py`:
- Around line 286-328: Add the `@require_torch_2_7_0` decorator to
test_flash_adamw and add four new tests mirroring its structure for the other
flash optimizers: test_flash_adam, test_flash_sgd, test_flash_sgdw, and
test_flash_lion. For each test (same pattern as test_flash_adamw) set
cfg["optimizer"] to the appropriate string ("flash_adam", "flash_sgd",
"flash_sgdw", "flash_lion"), run
validate_config/normalize_config/load_datasets/train, and assert the
trainer.optimizer.optimizer.__class__.__name__ contains the corresponding class
identifier (e.g., "FlashAdam", "FlashSGD", "FlashSGDW", "FlashLion") to verify
the correct optimizer is instantiated. Ensure each new test uses `@with_temp_dir`
and the `@require_torch_2_7_0` decorator consistent with existing
version-dependent tests like test_dion and test_came_pytorch.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c99e1e60-b323-476d-a422-5a8f77a5a892
📒 Files selected for processing (3)
src/axolotl/core/builders/base.pysrc/axolotl/utils/schemas/enums.pytests/e2e/test_optimizers.py
…rious custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 811-821: In check_flashoptim_deepspeed_fsdp, guard against
optimizer being None by coercing it to a string before calling startswith (e.g.,
optimizer = data.get("optimizer") or ""), and resolve FSDP version by checking
both top-level "fsdp_version" and "fsdp_config" keys (look for
data.get("fsdp_version") or data.get("fsdp_config", {}).get("version") or
data.get("fsdp_config", {}).get("fsdp_version")), then compare str(fsdp_version)
to "2"; update the logic in check_flashoptim_deepspeed_fsdp to use these safe
lookups so valid FSDP2 configs with only fsdp_config.version pass and optimizer:
null doesn't raise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b891219e-2d85-439d-a8cf-b8e32b7ad40b
📒 Files selected for processing (2)
src/axolotl/utils/schemas/validation.pytests/e2e/test_optimizers.py
…d training configurations.
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)
819-824: Keep FSDP version resolution consistent across optimizer validators.This block now accepts
fsdp_config.version, butcheck_muon_deepspeed_fsdpstill only reads top-levelfsdp_version/fsdp_config.fsdp_version. Equivalent FSDP2 configs will therefore validate for FlashOptim and fail for Muon depending on the selected optimizer. Please extract the version lookup into a shared helper, or normalize once before both validators, so these compatibility rules do not drift.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/validation.py` around lines 819 - 824, The FSDP version lookup is inconsistent between validators; extract the logic used around fsdp_config/fsdp_version into a shared helper (e.g., get_fsdp_version(data) or normalize_fsdp_in_config(data)) and use it from both check_muon_deepspeed_fsdp and the FlashOptim validator so both read the same resolved version (checking data["fsdp_version"], fsdp_config["version"], fsdp_config["fsdp_version"], then defaulting to 1). Update places that currently compute fsdp_version inline to call the helper or run the normalizer once before validation to ensure consistent behavior across optimizers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 289-301: The test method test_flash_optimizers is a
unittest.TestCase-style method decorated with pytest.mark.parametrize, which
pytest will not apply to TestCase methods; update the test to be a plain pytest
function or remove parametrize and iterate manually inside
test_flash_optimizers. Specifically, either convert the method
test_flash_optimizers (and its decorators `@with_temp_dir`, `@require_torch_2_7_0`)
into a top-level function test_flash_optimizers(...) so pytest can inject
optimizer_name/expected_class/learning_rate, or keep it as a TestCase method and
replace the `@pytest.mark.parametrize` block by a for-loop over the list of tuples
(optimizer_name, expected_class, learning_rate) inside test_flash_optimizers to
run each case.
---
Nitpick comments:
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 819-824: The FSDP version lookup is inconsistent between
validators; extract the logic used around fsdp_config/fsdp_version into a shared
helper (e.g., get_fsdp_version(data) or normalize_fsdp_in_config(data)) and use
it from both check_muon_deepspeed_fsdp and the FlashOptim validator so both read
the same resolved version (checking data["fsdp_version"],
fsdp_config["version"], fsdp_config["fsdp_version"], then defaulting to 1).
Update places that currently compute fsdp_version inline to call the helper or
run the normalizer once before validation to ensure consistent behavior across
optimizers.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e42e998f-f1b1-4a27-a054-622c81b62d2b
📒 Files selected for processing (2)
src/axolotl/utils/schemas/validation.pytests/e2e/test_optimizers.py
…pt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.
…pt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.
…p_path and skipif
|
@CodeRabbit review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/e2e/test_optimizers.py (1)
308-340: Consider a small config factory for these optimizer E2Es.This config block is now another near-copy of the setup used by the other optimizer tests, so future schema changes will keep fanning out across the file. A helper like
build_optimizer_cfg(output_dir, optimizer, learning_rate, **overrides)would make these additions cheaper to maintain.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/test_optimizers.py` around lines 308 - 340, The cfg dictionary in the optimizer E2E test is duplicated across tests and should be factored into a small factory to reduce maintenance; create a helper function (e.g., build_optimizer_cfg(output_dir, optimizer_name, learning_rate, **overrides)) and replace the inline cfg assignments with calls to build_optimizer_cfg in tests like tests/e2e/test_optimizers.py, ensuring it sets the same keys (base_model, model_type, tokenizer_type, sequence_len, load_in_8bit, adapter, lora_* params, val_set_size, special_tokens, datasets, num_epochs, micro_batch_size, gradient_accumulation_steps, output_dir, learning_rate, optimizer, optim_args, max_steps, lr_scheduler, save_first_step) and accepts overrides to customize per-test values.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 347-349: The assertion currently performs substring matching
against trainer.optimizer.optimizer.__class__.__name__; change it to require
exact equality so the test fails if a different class (e.g., FlashAdamW vs
flash_adam) is returned. Update the assertion to compare
trainer.optimizer.optimizer.__class__.__name__ directly to expected_class (use
==) and ensure expected_class holds the exact class name string produced by the
builder used in train(cfg=cfg, dataset_meta=dataset_meta); keep the surrounding
calls (train and check_model_output_exists) unchanged.
---
Nitpick comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 308-340: The cfg dictionary in the optimizer E2E test is
duplicated across tests and should be factored into a small factory to reduce
maintenance; create a helper function (e.g., build_optimizer_cfg(output_dir,
optimizer_name, learning_rate, **overrides)) and replace the inline cfg
assignments with calls to build_optimizer_cfg in tests like
tests/e2e/test_optimizers.py, ensuring it sets the same keys (base_model,
model_type, tokenizer_type, sequence_len, load_in_8bit, adapter, lora_* params,
val_set_size, special_tokens, datasets, num_epochs, micro_batch_size,
gradient_accumulation_steps, output_dir, learning_rate, optimizer, optim_args,
max_steps, lr_scheduler, save_first_step) and accepts overrides to customize
per-test values.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ccdc158f-f13c-4d31-8fea-b9ab828e87fb
📒 Files selected for processing (1)
tests/e2e/test_optimizers.py
|
@CodeRabbit review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/e2e/test_optimizers.py (1)
292-349: Well-structured parametrized test addressing all previous review concerns.The test correctly:
- Uses a standalone pytest function (not a
unittest.TestCasemethod), allowing@pytest.mark.parametrizeto work properly- Uses exact class-name matching (
==) instead of substring matching- Includes the torch>=2.7.0 version requirement
One minor consistency consideration: the
require_torch_2_7_0decorator is already imported (line 20) and used bytest_dion. You could reuse it for consistency, though the inlineskipifis functionally equivalent.,
♻️ Optional: Use existing decorator for consistency
+@require_torch_2_7_0 -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse("2.7.0"), - reason="test requires torch>=2.7.0", -) `@pytest.mark.parametrize`( "optimizer_name,expected_class,learning_rate",If you apply this change, you could also remove the unused imports
torchandfrom packaging import version(lines 8-9).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/test_optimizers.py` around lines 292 - 349, The test uses an inline pytest.mark.skipif to require torch>=2.7.0 but for consistency reuse the existing require_torch_2_7_0 decorator already imported and used by test_dion: replace the skipif on test_flash_optimizers with `@require_torch_2_7_0` and then remove the now-unused imports (torch and version) at the top of the file; keep the rest of test_flash_optimizers, including the parametrization and assertion, unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/e2e/test_optimizers.py`:
- Around line 292-349: The test uses an inline pytest.mark.skipif to require
torch>=2.7.0 but for consistency reuse the existing require_torch_2_7_0
decorator already imported and used by test_dion: replace the skipif on
test_flash_optimizers with `@require_torch_2_7_0` and then remove the now-unused
imports (torch and version) at the top of the file; keep the rest of
test_flash_optimizers, including the parametrization and assertion, unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ffa481ba-bfd3-4f5b-8206-396bff8e4cbc
📒 Files selected for processing (1)
tests/e2e/test_optimizers.py
…fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args
|
Hey @OnePunchMonk , since we won't be installing the package by default (and hence not on our CI), I believe your test would fail due to import error. If the test run and passes locally for you, I'm good with that. I also verified your changes look good. |
… including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.
Feature support with test case for flashoptim's FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion
Description
Added features for 'flashoptim' usage
Motivation and Context
#3451
How has this been tested?
Still untested
AI Usage Disclaimer
Claude
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Validation
Tests