Skip to content

Commit ecac731

Browse files
authored
auto-enable lora kernels where possible (axolotl-ai-cloud#2589)
* auto-enable lora kernels where possible * test * revert change to example yaml * naming * remove print * slight logic change
1 parent 742fef4 commit ecac731

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

src/axolotl/utils/schemas/config.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,54 @@ def check_multigpu_lora_kernels(cls, data):
13151315
)
13161316
return data
13171317

1318+
@model_validator(mode="before")
1319+
@classmethod
1320+
def check_auto_enable_lora_kernels(cls, data):
1321+
# Only proceed if using LoRA or QLoRA adapter
1322+
if data.get("adapter") in ["lora", "qlora"]:
1323+
# Skip if already set, using unsloth optimizations, or using 8-bit
1324+
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
1325+
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
1326+
if (
1327+
any(data.get(k) is not None for k in kernel_fields)
1328+
or any(data.get(k) for k in unsloth_fields)
1329+
or data.get("adapter") == "lora"
1330+
and data.get("load_in_8bit")
1331+
):
1332+
return data
1333+
1334+
# Check multi-GPU compatibility
1335+
capabilities = data.get("capabilities")
1336+
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
1337+
is_fsdp = data.get("fsdp") is not None
1338+
is_fsdp2 = (
1339+
data.get("fsdp_config") is not None
1340+
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
1341+
)
1342+
1343+
if (
1344+
not is_multi_gpu
1345+
or (is_multi_gpu and not is_fsdp)
1346+
or (is_multi_gpu and is_fsdp2)
1347+
):
1348+
# Auto-enable kernels if not explicitly set by user
1349+
if data.get("lora_mlp_kernel") is None:
1350+
data["lora_mlp_kernel"] = True
1351+
1352+
if data.get("lora_qkv_kernel") is None:
1353+
data["lora_qkv_kernel"] = True
1354+
1355+
if data.get("lora_o_kernel") is None:
1356+
data["lora_o_kernel"] = True
1357+
1358+
LOG.warning(
1359+
"Auto-enabling LoRA kernel optimizations for faster training. "
1360+
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
1361+
+ "See https://docs.axolotl.ai/docs/lora_optims.html for more info."
1362+
)
1363+
1364+
return data
1365+
13181366
@model_validator(mode="before")
13191367
@classmethod
13201368
def check_adopt_torch_version(cls, data):

tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
# pylint: disable=redefined-outer-name
44

5+
from pathlib import Path
6+
57
import pytest
68
import torch
9+
import yaml
710
from accelerate.state import PartialState
811
from peft import PeftModelForCausalLM, get_peft_config
912
from transformers import AutoModelForCausalLM, LlamaForCausalLM
1013
from transformers.models.llama.configuration_llama import LlamaConfig
1114
from transformers.models.llama.modeling_llama import LlamaAttention
1215
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
1316

17+
from axolotl.cli.config import load_cfg
1418
from axolotl.kernels.lora import (
1519
apply_lora_mlp_geglu,
1620
apply_lora_mlp_swiglu,
@@ -421,3 +425,42 @@ def test_kernel_training_integration():
421425
# Verify correct activation function
422426
layer = model.model.model.layers[0]
423427
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
428+
429+
430+
def test_kernel_training_integration_auto_enable(temp_dir):
431+
"""Test model loading with auto-enabled kernel patches."""
432+
# Create minimal config without explicitly setting kernel options
433+
cfg = DictDefault(
434+
{
435+
"base_model": "HuggingFaceTB/SmolLM2-135M",
436+
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
437+
"learning_rate": 0.000001,
438+
"datasets": [
439+
{
440+
"path": "mhenrichsen/alpaca_2k_test",
441+
"type": "alpaca",
442+
}
443+
],
444+
"micro_batch_size": 1,
445+
"gradient_accumulation_steps": 1,
446+
"adapter": "lora",
447+
"lora_r": 8,
448+
"lora_alpha": 16,
449+
"lora_dropout": 0.0,
450+
"lora_target_linear": True,
451+
"sequence_len": 1024,
452+
}
453+
)
454+
455+
# Write cfg to yaml file
456+
path = Path(temp_dir) / "config.yaml"
457+
with open(path, "w", encoding="utf-8") as fout:
458+
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
459+
460+
# Load config
461+
cfg = load_cfg(str(path))
462+
463+
# Verify kernel options were auto-enabled in the config
464+
assert cfg.lora_mlp_kernel is True
465+
assert cfg.lora_qkv_kernel is True
466+
assert cfg.lora_o_kernel is True

0 commit comments

Comments
 (0)