Skip to content

Commit d5bdba2

Browse files
committed
settings for fp8 training
1 parent 22fbe95 commit d5bdba2

File tree

10 files changed

+29
-23
lines changed

10 files changed

+29
-23
lines changed

config/config_categories.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p
2-
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size
1+
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p, precompute_conditions
2+
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size
33
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
44
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
5-
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
5+
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
6+
Model: model_name, pretrained_model_name_or_path, text_encoder_dtype, text_encoder_2_dtype, text_encoder_3_dtype, vae_dtype, layerwise_upcasting_modules, layerwise_upcasting_storage_dtype, layerwise_upcasting_granularity

config/config_template.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ gpu_ids: '0'
1818
gradient_accumulation_steps: 4
1919
gradient_checkpointing: true
2020
id_token: afkx
21+
layerwise_upcasting_modules: [none, transformer]
22+
layerwise_upcasting_granularity: [pytorch_layer, diffusers_layer]
23+
layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2]
2124
lora_alpha: 128
2225
lr: 0.0001
2326
lr_num_cycles: 1
2427
lr_scheduler: ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']
2528
lr_warmup_steps: 400
2629
max_grad_norm: 1.0
27-
mixed_precision: [bf16, fp16, 'no']
2830
model_name: ltx_video
2931
nccl_timeout: 1800
3032
num_validation_videos: 0
@@ -37,14 +39,14 @@ rank: 128
3739
report_to: none
3840
seed: 42
3941
target_modules: to_q to_k to_v to_out.0
40-
text_encoder_dtype: [bf16, fp16, fp32]
41-
text_encoder_2_dtype: [bf16, fp16, fp32]
42-
text_encoder_3_dtype: [bf16, fp16, fp32]
42+
text_encoder_dtype: [bf16, fp16, fp32, fp8]
43+
text_encoder_2_dtype: [bf16, fp16, fp32, fp8]
44+
text_encoder_3_dtype: [bf16, fp16, fp32, fp8]
4345
tracker_name: finetrainers
4446
train_steps: 3000
4547
training_type: lora
4648
use_8bit_bnb: false
47-
vae_dtype: [bf16, fp16, fp32]
49+
vae_dtype: [bf16, fp16, fp32, fp8]
4850
validation_epochs: 0
4951
validation_prompt_separator: ':::'
5052
validation_prompts: ''

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "finetrainers-ui"
3-
version = "0.8.0"
3+
version = "0.10.0"
44
dependencies = [
55
"gradio",
66
"torch>=2.4.1"

run_trainer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
2020

2121
# Model arguments
2222
model_cmd = f"--model_name {config.get('model_name')} \
23-
--pretrained_model_name_or_path {config.get('pretrained_model_name_or_path')}"
23+
--pretrained_model_name_or_path {config.get('pretrained_model_name_or_path')} \
24+
--text_encoder_dtype {config.get('text_encoder_dtype')} \
25+
--text_encoder_2_dtype {config.get('text_encoder_2_dtype')} \
26+
--text_encoder_3_dtype {config.get('text_encoder_3_dtype')} \
27+
--vae_dtype {config.get('vae_dtype')} "
28+
29+
if config.get('layerwise_upcasting_modules') != 'none':
30+
model_cmd += f"--layerwise_upcasting_modules {config.get('layerwise_upcasting_modules')} \
31+
--layerwise_upcasting_storage_dtype {config.get('layerwise_upcasting_storage_dtype')} \
32+
--layerwise_upcasting_granularity {config.get('layerwise_upcasting_granularity')} "
2433

2534
# Dataset arguments
2635
dataset_cmd = f"--data_root {config.get('data_root')} \
@@ -30,11 +39,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
3039
--video_resolution_buckets {config.get('video_resolution_buckets')} \
3140
--caption_dropout_p {config.get('caption_dropout_p')} \
3241
--caption_dropout_technique {config.get('caption_dropout_technique')} \
33-
{'--precompute_conditions' if config.get('precompute_conditions') else ''} \
34-
--text_encoder_dtype {config.get('text_encoder_dtype')} \
35-
--text_encoder_2_dtype {config.get('text_encoder_2_dtype')} \
36-
--text_encoder_3_dtype {config.get('text_encoder_3_dtype')} \
37-
--vae_dtype {config.get('vae_dtype')} "
42+
{'--precompute_conditions' if config.get('precompute_conditions') else ''} "
3843

3944
# Dataloader arguments
4045
dataloader_cmd = f"--dataloader_num_workers {config.get('dataloader_num_workers')}"
@@ -45,7 +50,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
4550
# Training arguments
4651
training_cmd = f"--training_type {config.get('training_type')} \
4752
--seed {config.get('seed')} \
48-
--mixed_precision {config.get('mixed_precision')} \
4953
--batch_size {config.get('batch_size')} \
5054
--train_steps {config.get('train_steps')} \
5155
--rank {config.get('rank')} \

tabs/general_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, title, config_file_path, allow_load=False):
1515

1616
try:
1717
with self.settings_column:
18-
inputs = self.update_form(self.config)
18+
inputs = self.update_form()
1919
self.components = OrderedDict(inputs)
2020
children = []
2121
for child in self.settings_column.children:

tabs/prepare_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, title, config_file_path, allow_load=False):
2020

2121
try:
2222
with self.settings_column:
23-
self.components = OrderedDict(self.update_form(self.config))
23+
self.components = OrderedDict(self.update_form())
2424
for i in range(len(self.settings_column.children)):
2525
keys = list(self.components.keys())
2626
properties[keys[i]] = self.settings_column.children[i]

tabs/tab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def add_buttons(self):
7171
outputs=[self.save_status, self.config_file_box, *self.get_properties().values()]
7272
)
7373

74-
def update_form(self, config):
74+
def update_form(self):
7575
inputs = dict()
7676

77-
for key, value in config.items():
77+
for key, value in self.config.items():
7878
category = 'Other'
7979
for categories in self.config_categories.keys():
8080
if key in self.config_categories[categories]:

tabs/training_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, title, config_file_path, allow_load=False):
3030

3131
try:
3232
with self.settings_column:
33-
inputs = self.update_form(self.config)
33+
inputs = self.update_form()
3434
self.components = OrderedDict(inputs)
3535
children = []
3636
for child in self.settings_column.children:

tabs/training_tab_legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, title, config_file_path, allow_load=False):
1717

1818
try:
1919
with self.settings_column:
20-
self.components = OrderedDict(self.update_form(self.config))
20+
self.components = OrderedDict(self.update_form())
2121
for i in range(len(self.settings_column.children)):
2222
keys = list(self.components.keys())
2323
properties[keys[i]] = self.settings_column.children[i]

trainer_config_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def validate(self):
2929
'lr_scheduler',
3030
'lr_warmup_steps',
3131
'max_grad_norm',
32-
'mixed_precision',
3332
'model_name',
3433
'nccl_timeout',
3534
'optimizer',

0 commit comments

Comments
 (0)