Skip to content

Commit 5df35a7

Browse files
committed
fix validation
1 parent a028548 commit 5df35a7

File tree

3 files changed

+49
-353
lines changed

3 files changed

+49
-353
lines changed

tabs/tab.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import OrderedDict
33
import gradio as gr
44
import yaml
5-
import editor_factory
65

76
class Tab(ABC):
87

test/test_trainer_config_validator.py

Lines changed: 3 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@pytest.fixture
88
def valid_config():
99
return {
10-
'finetrainers_path': '/path/to/finetrainers',
10+
'path_to_finetrainers': '/path/to/finetrainers',
1111
'accelerate_config': 'config1',
1212
'batch_size': 32,
1313
'beta1': 0.9,
@@ -17,7 +17,7 @@ def valid_config():
1717
'checkpointing_limit': 5,
1818
'checkpointing_steps': 1000,
1919
'data_root': '/path/to/data',
20-
'dataloader_num_workers': 4,
20+
'dataloader_num_workers': 0,
2121
'epsilon': 1e-8,
2222
'gpu_ids': '0,1',
2323
'gradient_accumulation_steps': 2,
@@ -34,7 +34,7 @@ def valid_config():
3434
'nccl_timeout': 60,
3535
'optimizer': 'adam',
3636
'pretrained_model_name_or_path': 'pretrained_model',
37-
'rank': 0,
37+
'rank': 64,
3838
'seed': 42,
3939
'target_modules': 'module1',
4040
'tracker_name': 'tracker',
@@ -55,56 +55,6 @@ def test_valid_config(valid_config):
5555
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
5656
trainer_validator.validate()
5757

58-
def test_validate_finetrainers_path_invalid(trainer_validator):
59-
trainer_validator.config['finetrainers_path'] = 123
60-
with pytest.raises(ValueError, match="finetrainers_path must be a string"):
61-
trainer_validator.validate_finetrainers_path()
62-
63-
def test_validate_finetrainers_path_valid(trainer_validator):
64-
with patch('os.path.isfile', return_value=True):
65-
trainer_validator.config['finetrainers_path'] = '/path/to/finetrainers'
66-
trainer_validator.validate_finetrainers_path()
67-
68-
def test_validate_accelerate_config_invalid(trainer_validator):
69-
trainer_validator.config['accelerate_config'] = []
70-
with pytest.raises(ValueError, match="accelerate_config must be a string"):
71-
trainer_validator.validate_accelerate_config()
72-
73-
def test_validate_batch_size_invalid(trainer_validator):
74-
trainer_validator.config['batch_size'] = 'not_an_int'
75-
with pytest.raises(ValueError, match="batch_size must be an integer"):
76-
trainer_validator.validate_batch_size()
77-
78-
def test_validate_beta1_invalid(trainer_validator):
79-
trainer_validator.config['beta1'] = 'not_a_float'
80-
with pytest.raises(ValueError, match="beta1 must be a float"):
81-
trainer_validator.validate_beta1()
82-
83-
def test_validate_beta2_invalid(trainer_validator):
84-
trainer_validator.config['beta2'] = 'not_a_float'
85-
with pytest.raises(ValueError, match="beta2 must be a float"):
86-
trainer_validator.validate_beta2()
87-
88-
def test_validate_caption_column_invalid(trainer_validator):
89-
trainer_validator.config['caption_column'] = 123
90-
with pytest.raises(ValueError, match="caption_column must be a string"):
91-
trainer_validator.validate_caption_column()
92-
93-
def test_validate_caption_dropout_p_invalid(trainer_validator):
94-
trainer_validator.config['caption_dropout_p'] = 'not_a_float'
95-
with pytest.raises(ValueError, match="caption_dropout_p must be a float"):
96-
trainer_validator.validate_caption_dropout_p()
97-
98-
def test_validate_checkpointing_limit_invalid(trainer_validator):
99-
trainer_validator.config['checkpointing_limit'] = 'not_an_int'
100-
with pytest.raises(ValueError, match="checkpointing_limit must be an integer"):
101-
trainer_validator.validate_checkpointing_limit()
102-
103-
def test_validate_checkpointing_steps_invalid(trainer_validator):
104-
trainer_validator.config['checkpointing_steps'] = 'not_an_int'
105-
with pytest.raises(ValueError, match="checkpointing_steps must be an integer"):
106-
trainer_validator.validate_checkpointing_steps()
107-
10858
def test_validate_data_root_invalid(trainer_validator):
10959
trainer_validator.config['data_root'] = '/invalid/path'
11060
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):
@@ -115,130 +65,7 @@ def test_validate_data_root_valid(trainer_validator):
11565
trainer_validator.config['data_root'] = '/path/to/data'
11666
trainer_validator.validate_data_root()
11767

118-
def test_validate_dataloader_num_workers_invalid(trainer_validator):
119-
trainer_validator.config['dataloader_num_workers'] = 'not_an_int'
120-
with pytest.raises(ValueError, match="dataloader_num_workers must be an integer"):
121-
trainer_validator.validate_dataloader_num_workers()
122-
123-
def test_validate_epsilon_invalid(trainer_validator):
124-
trainer_validator.config['epsilon'] = 'not_a_float'
125-
with pytest.raises(ValueError, match="epsilon must be a float"):
126-
trainer_validator.validate_epsilon()
127-
128-
def test_validate_gpu_ids_invalid(trainer_validator):
129-
trainer_validator.config['gpu_ids'] = 123
130-
with pytest.raises(ValueError, match="gpu_ids must be a string"):
131-
trainer_validator.validate_gpu_ids()
132-
133-
def test_validate_gradient_accumulation_steps_invalid(trainer_validator):
134-
trainer_validator.config['gradient_accumulation_steps'] = 'not_an_int'
135-
with pytest.raises(ValueError, match="gradient_accumulation_steps must be an integer"):
136-
trainer_validator.validate_gradient_accumulation_steps()
137-
138-
def test_validate_id_token_invalid(trainer_validator):
139-
trainer_validator.config['id_token'] = 123
140-
with pytest.raises(ValueError, match="id_token must be a string"):
141-
trainer_validator.validate_id_token()
142-
143-
def test_validate_lora_alpha_invalid(trainer_validator):
144-
trainer_validator.config['lora_alpha'] = 'not_an_int'
145-
with pytest.raises(ValueError, match="lora_alpha must be an integer"):
146-
trainer_validator.validate_lora_alpha()
147-
148-
def test_validate_lr_invalid(trainer_validator):
149-
trainer_validator.config['lr'] = 'not_a_float'
150-
with pytest.raises(ValueError, match="lr must be a float"):
151-
trainer_validator.validate_lr()
152-
153-
def test_validate_lr_num_cycles_invalid(trainer_validator):
154-
trainer_validator.config['lr_num_cycles'] = 'not_an_int'
155-
with pytest.raises(ValueError, match="lr_num_cycles must be an integer"):
156-
trainer_validator.validate_lr_num_cycles()
157-
158-
def test_validate_lr_scheduler_invalid(trainer_validator):
159-
trainer_validator.config['lr_scheduler'] = ''
160-
with pytest.raises(ValueError, match="lr_scheduler must be a string"):
161-
trainer_validator.validate_lr_scheduler()
162-
163-
def test_validate_lr_warmup_steps_invalid(trainer_validator):
164-
trainer_validator.config['lr_warmup_steps'] = 'not_an_int'
165-
with pytest.raises(ValueError, match="lr_warmup_steps must be an integer"):
166-
trainer_validator.validate_lr_warmup_steps()
167-
168-
def test_validate_max_grad_norm_invalid(trainer_validator):
169-
trainer_validator.config['max_grad_norm'] = 'not_a_float'
170-
with pytest.raises(ValueError, match="max_grad_norm must be a float"):
171-
trainer_validator.validate_max_grad_norm()
172-
173-
def test_validate_mixed_precision_invalid(trainer_validator):
174-
trainer_validator.config['mixed_precision'] = 123
175-
with pytest.raises(ValueError, match="mixed_precision must be a string"):
176-
trainer_validator.validate_mixed_precision()
177-
178-
def test_validate_model_name_invalid(trainer_validator):
179-
trainer_validator.config['model_name'] = 123
180-
with pytest.raises(ValueError, match="model_name must be a string"):
181-
trainer_validator.validate_model_name()
182-
183-
def test_validate_nccl_timeout_invalid(trainer_validator):
184-
trainer_validator.config['nccl_timeout'] = 'not_an_int'
185-
with pytest.raises(ValueError, match="nccl_timeout must be an integer"):
186-
trainer_validator.validate_nccl_timeout()
187-
188-
def test_validate_optimizer_invalid(trainer_validator):
189-
trainer_validator.config['optimizer'] = 123
190-
with pytest.raises(ValueError, match="optimizer must be a string"):
191-
trainer_validator.validate_optimizer()
192-
193-
def test_validate_pretrained_model_name_or_path_invalid(trainer_validator):
194-
trainer_validator.config['pretrained_model_name_or_path'] = 123
195-
with pytest.raises(ValueError, match="pretrained_model_name_or_path must be set"):
196-
trainer_validator.validate_pretrained_model_name_or_path()
197-
198-
def test_validate_rank_invalid(trainer_validator):
199-
trainer_validator.config['rank'] = 'not_an_int'
200-
with pytest.raises(ValueError, match="rank must be an integer"):
201-
trainer_validator.validate_rank()
202-
203-
def test_validate_seed_invalid(trainer_validator):
204-
trainer_validator.config['seed'] = 'not_an_int'
205-
with pytest.raises(ValueError, match="seed must be an integer"):
206-
trainer_validator.validate_seed()
207-
208-
def test_validate_target_modules_invalid(trainer_validator):
209-
trainer_validator.config['target_modules'] = 123
210-
with pytest.raises(ValueError, match="target_modules must be a string"):
211-
trainer_validator.validate_target_modules()
212-
213-
def test_validate_tracker_name_invalid(trainer_validator):
214-
trainer_validator.config['tracker_name'] = 123
215-
with pytest.raises(ValueError, match="tracker_name must be a string"):
216-
trainer_validator.validate_tracker_name()
217-
218-
def test_validate_train_steps_invalid(trainer_validator):
219-
trainer_validator.config['train_steps'] = 'not_an_int'
220-
with pytest.raises(ValueError, match="train_steps must be an integer"):
221-
trainer_validator.validate_train_steps()
222-
223-
def test_validate_training_type_invalid(trainer_validator):
224-
trainer_validator.config['training_type'] = 123
225-
with pytest.raises(ValueError, match="training_type must be a string"):
226-
trainer_validator.validate_training_type()
227-
228-
def test_validate_validation_steps_invalid(trainer_validator):
229-
trainer_validator.config['validation_steps'] = 'not_an_int'
230-
with pytest.raises(ValueError, match="validation_steps must be an integer"):
231-
trainer_validator.validate_validation_steps()
232-
233-
def test_validate_video_column_invalid(trainer_validator):
234-
trainer_validator.config['video_column'] = 123
235-
with pytest.raises(ValueError, match="video_column must be a string"):
236-
trainer_validator.validate_video_column()
237-
23868
def test_validate_video_resolution_buckets_invalid(trainer_validator):
239-
trainer_validator.config['video_resolution_buckets'] = 123
240-
with pytest.raises(ValueError, match="video_resolution_buckets must be a string"):
241-
trainer_validator.validate_video_resolution_buckets()
24269
trainer_validator.config['video_resolution_buckets'] = '720p,1080p,4k'
24370
with pytest.raises(ValueError, match=f"Each bucket must have the format '<frames>x<height>x<width>', but got {trainer_validator.config['video_resolution_buckets']}"):
24471
trainer_validator.validate_video_resolution_buckets()
@@ -249,8 +76,3 @@ def test_validate_video_resolution_buckets_valid(trainer_validator):
24976

25077
trainer_validator.config['video_resolution_buckets'] = '8x320x512 24x480x720 30x720x1280'
25178
trainer_validator.validate_video_resolution_buckets()
252-
253-
def test_validate_weight_decay_invalid(trainer_validator):
254-
trainer_validator.config['weight_decay'] = 'not_a_float'
255-
with pytest.raises(ValueError, match="weight_decay must be a float"):
256-
trainer_validator.validate_weight_decay()

0 commit comments

Comments
 (0)