Skip to content

Commit a028548

Browse files
committed
config validator and tests
1 parent 5a3672b commit a028548

File tree

4 files changed

+480
-4
lines changed

4 files changed

+480
-4
lines changed

tabs/training_tab.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44
import gradio as gr
55
from typing import OrderedDict
6-
from config import Config, global_config
6+
from config import Config
77

88
from run_trainer import RunTrainer
99
from tabs import general_tab
1010
from tabs.tab import Tab
11+
from trainer_config_validator import TrainerValidator
1112

1213
properties = OrderedDict()
1314

@@ -69,16 +70,21 @@ def run_trainer(self, *args):
6970
key = keys_list[index]
7071
properties[key].value = properties_values[index]
7172
config.set(key, properties_values[index])
73+
config.set('path_to_finetrainers', general_tab.properties['path_to_finetrainers'].value)
74+
75+
config_validator = TrainerValidator(config)
76+
try:
77+
config_validator.validate()
78+
except Exception as e:
79+
return str(e), None
7280

7381
output_path = os.path.join(properties['output_dir'].value, "config")
7482
os.makedirs(output_path, exist_ok=True)
7583
self.save_edits(os.path.join(output_path, "config_{}.yaml".format(time)), *properties_values)
7684

7785
log_file = os.path.join(output_path, "log_{}.txt".format(time))
7886

79-
if not general_tab.properties['path_to_finetrainers'].value:
80-
return "Please set the path to finetrainers in General Settings"
81-
result = self.trainer.run(config, general_tab.properties['path_to_finetrainers'].value, log_file)
87+
result = self.trainer.run(config, config.get('path_to_finetrainers'), log_file)
8288
self.trainer.running = False
8389
if isinstance(result, str):
8490
return result, log_file

test/__init__.py

Whitespace-only changes.
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import os
2+
import pytest
3+
from unittest.mock import patch
4+
5+
from trainer_config_validator import TrainerValidator
6+
7+
@pytest.fixture
8+
def valid_config():
9+
return {
10+
'finetrainers_path': '/path/to/finetrainers',
11+
'accelerate_config': 'config1',
12+
'batch_size': 32,
13+
'beta1': 0.9,
14+
'beta2': 0.999,
15+
'caption_column': 'captions.txt',
16+
'caption_dropout_p': 0.1,
17+
'checkpointing_limit': 5,
18+
'checkpointing_steps': 1000,
19+
'data_root': '/path/to/data',
20+
'dataloader_num_workers': 4,
21+
'epsilon': 1e-8,
22+
'gpu_ids': '0,1',
23+
'gradient_accumulation_steps': 2,
24+
'gradient_checkpointing': True,
25+
'id_token': 'token123',
26+
'lora_alpha': 128,
27+
'lr': 0.001,
28+
'lr_num_cycles': 10,
29+
'lr_scheduler': 'scheduler1',
30+
'lr_warmup_steps': 500,
31+
'max_grad_norm': 1.0,
32+
'mixed_precision': 'fp16',
33+
'model_name': 'model_v1',
34+
'nccl_timeout': 60,
35+
'optimizer': 'adam',
36+
'pretrained_model_name_or_path': 'pretrained_model',
37+
'rank': 0,
38+
'seed': 42,
39+
'target_modules': 'module1',
40+
'tracker_name': 'tracker',
41+
'train_steps': 10000,
42+
'training_type': 'type1',
43+
'validation_steps': 100,
44+
'video_column': 'videos.txt',
45+
'video_resolution_buckets': '24x480x720',
46+
'weight_decay': 0.01
47+
}
48+
49+
@pytest.fixture
50+
def trainer_validator(valid_config):
51+
return TrainerValidator(valid_config)
52+
53+
def test_valid_config(valid_config):
54+
trainer_validator = TrainerValidator(valid_config)
55+
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
56+
trainer_validator.validate()
57+
58+
def test_validate_finetrainers_path_invalid(trainer_validator):
59+
trainer_validator.config['finetrainers_path'] = 123
60+
with pytest.raises(ValueError, match="finetrainers_path must be a string"):
61+
trainer_validator.validate_finetrainers_path()
62+
63+
def test_validate_finetrainers_path_valid(trainer_validator):
64+
with patch('os.path.isfile', return_value=True):
65+
trainer_validator.config['finetrainers_path'] = '/path/to/finetrainers'
66+
trainer_validator.validate_finetrainers_path()
67+
68+
def test_validate_accelerate_config_invalid(trainer_validator):
69+
trainer_validator.config['accelerate_config'] = []
70+
with pytest.raises(ValueError, match="accelerate_config must be a string"):
71+
trainer_validator.validate_accelerate_config()
72+
73+
def test_validate_batch_size_invalid(trainer_validator):
74+
trainer_validator.config['batch_size'] = 'not_an_int'
75+
with pytest.raises(ValueError, match="batch_size must be an integer"):
76+
trainer_validator.validate_batch_size()
77+
78+
def test_validate_beta1_invalid(trainer_validator):
79+
trainer_validator.config['beta1'] = 'not_a_float'
80+
with pytest.raises(ValueError, match="beta1 must be a float"):
81+
trainer_validator.validate_beta1()
82+
83+
def test_validate_beta2_invalid(trainer_validator):
84+
trainer_validator.config['beta2'] = 'not_a_float'
85+
with pytest.raises(ValueError, match="beta2 must be a float"):
86+
trainer_validator.validate_beta2()
87+
88+
def test_validate_caption_column_invalid(trainer_validator):
89+
trainer_validator.config['caption_column'] = 123
90+
with pytest.raises(ValueError, match="caption_column must be a string"):
91+
trainer_validator.validate_caption_column()
92+
93+
def test_validate_caption_dropout_p_invalid(trainer_validator):
94+
trainer_validator.config['caption_dropout_p'] = 'not_a_float'
95+
with pytest.raises(ValueError, match="caption_dropout_p must be a float"):
96+
trainer_validator.validate_caption_dropout_p()
97+
98+
def test_validate_checkpointing_limit_invalid(trainer_validator):
99+
trainer_validator.config['checkpointing_limit'] = 'not_an_int'
100+
with pytest.raises(ValueError, match="checkpointing_limit must be an integer"):
101+
trainer_validator.validate_checkpointing_limit()
102+
103+
def test_validate_checkpointing_steps_invalid(trainer_validator):
104+
trainer_validator.config['checkpointing_steps'] = 'not_an_int'
105+
with pytest.raises(ValueError, match="checkpointing_steps must be an integer"):
106+
trainer_validator.validate_checkpointing_steps()
107+
108+
def test_validate_data_root_invalid(trainer_validator):
109+
trainer_validator.config['data_root'] = '/invalid/path'
110+
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):
111+
trainer_validator.validate_data_root()
112+
113+
def test_validate_data_root_valid(trainer_validator):
114+
with patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
115+
trainer_validator.config['data_root'] = '/path/to/data'
116+
trainer_validator.validate_data_root()
117+
118+
def test_validate_dataloader_num_workers_invalid(trainer_validator):
119+
trainer_validator.config['dataloader_num_workers'] = 'not_an_int'
120+
with pytest.raises(ValueError, match="dataloader_num_workers must be an integer"):
121+
trainer_validator.validate_dataloader_num_workers()
122+
123+
def test_validate_epsilon_invalid(trainer_validator):
124+
trainer_validator.config['epsilon'] = 'not_a_float'
125+
with pytest.raises(ValueError, match="epsilon must be a float"):
126+
trainer_validator.validate_epsilon()
127+
128+
def test_validate_gpu_ids_invalid(trainer_validator):
129+
trainer_validator.config['gpu_ids'] = 123
130+
with pytest.raises(ValueError, match="gpu_ids must be a string"):
131+
trainer_validator.validate_gpu_ids()
132+
133+
def test_validate_gradient_accumulation_steps_invalid(trainer_validator):
134+
trainer_validator.config['gradient_accumulation_steps'] = 'not_an_int'
135+
with pytest.raises(ValueError, match="gradient_accumulation_steps must be an integer"):
136+
trainer_validator.validate_gradient_accumulation_steps()
137+
138+
def test_validate_id_token_invalid(trainer_validator):
139+
trainer_validator.config['id_token'] = 123
140+
with pytest.raises(ValueError, match="id_token must be a string"):
141+
trainer_validator.validate_id_token()
142+
143+
def test_validate_lora_alpha_invalid(trainer_validator):
144+
trainer_validator.config['lora_alpha'] = 'not_an_int'
145+
with pytest.raises(ValueError, match="lora_alpha must be an integer"):
146+
trainer_validator.validate_lora_alpha()
147+
148+
def test_validate_lr_invalid(trainer_validator):
149+
trainer_validator.config['lr'] = 'not_a_float'
150+
with pytest.raises(ValueError, match="lr must be a float"):
151+
trainer_validator.validate_lr()
152+
153+
def test_validate_lr_num_cycles_invalid(trainer_validator):
154+
trainer_validator.config['lr_num_cycles'] = 'not_an_int'
155+
with pytest.raises(ValueError, match="lr_num_cycles must be an integer"):
156+
trainer_validator.validate_lr_num_cycles()
157+
158+
def test_validate_lr_scheduler_invalid(trainer_validator):
159+
trainer_validator.config['lr_scheduler'] = ''
160+
with pytest.raises(ValueError, match="lr_scheduler must be a string"):
161+
trainer_validator.validate_lr_scheduler()
162+
163+
def test_validate_lr_warmup_steps_invalid(trainer_validator):
164+
trainer_validator.config['lr_warmup_steps'] = 'not_an_int'
165+
with pytest.raises(ValueError, match="lr_warmup_steps must be an integer"):
166+
trainer_validator.validate_lr_warmup_steps()
167+
168+
def test_validate_max_grad_norm_invalid(trainer_validator):
169+
trainer_validator.config['max_grad_norm'] = 'not_a_float'
170+
with pytest.raises(ValueError, match="max_grad_norm must be a float"):
171+
trainer_validator.validate_max_grad_norm()
172+
173+
def test_validate_mixed_precision_invalid(trainer_validator):
174+
trainer_validator.config['mixed_precision'] = 123
175+
with pytest.raises(ValueError, match="mixed_precision must be a string"):
176+
trainer_validator.validate_mixed_precision()
177+
178+
def test_validate_model_name_invalid(trainer_validator):
179+
trainer_validator.config['model_name'] = 123
180+
with pytest.raises(ValueError, match="model_name must be a string"):
181+
trainer_validator.validate_model_name()
182+
183+
def test_validate_nccl_timeout_invalid(trainer_validator):
184+
trainer_validator.config['nccl_timeout'] = 'not_an_int'
185+
with pytest.raises(ValueError, match="nccl_timeout must be an integer"):
186+
trainer_validator.validate_nccl_timeout()
187+
188+
def test_validate_optimizer_invalid(trainer_validator):
189+
trainer_validator.config['optimizer'] = 123
190+
with pytest.raises(ValueError, match="optimizer must be a string"):
191+
trainer_validator.validate_optimizer()
192+
193+
def test_validate_pretrained_model_name_or_path_invalid(trainer_validator):
194+
trainer_validator.config['pretrained_model_name_or_path'] = 123
195+
with pytest.raises(ValueError, match="pretrained_model_name_or_path must be set"):
196+
trainer_validator.validate_pretrained_model_name_or_path()
197+
198+
def test_validate_rank_invalid(trainer_validator):
199+
trainer_validator.config['rank'] = 'not_an_int'
200+
with pytest.raises(ValueError, match="rank must be an integer"):
201+
trainer_validator.validate_rank()
202+
203+
def test_validate_seed_invalid(trainer_validator):
204+
trainer_validator.config['seed'] = 'not_an_int'
205+
with pytest.raises(ValueError, match="seed must be an integer"):
206+
trainer_validator.validate_seed()
207+
208+
def test_validate_target_modules_invalid(trainer_validator):
209+
trainer_validator.config['target_modules'] = 123
210+
with pytest.raises(ValueError, match="target_modules must be a string"):
211+
trainer_validator.validate_target_modules()
212+
213+
def test_validate_tracker_name_invalid(trainer_validator):
214+
trainer_validator.config['tracker_name'] = 123
215+
with pytest.raises(ValueError, match="tracker_name must be a string"):
216+
trainer_validator.validate_tracker_name()
217+
218+
def test_validate_train_steps_invalid(trainer_validator):
219+
trainer_validator.config['train_steps'] = 'not_an_int'
220+
with pytest.raises(ValueError, match="train_steps must be an integer"):
221+
trainer_validator.validate_train_steps()
222+
223+
def test_validate_training_type_invalid(trainer_validator):
224+
trainer_validator.config['training_type'] = 123
225+
with pytest.raises(ValueError, match="training_type must be a string"):
226+
trainer_validator.validate_training_type()
227+
228+
def test_validate_validation_steps_invalid(trainer_validator):
229+
trainer_validator.config['validation_steps'] = 'not_an_int'
230+
with pytest.raises(ValueError, match="validation_steps must be an integer"):
231+
trainer_validator.validate_validation_steps()
232+
233+
def test_validate_video_column_invalid(trainer_validator):
234+
trainer_validator.config['video_column'] = 123
235+
with pytest.raises(ValueError, match="video_column must be a string"):
236+
trainer_validator.validate_video_column()
237+
238+
def test_validate_video_resolution_buckets_invalid(trainer_validator):
239+
trainer_validator.config['video_resolution_buckets'] = 123
240+
with pytest.raises(ValueError, match="video_resolution_buckets must be a string"):
241+
trainer_validator.validate_video_resolution_buckets()
242+
trainer_validator.config['video_resolution_buckets'] = '720p,1080p,4k'
243+
with pytest.raises(ValueError, match=f"Each bucket must have the format '<frames>x<height>x<width>', but got {trainer_validator.config['video_resolution_buckets']}"):
244+
trainer_validator.validate_video_resolution_buckets()
245+
246+
def test_validate_video_resolution_buckets_valid(trainer_validator):
247+
trainer_validator.config['video_resolution_buckets'] = '24x480x720'
248+
trainer_validator.validate_video_resolution_buckets()
249+
250+
trainer_validator.config['video_resolution_buckets'] = '8x320x512 24x480x720 30x720x1280'
251+
trainer_validator.validate_video_resolution_buckets()
252+
253+
def test_validate_weight_decay_invalid(trainer_validator):
254+
trainer_validator.config['weight_decay'] = 'not_a_float'
255+
with pytest.raises(ValueError, match="weight_decay must be a float"):
256+
trainer_validator.validate_weight_decay()

0 commit comments

Comments
 (0)