Skip to content

Commit bde2422

Browse files
authored
Merge pull request #13 from neph1/update-v0.9.1
add resume from checkpoint
2 parents a1293e0 + b0031c5 commit bde2422

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

config/config_categories.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p
2-
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size
2+
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
33
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
44
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
55
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config

config/config_template.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ precompute_conditions: false
3535
pretrained_model_name_or_path: ''
3636
rank: 128
3737
report_to: none
38+
resume_from_checkpoint: ''
3839
seed: 42
3940
target_modules: to_q to_k to_v to_out.0
4041
text_encoder_dtype: [bf16, fp16, fp32]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "finetrainers-ui"
3-
version = "0.8.0"
3+
version = "0.9.1"
44
dependencies = [
55
"gradio",
66
"torch>=2.4.1"

run_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
5656
--checkpointing_steps {config.get('checkpointing_steps')} \
5757
--checkpointing_limit {config.get('checkpointing_limit')} \
5858
{'--enable_slicing' if config.get('enable_slicing') else ''} \
59-
{'--enable_tiling' if config.get('enable_tiling') else ''}"
59+
{'--enable_tiling' if config.get('enable_tiling') else ''} "
60+
61+
if config.get('resume_from_checkpoint'):
62+
training_cmd += f"--resume_from_checkpoint {config.get('resume_from_checkpoint')}"
6063

6164
# Optimizer arguments
6265
optimizer_cmd = f"--optimizer {config.get('optimizer')} \

scripts/rename_keys.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
def rename_keys(file, outfile: str)-> bool:
88
sd, metadata = load_state_dict(file, torch.float32)
99

10-
keys_to_normalize = [key for key in sd.keys()]
11-
values_to_normalize = [sd[key].to(torch.float32) for key in keys_to_normalize]
10+
keys_to_rename = [key for key in sd.keys()]
11+
values = [sd[key].to(torch.float32) for key in keys_to_rename]
1212
new_sd = dict()
13-
for key, value in zip(keys_to_normalize, values_to_normalize):
13+
for key, value in zip(keys_to_rename, values):
1414
new_sd[key.replace("transformer.", "")] = value
1515

1616
save_to_file(outfile, new_sd, torch.float16, metadata)

test/test_trainer_config_validator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
from unittest.mock import patch
44

5+
import yaml
6+
57
from trainer_config_validator import TrainerValidator
68

79
@pytest.fixture
@@ -55,6 +57,23 @@ def test_valid_config(valid_config):
5557
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
5658
trainer_validator.validate()
5759

60+
def test_config_template():
61+
config = None
62+
with open('config/config_template.yaml', "r") as file:
63+
config = yaml.safe_load(file)
64+
config['path_to_finetrainers'] = '/path/to/finetrainers'
65+
config['data_root'] = '/path/to/data'
66+
config['pretrained_model_name_or_path'] = 'pretrained_model'
67+
68+
trainer_validator = TrainerValidator(config)
69+
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
70+
trainer_validator.validate()
71+
72+
def test_validate_data_root_not_set(trainer_validator):
73+
trainer_validator.config['data_root'] = ''
74+
with pytest.raises(ValueError, match="data_root is required"):
75+
trainer_validator.validate()
76+
5877
def test_validate_data_root_invalid(trainer_validator):
5978
trainer_validator.config['data_root'] = '/invalid/path'
6079
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):

0 commit comments

Comments
 (0)