Skip to content

Commit 244de8b

Browse files
committed
Merge branch 'main' into update-v0.10.0
2 parents d5bdba2 + d065821 commit 244de8b

File tree

8 files changed

+172
-89
lines changed

8 files changed

+172
-89
lines changed

.github/workflows/python-app.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# This workflow will install Python dependencies, run tests and lint with a single version of Python
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3+
4+
name: Python application
5+
6+
on:
7+
push:
8+
branches: [ "main" ]
9+
pull_request:
10+
branches: [ "main" ]
11+
12+
permissions:
13+
contents: read
14+
15+
jobs:
16+
build:
17+
18+
runs-on: ubuntu-latest
19+
20+
steps:
21+
- uses: actions/checkout@v4
22+
- name: Set up Python 3.10
23+
uses: actions/setup-python@v3
24+
with:
25+
python-version: "3.10"
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install flake8 pytest
30+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31+
- name: Lint with flake8
32+
run: |
33+
# stop the build if there are Python syntax errors or undefined names
34+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37+
- name: Test with pytest
38+
run: |
39+
pytest

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ I try to keep main stable, but if it fails, step back one version and try that.
1212

1313
![Screenshot from 2024-12-30 07-56-37](https://github.com/user-attachments/assets/91b947db-1e50-42e0-8d12-28b436bf837d)
1414

15+
v0.9.1: Add 'resume_from_checkpoint' setting
16+
1517
v0.9.0: Add missing precompute_condition settings
1618

1719
v0.8.0: Configuration validator. Fail early.

config/config_categories.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p, precompute_conditions
2-
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size
1+
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p, precompute_conditions
2+
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
33
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
44
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
55
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config

config/config_template.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ checkpointing_limit: 102
1010
checkpointing_steps: 500
1111
data_root: ''
1212
dataloader_num_workers: 0
13+
dataset_file: ''
1314
diffusion_options: ''
15+
enable_model_cpu_offload: false
1416
enable_slicing: true
1517
enable_tiling: true
1618
epsilon: 1e-8
@@ -21,6 +23,7 @@ id_token: afkx
2123
layerwise_upcasting_modules: [none, transformer]
2224
layerwise_upcasting_granularity: [pytorch_layer, diffusers_layer]
2325
layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2]
26+
image_resolution_buckets: 512x768
2427
lora_alpha: 128
2528
lr: 0.0001
2629
lr_num_cycles: 1
@@ -37,6 +40,7 @@ precompute_conditions: false
3740
pretrained_model_name_or_path: ''
3841
rank: 128
3942
report_to: none
43+
resume_from_checkpoint: ''
4044
seed: 42
4145
target_modules: to_q to_k to_v to_out.0
4246
text_encoder_dtype: [bf16, fp16, fp32, fp8]

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ dependencies = [
55
"gradio",
66
"torch>=2.4.1"
77
]
8+
description = "A gradio based ui for training video transformer models with finetrainers as backend"
9+
readme = "README.md"
10+
license = {file = "LICENSE"}
811

912

1013
[project.urls]
11-
Homepage = "https://github.com/neph1/finetrainers-ui"
14+
Repository = "https://github.com/neph1/finetrainers-ui"
15+
16+
[tool.setuptools]
17+
packages = ["tabs", "config"]

run_trainer.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import signal
33
import subprocess
4-
import time
4+
5+
import psutil
56

67
from config import Config
78

@@ -18,88 +19,92 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
1819
assert config.get('data_root'), "Data root required"
1920
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"
2021

21-
# Model arguments
22-
model_cmd = f"--model_name {config.get('model_name')} \
23-
--pretrained_model_name_or_path {config.get('pretrained_model_name_or_path')} \
24-
--text_encoder_dtype {config.get('text_encoder_dtype')} \
25-
--text_encoder_2_dtype {config.get('text_encoder_2_dtype')} \
26-
--text_encoder_3_dtype {config.get('text_encoder_3_dtype')} \
27-
--vae_dtype {config.get('vae_dtype')} "
28-
22+
model_cmd = ["--model_name", config.get('model_name'),
23+
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'),
24+
"--text_encoder_dtype", config.get('text_encoder_dtype'),
25+
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
26+
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
27+
"--vae_dtype", config.get('vae_dtype')]
28+
2929
if config.get('layerwise_upcasting_modules') != 'none':
30-
model_cmd += f"--layerwise_upcasting_modules {config.get('layerwise_upcasting_modules')} \
31-
--layerwise_upcasting_storage_dtype {config.get('layerwise_upcasting_storage_dtype')} \
32-
--layerwise_upcasting_granularity {config.get('layerwise_upcasting_granularity')} "
33-
34-
# Dataset arguments
35-
dataset_cmd = f"--data_root {config.get('data_root')} \
36-
--video_column {config.get('video_column')} \
37-
--caption_column {config.get('caption_column')} \
38-
--id_token {config.get('id_token')} \
39-
--video_resolution_buckets {config.get('video_resolution_buckets')} \
40-
--caption_dropout_p {config.get('caption_dropout_p')} \
41-
--caption_dropout_technique {config.get('caption_dropout_technique')} \
42-
{'--precompute_conditions' if config.get('precompute_conditions') else ''} "
43-
44-
# Dataloader arguments
45-
dataloader_cmd = f"--dataloader_num_workers {config.get('dataloader_num_workers')}"
30+
model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'),
31+
"--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'),
32+
"--layerwise_upcasting_granularity", config.get('layerwise_upcasting_granularity')]
33+
34+
dataset_cmd = ["--data_root", config.get('data_root'),
35+
"--video_column", config.get('video_column'),
36+
"--caption_column", config.get('caption_column'),
37+
"--id_token", config.get('id_token'),
38+
"--video_resolution_buckets"]
39+
dataset_cmd += config.get('video_resolution_buckets').split(' ')
40+
dataset_cmd += ["--image_resolution_buckets"]
41+
dataset_cmd += config.get('image_resolution_buckets').split(' ')
42+
dataset_cmd += ["--caption_dropout_p", config.get('caption_dropout_p'),
43+
"--caption_dropout_technique", config.get('caption_dropout_technique'),
44+
"--text_encoder_dtype", config.get('text_encoder_dtype'),
45+
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
46+
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
47+
"--vae_dtype", config.get('vae_dtype'),
48+
'--precompute_conditions' if config.get('precompute_conditions') else '']
49+
if config.get('dataset_file'):
50+
dataset_cmd += ["--dataset_file", config.get('dataset_file')]
51+
52+
dataloader_cmd = ["--dataloader_num_workers", config.get('dataloader_num_workers')]
4653

4754
# Diffusion arguments TODO: replace later
48-
diffusion_cmd = f"{config.get('diffusion_options')}"
49-
50-
# Training arguments
51-
training_cmd = f"--training_type {config.get('training_type')} \
52-
--seed {config.get('seed')} \
53-
--batch_size {config.get('batch_size')} \
54-
--train_steps {config.get('train_steps')} \
55-
--rank {config.get('rank')} \
56-
--lora_alpha {config.get('lora_alpha')} \
57-
--target_modules {config.get('target_modules')} \
58-
--gradient_accumulation_steps {config.get('gradient_accumulation_steps')} \
59-
{'--gradient_checkpointing' if config.get('gradient_checkpointing') else ''} \
60-
--checkpointing_steps {config.get('checkpointing_steps')} \
61-
--checkpointing_limit {config.get('checkpointing_limit')} \
62-
{'--enable_slicing' if config.get('enable_slicing') else ''} \
63-
{'--enable_tiling' if config.get('enable_tiling') else ''}"
64-
65-
# Optimizer arguments
66-
optimizer_cmd = f"--optimizer {config.get('optimizer')} \
67-
--lr {config.get('lr')} \
68-
--lr_scheduler {config.get('lr_scheduler')} \
69-
--lr_warmup_steps {config.get('lr_warmup_steps')} \
70-
--lr_num_cycles {config.get('lr_num_cycles')} \
71-
--beta1 {config.get('beta1')} \
72-
--beta2 {config.get('beta2')} \
73-
--weight_decay {config.get('weight_decay')} \
74-
--epsilon {config.get('epsilon')} \
75-
--max_grad_norm {config.get('max_grad_norm')} \
76-
{'--use_8bit_bnb' if config.get('use_8bit_bnb') else ''}"
77-
78-
# Validation arguments
79-
validation_cmd = f"--validation_prompts \"{config.get('validation_prompts')}\" \
80-
--num_validation_videos {config.get('num_validation_videos')} \
81-
--validation_steps {config.get('validation_steps')}"
82-
83-
# Miscellaneous arguments
84-
miscellaneous_cmd = f"--tracker_name {config.get('tracker_name')} \
85-
--output_dir {config.get('output_dir')} \
86-
--nccl_timeout {config.get('nccl_timeout')} \
87-
--report_to {config.get('report_to')}"
88-
89-
cmd = f"accelerate launch --config_file {finetrainers_path}/accelerate_configs/{config.get('accelerate_config')} --gpu_ids {config.get('gpu_ids')} {finetrainers_path}/train.py \
90-
{model_cmd} \
91-
{dataset_cmd} \
92-
{dataloader_cmd} \
93-
{diffusion_cmd} \
94-
{training_cmd} \
95-
{optimizer_cmd} \
96-
{validation_cmd} \
97-
{miscellaneous_cmd}"
98-
99-
print(cmd)
55+
diffusion_cmd = [config.get('diffusion_options')]
56+
57+
training_cmd = ["--training_type", config.get('training_type'),
58+
"--seed", config.get('seed'),
59+
"--mixed_precision", config.get('mixed_precision'),
60+
"--batch_size", config.get('batch_size'),
61+
"--train_steps", config.get('train_steps'),
62+
"--rank", config.get('rank'),
63+
"--lora_alpha", config.get('lora_alpha'),
64+
"--target_modules"]
65+
training_cmd += config.get('target_modules').split(' ')
66+
training_cmd += ["--gradient_accumulation_steps", config.get('gradient_accumulation_steps'),
67+
'--gradient_checkpointing' if config.get('gradient_checkpointing') else '',
68+
"--checkpointing_steps", config.get('checkpointing_steps'),
69+
"--checkpointing_limit", config.get('checkpointing_limit'),
70+
'--enable_slicing' if config.get('enable_slicing') else '',
71+
'--enable_tiling' if config.get('enable_tiling') else '']
72+
if config.get('enable_model_cpu_offload'):
73+
training_cmd += ["--enable_model_cpu_offload"]
74+
75+
if config.get('resume_from_checkpoint'):
76+
training_cmd += ["--resume_from_checkpoint", config.get('resume_from_checkpoint')]
77+
78+
optimizer_cmd = ["--optimizer", config.get('optimizer'),
79+
"--lr", config.get('lr'),
80+
"--lr_scheduler", config.get('lr_scheduler'),
81+
"--lr_warmup_steps", config.get('lr_warmup_steps'),
82+
"--lr_num_cycles", config.get('lr_num_cycles'),
83+
"--beta1", config.get('beta1'),
84+
"--beta2", config.get('beta2'),
85+
"--weight_decay", config.get('weight_decay'),
86+
"--epsilon", config.get('epsilon'),
87+
"--max_grad_norm", config.get('max_grad_norm'),
88+
'--use_8bit_bnb' if config.get('use_8bit_bnb') else '']
89+
90+
validation_cmd = ["--validation_prompts" if config.get('validation_prompts') else '', config.get('validation_prompts') or '',
91+
"--num_validation_videos", config.get('num_validation_videos'),
92+
"--validation_steps", config.get('validation_steps')]
93+
94+
miscellaneous_cmd = ["--tracker_name", config.get('tracker_name'),
95+
"--output_dir", config.get('output_dir'),
96+
"--nccl_timeout", config.get('nccl_timeout'),
97+
"--report_to", config.get('report_to')]
98+
accelerate_cmd = ["accelerate", "launch", "--config_file", f"{finetrainers_path}/accelerate_configs/{config.get('accelerate_config')}", "--gpu_ids", config.get('gpu_ids')]
99+
cmd = accelerate_cmd + [f"{finetrainers_path}/train.py"] + model_cmd + dataset_cmd + dataloader_cmd + diffusion_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
100+
fixed_cmd = []
101+
for i in range(len(cmd)):
102+
if cmd[i] != '':
103+
fixed_cmd.append(f"{cmd[i]}")
104+
print(' '.join(fixed_cmd))
100105
self.running = True
101106
with open(log_file, "w") as output_file:
102-
self.process = subprocess.Popen(cmd, shell=True, stdout=output_file, stderr=output_file, text=True)
107+
self.process = subprocess.Popen(fixed_cmd, shell=False, stdout=output_file, stderr=output_file, text=True, preexec_fn=os.setsid)
103108
self.process.communicate()
104109
return self.process
105110

@@ -109,12 +114,20 @@ def stop(self):
109114
try:
110115
self.running = False
111116
if self.process:
112-
self.process.terminate()
113-
time.sleep(3)
114-
if self.process.poll() is None:
115-
self.process.kill()
117+
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
118+
self.terminate_process_tree(self.process.pid)
116119
except Exception as e:
117120
return f"Error stopping training: {e}"
118121
finally:
119122
self.process.wait()
120-
return "Training forcibly stopped"
123+
return "Training forcibly stopped"
124+
125+
def terminate_process_tree(pid):
126+
try:
127+
parent = psutil.Process(pid)
128+
children = parent.children(recursive=True) # Get child processes
129+
for child in children:
130+
child.terminate()
131+
parent.terminate()
132+
except psutil.NoSuchProcess:
133+
pass

scripts/rename_keys.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
def rename_keys(file, outfile: str)-> bool:
88
sd, metadata = load_state_dict(file, torch.float32)
99

10-
keys_to_normalize = [key for key in sd.keys()]
11-
values_to_normalize = [sd[key].to(torch.float32) for key in keys_to_normalize]
10+
keys_to_rename = [key for key in sd.keys()]
11+
values = [sd[key].to(torch.float32) for key in keys_to_rename]
1212
new_sd = dict()
13-
for key, value in zip(keys_to_normalize, values_to_normalize):
13+
for key, value in zip(keys_to_rename, values):
1414
new_sd[key.replace("transformer.", "")] = value
1515

1616
save_to_file(outfile, new_sd, torch.float16, metadata)

test/test_trainer_config_validator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
from unittest.mock import patch
44

5+
import yaml
6+
57
from trainer_config_validator import TrainerValidator
68

79
@pytest.fixture
@@ -55,6 +57,23 @@ def test_valid_config(valid_config):
5557
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
5658
trainer_validator.validate()
5759

60+
def test_config_template():
61+
config = None
62+
with open('config/config_template.yaml', "r") as file:
63+
config = yaml.safe_load(file)
64+
config['path_to_finetrainers'] = '/path/to/finetrainers'
65+
config['data_root'] = '/path/to/data'
66+
config['pretrained_model_name_or_path'] = 'pretrained_model'
67+
68+
trainer_validator = TrainerValidator(config)
69+
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
70+
trainer_validator.validate()
71+
72+
def test_validate_data_root_not_set(trainer_validator):
73+
trainer_validator.config['data_root'] = ''
74+
with pytest.raises(ValueError, match="data_root is required"):
75+
trainer_validator.validate()
76+
5877
def test_validate_data_root_invalid(trainer_validator):
5978
trainer_validator.config['data_root'] = '/invalid/path'
6079
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):

0 commit comments

Comments
 (0)