77@pytest .fixture
88def valid_config ():
99 return {
10- 'finetrainers_path ' : '/path/to/finetrainers' ,
10+ 'path_to_finetrainers ' : '/path/to/finetrainers' ,
1111 'accelerate_config' : 'config1' ,
1212 'batch_size' : 32 ,
1313 'beta1' : 0.9 ,
@@ -17,7 +17,7 @@ def valid_config():
1717 'checkpointing_limit' : 5 ,
1818 'checkpointing_steps' : 1000 ,
1919 'data_root' : '/path/to/data' ,
20- 'dataloader_num_workers' : 4 ,
20+ 'dataloader_num_workers' : 0 ,
2121 'epsilon' : 1e-8 ,
2222 'gpu_ids' : '0,1' ,
2323 'gradient_accumulation_steps' : 2 ,
@@ -34,7 +34,7 @@ def valid_config():
3434 'nccl_timeout' : 60 ,
3535 'optimizer' : 'adam' ,
3636 'pretrained_model_name_or_path' : 'pretrained_model' ,
37- 'rank' : 0 ,
37+ 'rank' : 64 ,
3838 'seed' : 42 ,
3939 'target_modules' : 'module1' ,
4040 'tracker_name' : 'tracker' ,
@@ -55,56 +55,6 @@ def test_valid_config(valid_config):
5555 with patch ('os.path.isfile' , return_value = True ), patch ('os.path.exists' , return_value = True ), patch ('os.path.isdir' , return_value = True ):
5656 trainer_validator .validate ()
5757
58- def test_validate_finetrainers_path_invalid (trainer_validator ):
59- trainer_validator .config ['finetrainers_path' ] = 123
60- with pytest .raises (ValueError , match = "finetrainers_path must be a string" ):
61- trainer_validator .validate_finetrainers_path ()
62-
63- def test_validate_finetrainers_path_valid (trainer_validator ):
64- with patch ('os.path.isfile' , return_value = True ):
65- trainer_validator .config ['finetrainers_path' ] = '/path/to/finetrainers'
66- trainer_validator .validate_finetrainers_path ()
67-
68- def test_validate_accelerate_config_invalid (trainer_validator ):
69- trainer_validator .config ['accelerate_config' ] = []
70- with pytest .raises (ValueError , match = "accelerate_config must be a string" ):
71- trainer_validator .validate_accelerate_config ()
72-
73- def test_validate_batch_size_invalid (trainer_validator ):
74- trainer_validator .config ['batch_size' ] = 'not_an_int'
75- with pytest .raises (ValueError , match = "batch_size must be an integer" ):
76- trainer_validator .validate_batch_size ()
77-
78- def test_validate_beta1_invalid (trainer_validator ):
79- trainer_validator .config ['beta1' ] = 'not_a_float'
80- with pytest .raises (ValueError , match = "beta1 must be a float" ):
81- trainer_validator .validate_beta1 ()
82-
83- def test_validate_beta2_invalid (trainer_validator ):
84- trainer_validator .config ['beta2' ] = 'not_a_float'
85- with pytest .raises (ValueError , match = "beta2 must be a float" ):
86- trainer_validator .validate_beta2 ()
87-
88- def test_validate_caption_column_invalid (trainer_validator ):
89- trainer_validator .config ['caption_column' ] = 123
90- with pytest .raises (ValueError , match = "caption_column must be a string" ):
91- trainer_validator .validate_caption_column ()
92-
93- def test_validate_caption_dropout_p_invalid (trainer_validator ):
94- trainer_validator .config ['caption_dropout_p' ] = 'not_a_float'
95- with pytest .raises (ValueError , match = "caption_dropout_p must be a float" ):
96- trainer_validator .validate_caption_dropout_p ()
97-
98- def test_validate_checkpointing_limit_invalid (trainer_validator ):
99- trainer_validator .config ['checkpointing_limit' ] = 'not_an_int'
100- with pytest .raises (ValueError , match = "checkpointing_limit must be an integer" ):
101- trainer_validator .validate_checkpointing_limit ()
102-
103- def test_validate_checkpointing_steps_invalid (trainer_validator ):
104- trainer_validator .config ['checkpointing_steps' ] = 'not_an_int'
105- with pytest .raises (ValueError , match = "checkpointing_steps must be an integer" ):
106- trainer_validator .validate_checkpointing_steps ()
107-
10858def test_validate_data_root_invalid (trainer_validator ):
10959 trainer_validator .config ['data_root' ] = '/invalid/path'
11060 with pytest .raises (ValueError , match = "data_root path /invalid/path does not exist" ):
@@ -115,130 +65,7 @@ def test_validate_data_root_valid(trainer_validator):
11565 trainer_validator .config ['data_root' ] = '/path/to/data'
11666 trainer_validator .validate_data_root ()
11767
118- def test_validate_dataloader_num_workers_invalid (trainer_validator ):
119- trainer_validator .config ['dataloader_num_workers' ] = 'not_an_int'
120- with pytest .raises (ValueError , match = "dataloader_num_workers must be an integer" ):
121- trainer_validator .validate_dataloader_num_workers ()
122-
123- def test_validate_epsilon_invalid (trainer_validator ):
124- trainer_validator .config ['epsilon' ] = 'not_a_float'
125- with pytest .raises (ValueError , match = "epsilon must be a float" ):
126- trainer_validator .validate_epsilon ()
127-
128- def test_validate_gpu_ids_invalid (trainer_validator ):
129- trainer_validator .config ['gpu_ids' ] = 123
130- with pytest .raises (ValueError , match = "gpu_ids must be a string" ):
131- trainer_validator .validate_gpu_ids ()
132-
133- def test_validate_gradient_accumulation_steps_invalid (trainer_validator ):
134- trainer_validator .config ['gradient_accumulation_steps' ] = 'not_an_int'
135- with pytest .raises (ValueError , match = "gradient_accumulation_steps must be an integer" ):
136- trainer_validator .validate_gradient_accumulation_steps ()
137-
138- def test_validate_id_token_invalid (trainer_validator ):
139- trainer_validator .config ['id_token' ] = 123
140- with pytest .raises (ValueError , match = "id_token must be a string" ):
141- trainer_validator .validate_id_token ()
142-
143- def test_validate_lora_alpha_invalid (trainer_validator ):
144- trainer_validator .config ['lora_alpha' ] = 'not_an_int'
145- with pytest .raises (ValueError , match = "lora_alpha must be an integer" ):
146- trainer_validator .validate_lora_alpha ()
147-
148- def test_validate_lr_invalid (trainer_validator ):
149- trainer_validator .config ['lr' ] = 'not_a_float'
150- with pytest .raises (ValueError , match = "lr must be a float" ):
151- trainer_validator .validate_lr ()
152-
153- def test_validate_lr_num_cycles_invalid (trainer_validator ):
154- trainer_validator .config ['lr_num_cycles' ] = 'not_an_int'
155- with pytest .raises (ValueError , match = "lr_num_cycles must be an integer" ):
156- trainer_validator .validate_lr_num_cycles ()
157-
158- def test_validate_lr_scheduler_invalid (trainer_validator ):
159- trainer_validator .config ['lr_scheduler' ] = ''
160- with pytest .raises (ValueError , match = "lr_scheduler must be a string" ):
161- trainer_validator .validate_lr_scheduler ()
162-
163- def test_validate_lr_warmup_steps_invalid (trainer_validator ):
164- trainer_validator .config ['lr_warmup_steps' ] = 'not_an_int'
165- with pytest .raises (ValueError , match = "lr_warmup_steps must be an integer" ):
166- trainer_validator .validate_lr_warmup_steps ()
167-
168- def test_validate_max_grad_norm_invalid (trainer_validator ):
169- trainer_validator .config ['max_grad_norm' ] = 'not_a_float'
170- with pytest .raises (ValueError , match = "max_grad_norm must be a float" ):
171- trainer_validator .validate_max_grad_norm ()
172-
173- def test_validate_mixed_precision_invalid (trainer_validator ):
174- trainer_validator .config ['mixed_precision' ] = 123
175- with pytest .raises (ValueError , match = "mixed_precision must be a string" ):
176- trainer_validator .validate_mixed_precision ()
177-
178- def test_validate_model_name_invalid (trainer_validator ):
179- trainer_validator .config ['model_name' ] = 123
180- with pytest .raises (ValueError , match = "model_name must be a string" ):
181- trainer_validator .validate_model_name ()
182-
183- def test_validate_nccl_timeout_invalid (trainer_validator ):
184- trainer_validator .config ['nccl_timeout' ] = 'not_an_int'
185- with pytest .raises (ValueError , match = "nccl_timeout must be an integer" ):
186- trainer_validator .validate_nccl_timeout ()
187-
188- def test_validate_optimizer_invalid (trainer_validator ):
189- trainer_validator .config ['optimizer' ] = 123
190- with pytest .raises (ValueError , match = "optimizer must be a string" ):
191- trainer_validator .validate_optimizer ()
192-
193- def test_validate_pretrained_model_name_or_path_invalid (trainer_validator ):
194- trainer_validator .config ['pretrained_model_name_or_path' ] = 123
195- with pytest .raises (ValueError , match = "pretrained_model_name_or_path must be set" ):
196- trainer_validator .validate_pretrained_model_name_or_path ()
197-
198- def test_validate_rank_invalid (trainer_validator ):
199- trainer_validator .config ['rank' ] = 'not_an_int'
200- with pytest .raises (ValueError , match = "rank must be an integer" ):
201- trainer_validator .validate_rank ()
202-
203- def test_validate_seed_invalid (trainer_validator ):
204- trainer_validator .config ['seed' ] = 'not_an_int'
205- with pytest .raises (ValueError , match = "seed must be an integer" ):
206- trainer_validator .validate_seed ()
207-
208- def test_validate_target_modules_invalid (trainer_validator ):
209- trainer_validator .config ['target_modules' ] = 123
210- with pytest .raises (ValueError , match = "target_modules must be a string" ):
211- trainer_validator .validate_target_modules ()
212-
213- def test_validate_tracker_name_invalid (trainer_validator ):
214- trainer_validator .config ['tracker_name' ] = 123
215- with pytest .raises (ValueError , match = "tracker_name must be a string" ):
216- trainer_validator .validate_tracker_name ()
217-
218- def test_validate_train_steps_invalid (trainer_validator ):
219- trainer_validator .config ['train_steps' ] = 'not_an_int'
220- with pytest .raises (ValueError , match = "train_steps must be an integer" ):
221- trainer_validator .validate_train_steps ()
222-
223- def test_validate_training_type_invalid (trainer_validator ):
224- trainer_validator .config ['training_type' ] = 123
225- with pytest .raises (ValueError , match = "training_type must be a string" ):
226- trainer_validator .validate_training_type ()
227-
228- def test_validate_validation_steps_invalid (trainer_validator ):
229- trainer_validator .config ['validation_steps' ] = 'not_an_int'
230- with pytest .raises (ValueError , match = "validation_steps must be an integer" ):
231- trainer_validator .validate_validation_steps ()
232-
233- def test_validate_video_column_invalid (trainer_validator ):
234- trainer_validator .config ['video_column' ] = 123
235- with pytest .raises (ValueError , match = "video_column must be a string" ):
236- trainer_validator .validate_video_column ()
237-
23868def test_validate_video_resolution_buckets_invalid (trainer_validator ):
239- trainer_validator .config ['video_resolution_buckets' ] = 123
240- with pytest .raises (ValueError , match = "video_resolution_buckets must be a string" ):
241- trainer_validator .validate_video_resolution_buckets ()
24269 trainer_validator .config ['video_resolution_buckets' ] = '720p,1080p,4k'
24370 with pytest .raises (ValueError , match = f"Each bucket must have the format '<frames>x<height>x<width>', but got { trainer_validator .config ['video_resolution_buckets' ]} " ):
24471 trainer_validator .validate_video_resolution_buckets ()
@@ -249,8 +76,3 @@ def test_validate_video_resolution_buckets_valid(trainer_validator):
24976
25077 trainer_validator .config ['video_resolution_buckets' ] = '8x320x512 24x480x720 30x720x1280'
25178 trainer_validator .validate_video_resolution_buckets ()
252-
253- def test_validate_weight_decay_invalid (trainer_validator ):
254- trainer_validator .config ['weight_decay' ] = 'not_a_float'
255- with pytest .raises (ValueError , match = "weight_decay must be a float" ):
256- trainer_validator .validate_weight_decay ()
0 commit comments