1+ import os
2+ import pytest
3+ from unittest .mock import patch
4+
5+ from trainer_config_validator import TrainerValidator
6+
7+ @pytest .fixture
8+ def valid_config ():
9+ return {
10+ 'finetrainers_path' : '/path/to/finetrainers' ,
11+ 'accelerate_config' : 'config1' ,
12+ 'batch_size' : 32 ,
13+ 'beta1' : 0.9 ,
14+ 'beta2' : 0.999 ,
15+ 'caption_column' : 'captions.txt' ,
16+ 'caption_dropout_p' : 0.1 ,
17+ 'checkpointing_limit' : 5 ,
18+ 'checkpointing_steps' : 1000 ,
19+ 'data_root' : '/path/to/data' ,
20+ 'dataloader_num_workers' : 4 ,
21+ 'epsilon' : 1e-8 ,
22+ 'gpu_ids' : '0,1' ,
23+ 'gradient_accumulation_steps' : 2 ,
24+ 'gradient_checkpointing' : True ,
25+ 'id_token' : 'token123' ,
26+ 'lora_alpha' : 128 ,
27+ 'lr' : 0.001 ,
28+ 'lr_num_cycles' : 10 ,
29+ 'lr_scheduler' : 'scheduler1' ,
30+ 'lr_warmup_steps' : 500 ,
31+ 'max_grad_norm' : 1.0 ,
32+ 'mixed_precision' : 'fp16' ,
33+ 'model_name' : 'model_v1' ,
34+ 'nccl_timeout' : 60 ,
35+ 'optimizer' : 'adam' ,
36+ 'pretrained_model_name_or_path' : 'pretrained_model' ,
37+ 'rank' : 0 ,
38+ 'seed' : 42 ,
39+ 'target_modules' : 'module1' ,
40+ 'tracker_name' : 'tracker' ,
41+ 'train_steps' : 10000 ,
42+ 'training_type' : 'type1' ,
43+ 'validation_steps' : 100 ,
44+ 'video_column' : 'videos.txt' ,
45+ 'video_resolution_buckets' : '24x480x720' ,
46+ 'weight_decay' : 0.01
47+ }
48+
49+ @pytest .fixture
50+ def trainer_validator (valid_config ):
51+ return TrainerValidator (valid_config )
52+
53+ def test_valid_config (valid_config ):
54+ trainer_validator = TrainerValidator (valid_config )
55+ with patch ('os.path.isfile' , return_value = True ), patch ('os.path.exists' , return_value = True ), patch ('os.path.isdir' , return_value = True ):
56+ trainer_validator .validate ()
57+
58+ def test_validate_finetrainers_path_invalid (trainer_validator ):
59+ trainer_validator .config ['finetrainers_path' ] = 123
60+ with pytest .raises (ValueError , match = "finetrainers_path must be a string" ):
61+ trainer_validator .validate_finetrainers_path ()
62+
63+ def test_validate_finetrainers_path_valid (trainer_validator ):
64+ with patch ('os.path.isfile' , return_value = True ):
65+ trainer_validator .config ['finetrainers_path' ] = '/path/to/finetrainers'
66+ trainer_validator .validate_finetrainers_path ()
67+
68+ def test_validate_accelerate_config_invalid (trainer_validator ):
69+ trainer_validator .config ['accelerate_config' ] = []
70+ with pytest .raises (ValueError , match = "accelerate_config must be a string" ):
71+ trainer_validator .validate_accelerate_config ()
72+
73+ def test_validate_batch_size_invalid (trainer_validator ):
74+ trainer_validator .config ['batch_size' ] = 'not_an_int'
75+ with pytest .raises (ValueError , match = "batch_size must be an integer" ):
76+ trainer_validator .validate_batch_size ()
77+
78+ def test_validate_beta1_invalid (trainer_validator ):
79+ trainer_validator .config ['beta1' ] = 'not_a_float'
80+ with pytest .raises (ValueError , match = "beta1 must be a float" ):
81+ trainer_validator .validate_beta1 ()
82+
83+ def test_validate_beta2_invalid (trainer_validator ):
84+ trainer_validator .config ['beta2' ] = 'not_a_float'
85+ with pytest .raises (ValueError , match = "beta2 must be a float" ):
86+ trainer_validator .validate_beta2 ()
87+
88+ def test_validate_caption_column_invalid (trainer_validator ):
89+ trainer_validator .config ['caption_column' ] = 123
90+ with pytest .raises (ValueError , match = "caption_column must be a string" ):
91+ trainer_validator .validate_caption_column ()
92+
93+ def test_validate_caption_dropout_p_invalid (trainer_validator ):
94+ trainer_validator .config ['caption_dropout_p' ] = 'not_a_float'
95+ with pytest .raises (ValueError , match = "caption_dropout_p must be a float" ):
96+ trainer_validator .validate_caption_dropout_p ()
97+
98+ def test_validate_checkpointing_limit_invalid (trainer_validator ):
99+ trainer_validator .config ['checkpointing_limit' ] = 'not_an_int'
100+ with pytest .raises (ValueError , match = "checkpointing_limit must be an integer" ):
101+ trainer_validator .validate_checkpointing_limit ()
102+
103+ def test_validate_checkpointing_steps_invalid (trainer_validator ):
104+ trainer_validator .config ['checkpointing_steps' ] = 'not_an_int'
105+ with pytest .raises (ValueError , match = "checkpointing_steps must be an integer" ):
106+ trainer_validator .validate_checkpointing_steps ()
107+
108+ def test_validate_data_root_invalid (trainer_validator ):
109+ trainer_validator .config ['data_root' ] = '/invalid/path'
110+ with pytest .raises (ValueError , match = "data_root path /invalid/path does not exist" ):
111+ trainer_validator .validate_data_root ()
112+
113+ def test_validate_data_root_valid (trainer_validator ):
114+ with patch ('os.path.exists' , return_value = True ), patch ('os.path.isdir' , return_value = True ):
115+ trainer_validator .config ['data_root' ] = '/path/to/data'
116+ trainer_validator .validate_data_root ()
117+
118+ def test_validate_dataloader_num_workers_invalid (trainer_validator ):
119+ trainer_validator .config ['dataloader_num_workers' ] = 'not_an_int'
120+ with pytest .raises (ValueError , match = "dataloader_num_workers must be an integer" ):
121+ trainer_validator .validate_dataloader_num_workers ()
122+
123+ def test_validate_epsilon_invalid (trainer_validator ):
124+ trainer_validator .config ['epsilon' ] = 'not_a_float'
125+ with pytest .raises (ValueError , match = "epsilon must be a float" ):
126+ trainer_validator .validate_epsilon ()
127+
128+ def test_validate_gpu_ids_invalid (trainer_validator ):
129+ trainer_validator .config ['gpu_ids' ] = 123
130+ with pytest .raises (ValueError , match = "gpu_ids must be a string" ):
131+ trainer_validator .validate_gpu_ids ()
132+
133+ def test_validate_gradient_accumulation_steps_invalid (trainer_validator ):
134+ trainer_validator .config ['gradient_accumulation_steps' ] = 'not_an_int'
135+ with pytest .raises (ValueError , match = "gradient_accumulation_steps must be an integer" ):
136+ trainer_validator .validate_gradient_accumulation_steps ()
137+
138+ def test_validate_id_token_invalid (trainer_validator ):
139+ trainer_validator .config ['id_token' ] = 123
140+ with pytest .raises (ValueError , match = "id_token must be a string" ):
141+ trainer_validator .validate_id_token ()
142+
143+ def test_validate_lora_alpha_invalid (trainer_validator ):
144+ trainer_validator .config ['lora_alpha' ] = 'not_an_int'
145+ with pytest .raises (ValueError , match = "lora_alpha must be an integer" ):
146+ trainer_validator .validate_lora_alpha ()
147+
148+ def test_validate_lr_invalid (trainer_validator ):
149+ trainer_validator .config ['lr' ] = 'not_a_float'
150+ with pytest .raises (ValueError , match = "lr must be a float" ):
151+ trainer_validator .validate_lr ()
152+
153+ def test_validate_lr_num_cycles_invalid (trainer_validator ):
154+ trainer_validator .config ['lr_num_cycles' ] = 'not_an_int'
155+ with pytest .raises (ValueError , match = "lr_num_cycles must be an integer" ):
156+ trainer_validator .validate_lr_num_cycles ()
157+
158+ def test_validate_lr_scheduler_invalid (trainer_validator ):
159+ trainer_validator .config ['lr_scheduler' ] = ''
160+ with pytest .raises (ValueError , match = "lr_scheduler must be a string" ):
161+ trainer_validator .validate_lr_scheduler ()
162+
163+ def test_validate_lr_warmup_steps_invalid (trainer_validator ):
164+ trainer_validator .config ['lr_warmup_steps' ] = 'not_an_int'
165+ with pytest .raises (ValueError , match = "lr_warmup_steps must be an integer" ):
166+ trainer_validator .validate_lr_warmup_steps ()
167+
168+ def test_validate_max_grad_norm_invalid (trainer_validator ):
169+ trainer_validator .config ['max_grad_norm' ] = 'not_a_float'
170+ with pytest .raises (ValueError , match = "max_grad_norm must be a float" ):
171+ trainer_validator .validate_max_grad_norm ()
172+
173+ def test_validate_mixed_precision_invalid (trainer_validator ):
174+ trainer_validator .config ['mixed_precision' ] = 123
175+ with pytest .raises (ValueError , match = "mixed_precision must be a string" ):
176+ trainer_validator .validate_mixed_precision ()
177+
178+ def test_validate_model_name_invalid (trainer_validator ):
179+ trainer_validator .config ['model_name' ] = 123
180+ with pytest .raises (ValueError , match = "model_name must be a string" ):
181+ trainer_validator .validate_model_name ()
182+
183+ def test_validate_nccl_timeout_invalid (trainer_validator ):
184+ trainer_validator .config ['nccl_timeout' ] = 'not_an_int'
185+ with pytest .raises (ValueError , match = "nccl_timeout must be an integer" ):
186+ trainer_validator .validate_nccl_timeout ()
187+
188+ def test_validate_optimizer_invalid (trainer_validator ):
189+ trainer_validator .config ['optimizer' ] = 123
190+ with pytest .raises (ValueError , match = "optimizer must be a string" ):
191+ trainer_validator .validate_optimizer ()
192+
193+ def test_validate_pretrained_model_name_or_path_invalid (trainer_validator ):
194+ trainer_validator .config ['pretrained_model_name_or_path' ] = 123
195+ with pytest .raises (ValueError , match = "pretrained_model_name_or_path must be set" ):
196+ trainer_validator .validate_pretrained_model_name_or_path ()
197+
198+ def test_validate_rank_invalid (trainer_validator ):
199+ trainer_validator .config ['rank' ] = 'not_an_int'
200+ with pytest .raises (ValueError , match = "rank must be an integer" ):
201+ trainer_validator .validate_rank ()
202+
203+ def test_validate_seed_invalid (trainer_validator ):
204+ trainer_validator .config ['seed' ] = 'not_an_int'
205+ with pytest .raises (ValueError , match = "seed must be an integer" ):
206+ trainer_validator .validate_seed ()
207+
208+ def test_validate_target_modules_invalid (trainer_validator ):
209+ trainer_validator .config ['target_modules' ] = 123
210+ with pytest .raises (ValueError , match = "target_modules must be a string" ):
211+ trainer_validator .validate_target_modules ()
212+
213+ def test_validate_tracker_name_invalid (trainer_validator ):
214+ trainer_validator .config ['tracker_name' ] = 123
215+ with pytest .raises (ValueError , match = "tracker_name must be a string" ):
216+ trainer_validator .validate_tracker_name ()
217+
218+ def test_validate_train_steps_invalid (trainer_validator ):
219+ trainer_validator .config ['train_steps' ] = 'not_an_int'
220+ with pytest .raises (ValueError , match = "train_steps must be an integer" ):
221+ trainer_validator .validate_train_steps ()
222+
223+ def test_validate_training_type_invalid (trainer_validator ):
224+ trainer_validator .config ['training_type' ] = 123
225+ with pytest .raises (ValueError , match = "training_type must be a string" ):
226+ trainer_validator .validate_training_type ()
227+
228+ def test_validate_validation_steps_invalid (trainer_validator ):
229+ trainer_validator .config ['validation_steps' ] = 'not_an_int'
230+ with pytest .raises (ValueError , match = "validation_steps must be an integer" ):
231+ trainer_validator .validate_validation_steps ()
232+
233+ def test_validate_video_column_invalid (trainer_validator ):
234+ trainer_validator .config ['video_column' ] = 123
235+ with pytest .raises (ValueError , match = "video_column must be a string" ):
236+ trainer_validator .validate_video_column ()
237+
238+ def test_validate_video_resolution_buckets_invalid (trainer_validator ):
239+ trainer_validator .config ['video_resolution_buckets' ] = 123
240+ with pytest .raises (ValueError , match = "video_resolution_buckets must be a string" ):
241+ trainer_validator .validate_video_resolution_buckets ()
242+ trainer_validator .config ['video_resolution_buckets' ] = '720p,1080p,4k'
243+ with pytest .raises (ValueError , match = f"Each bucket must have the format '<frames>x<height>x<width>', but got { trainer_validator .config ['video_resolution_buckets' ]} " ):
244+ trainer_validator .validate_video_resolution_buckets ()
245+
246+ def test_validate_video_resolution_buckets_valid (trainer_validator ):
247+ trainer_validator .config ['video_resolution_buckets' ] = '24x480x720'
248+ trainer_validator .validate_video_resolution_buckets ()
249+
250+ trainer_validator .config ['video_resolution_buckets' ] = '8x320x512 24x480x720 30x720x1280'
251+ trainer_validator .validate_video_resolution_buckets ()
252+
253+ def test_validate_weight_decay_invalid (trainer_validator ):
254+ trainer_validator .config ['weight_decay' ] = 'not_a_float'
255+ with pytest .raises (ValueError , match = "weight_decay must be a float" ):
256+ trainer_validator .validate_weight_decay ()
0 commit comments