Skip to content

Commit 1447beb

Browse files
authored
make sure to validate the config before normalizing so defaults get set (axolotl-ai-cloud#2554)
* make sure to validate the config before normalizing so defaults get set * validation not needed for particular test * remove duplicate validations * set qlora correctly
1 parent 66f41ec commit 1447beb

18 files changed

+47
-17
lines changed

tests/e2e/integrations/test_cut_cross_entropy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from axolotl.common.datasets import load_datasets
99
from axolotl.train import train
1010
from axolotl.utils import get_pytorch_version
11-
from axolotl.utils.config import normalize_config, prepare_plugins
11+
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
1212
from axolotl.utils.dict import DictDefault
1313

1414
from ..utils import check_model_output_exists
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
5656
# pylint: disable=redefined-outer-name
5757
def test_llama_w_cce(self, min_cfg, temp_dir):
5858
cfg = DictDefault(min_cfg)
59+
cfg = validate_config(cfg)
5960
prepare_plugins(cfg)
6061
normalize_config(cfg)
6162
cli_args = TrainerCliArgs()
@@ -101,6 +102,7 @@ def test_qwen2_w_cce(self, temp_dir):
101102
"bf16": "auto",
102103
}
103104
)
105+
cfg = validate_config(cfg)
104106
prepare_plugins(cfg)
105107
normalize_config(cfg)
106108
cli_args = TrainerCliArgs()
@@ -129,6 +131,7 @@ def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
129131
attention_type: True,
130132
}
131133
)
134+
cfg = validate_config(cfg)
132135
prepare_plugins(cfg)
133136
normalize_config(cfg)
134137
cli_args = TrainerCliArgs()

tests/e2e/integrations/test_liger.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from axolotl.cli.args import TrainerCliArgs
66
from axolotl.common.datasets import load_datasets
77
from axolotl.train import train
8-
from axolotl.utils.config import normalize_config, prepare_plugins
8+
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
99
from axolotl.utils.dict import DictDefault
1010

1111
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,6 +54,7 @@ def test_llama_wo_flce(self, temp_dir):
5454
}
5555
)
5656
# pylint: disable=duplicate-code
57+
cfg = validate_config(cfg)
5758
prepare_plugins(cfg)
5859
normalize_config(cfg)
5960
cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ def test_llama_w_flce(self, temp_dir):
100101
}
101102
)
102103
# pylint: disable=duplicate-code
104+
cfg = validate_config(cfg)
103105
prepare_plugins(cfg)
104106
normalize_config(cfg)
105107
cli_args = TrainerCliArgs()

tests/e2e/patched/test_4d_multipack_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from axolotl.cli.args import TrainerCliArgs
1010
from axolotl.common.datasets import load_datasets
1111
from axolotl.train import train
12-
from axolotl.utils.config import normalize_config
12+
from axolotl.utils.config import normalize_config, validate_config
1313
from axolotl.utils.dict import DictDefault
1414

1515
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ def test_sdp_lora_packing(self, temp_dir):
6060
"fp16": True,
6161
}
6262
)
63+
cfg = validate_config(cfg)
6364
normalize_config(cfg)
6465
cli_args = TrainerCliArgs()
6566
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ def test_torch_lora_packing(self, temp_dir):
104105
"fp16": True,
105106
}
106107
)
108+
cfg = validate_config(cfg)
107109
normalize_config(cfg)
108110
cli_args = TrainerCliArgs()
109111
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_falcon_samplepack.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from axolotl.cli.args import TrainerCliArgs
1010
from axolotl.common.datasets import load_datasets
1111
from axolotl.train import train
12-
from axolotl.utils.config import normalize_config
12+
from axolotl.utils.config import normalize_config, validate_config
1313
from axolotl.utils.dict import DictDefault
1414

1515
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ def test_qlora(self, temp_dir):
6363
"bf16": "auto",
6464
}
6565
)
66+
cfg = validate_config(cfg)
6667
normalize_config(cfg)
6768
cli_args = TrainerCliArgs()
6869
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -103,6 +104,7 @@ def test_ft(self, temp_dir):
103104
"bf16": "auto",
104105
}
105106
)
107+
cfg = validate_config(cfg)
106108
normalize_config(cfg)
107109
cli_args = TrainerCliArgs()
108110
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_fused_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from axolotl.cli.args import TrainerCliArgs
1313
from axolotl.common.datasets import load_datasets
1414
from axolotl.train import train
15-
from axolotl.utils.config import normalize_config
15+
from axolotl.utils.config import normalize_config, validate_config
1616
from axolotl.utils.dict import DictDefault
1717

1818
from ..utils import check_model_output_exists, with_temp_dir
@@ -67,6 +67,7 @@ def test_fft_packing(self, temp_dir):
6767
cfg.bf16 = True
6868
else:
6969
cfg.fp16 = True
70+
cfg = validate_config(cfg)
7071
normalize_config(cfg)
7172
cli_args = TrainerCliArgs()
7273
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_llama_s2_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from axolotl.cli.args import TrainerCliArgs
1212
from axolotl.common.datasets import load_datasets
1313
from axolotl.train import train
14-
from axolotl.utils.config import normalize_config
14+
from axolotl.utils.config import normalize_config, validate_config
1515
from axolotl.utils.dict import DictDefault
1616

1717
from ..utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,7 @@ def test_lora_s2_attn(self, temp_dir):
6565
}
6666
)
6767

68+
cfg = validate_config(cfg)
6869
normalize_config(cfg)
6970
cli_args = TrainerCliArgs()
7071
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,6 +106,7 @@ def test_fft_s2_attn(self, temp_dir):
105106
}
106107
)
107108

109+
cfg = validate_config(cfg)
108110
normalize_config(cfg)
109111
cli_args = TrainerCliArgs()
110112
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_lora_llama_multipack.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from axolotl.cli.args import TrainerCliArgs
1313
from axolotl.common.datasets import load_datasets
1414
from axolotl.train import train
15-
from axolotl.utils.config import normalize_config
15+
from axolotl.utils.config import normalize_config, validate_config
1616
from axolotl.utils.dict import DictDefault
1717

1818
from ..utils import check_model_output_exists, with_temp_dir
@@ -70,6 +70,7 @@ def test_lora_packing(self, temp_dir):
7070
else:
7171
cfg.fp16 = True
7272

73+
cfg = validate_config(cfg)
7374
normalize_config(cfg)
7475
cli_args = TrainerCliArgs()
7576
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -120,6 +121,7 @@ def test_lora_gptq_packed(self, temp_dir):
120121
"lr_scheduler": "cosine",
121122
}
122123
)
124+
cfg = validate_config(cfg)
123125
normalize_config(cfg)
124126
cli_args = TrainerCliArgs()
125127
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_mistral_samplepack.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from axolotl.cli.args import TrainerCliArgs
1010
from axolotl.common.datasets import load_datasets
1111
from axolotl.train import train
12-
from axolotl.utils.config import normalize_config
12+
from axolotl.utils.config import normalize_config, validate_config
1313
from axolotl.utils.dict import DictDefault
1414

1515
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ def test_lora_packing(self, temp_dir):
6363
"bf16": "auto",
6464
}
6565
)
66+
cfg = validate_config(cfg)
6667
normalize_config(cfg)
6768
cli_args = TrainerCliArgs()
6869
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ def test_ft_packing(self, temp_dir):
104105
"bf16": "auto",
105106
}
106107
)
108+
cfg = validate_config(cfg)
107109
normalize_config(cfg)
108110
cli_args = TrainerCliArgs()
109111
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_mixtral_samplepack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from axolotl.cli.args import TrainerCliArgs
1010
from axolotl.common.datasets import load_datasets
1111
from axolotl.train import train
12-
from axolotl.utils.config import normalize_config
12+
from axolotl.utils.config import normalize_config, validate_config
1313
from axolotl.utils.dict import DictDefault
1414

1515
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ def test_qlora(self, temp_dir):
6060
"bf16": "auto",
6161
}
6262
)
63+
cfg = validate_config(cfg)
6364
normalize_config(cfg)
6465
cli_args = TrainerCliArgs()
6566
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

tests/e2e/patched/test_model_patches.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import transformers
88

9-
from axolotl.utils.config import normalize_config
9+
from axolotl.utils.config import normalize_config, validate_config
1010
from axolotl.utils.dict import DictDefault
1111
from axolotl.utils.models import load_model, load_tokenizer
1212

@@ -47,6 +47,7 @@ def test_mixtral_multipack(self, temp_dir):
4747
"eval_steps": 10,
4848
}
4949
)
50+
cfg = validate_config(cfg)
5051
normalize_config(cfg)
5152
tokenizer = load_tokenizer(cfg)
5253
load_model(cfg, tokenizer, inference=False)
@@ -79,6 +80,7 @@ def test_mistral_multipack(self, temp_dir):
7980
"eval_steps": 10,
8081
}
8182
)
83+
cfg = validate_config(cfg)
8284
normalize_config(cfg)
8385
tokenizer = load_tokenizer(cfg)
8486
load_model(cfg, tokenizer, inference=False)

0 commit comments

Comments
 (0)