Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions config/config_categories.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p, precompute_conditions
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
Model: model_name, pretrained_model_name_or_path, text_encoder_dtype, text_encoder_2_dtype, text_encoder_3_dtype, vae_dtype, layerwise_upcasting_modules, layerwise_upcasting_storage_dtype, layerwise_upcasting_granularity
5 changes: 4 additions & 1 deletion config/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ gpu_ids: '0'
gradient_accumulation_steps: 4
gradient_checkpointing: true
id_token: afkx
layerwise_upcasting_modules: [none, transformer]
layerwise_upcasting_skip_modules_pattern: 'patch_embed pos_embed x_embedder context_embedder ^proj_in$ ^proj_out$ norm'
layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2]
image_resolution_buckets: 512x768
lora_alpha: 128
lr: 0.0001
lr_num_cycles: 1
lr_scheduler: ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']
lr_warmup_steps: 400
max_grad_norm: 1.0
mixed_precision: [bf16, fp16, 'no']
model_name: ltx_video
nccl_timeout: 1800
num_validation_videos: 0
Expand All @@ -45,6 +47,7 @@ text_encoder_dtype: [bf16, fp16, fp32]
text_encoder_2_dtype: [bf16, fp16, fp32]
text_encoder_3_dtype: [bf16, fp16, fp32]
tracker_name: finetrainers
transformer_dtype: [bf16, fp16, fp32]
train_steps: 3000
training_type: lora
use_8bit_bnb: false
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "finetrainers-ui"
version = "0.9.3"
version = "0.10.0"
dependencies = [
"gradio",
"torch>=2.4.1"
Expand Down
13 changes: 11 additions & 2 deletions run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"

model_cmd = ["--model_name", config.get('model_name'),
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path')]
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'),
"--text_encoder_dtype", config.get('text_encoder_dtype'),
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--vae_dtype", config.get('vae_dtype')]

if config.get('layerwise_upcasting_modules') != 'none':
model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'),
"--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'),
"--layerwise_upcasting_skip_modules_pattern", config.get('layerwise_upcasting_skip_modules_pattern')]

dataset_cmd = ["--data_root", config.get('data_root'),
"--video_column", config.get('video_column'),
Expand All @@ -36,6 +45,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--vae_dtype", config.get('vae_dtype'),
"--transformer_dtype", config.get('transformer_dtype'),
'--precompute_conditions' if config.get('precompute_conditions') else '']
if config.get('dataset_file'):
dataset_cmd += ["--dataset_file", config.get('dataset_file')]
Expand All @@ -47,7 +57,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):

training_cmd = ["--training_type", config.get('training_type'),
"--seed", config.get('seed'),
"--mixed_precision", config.get('mixed_precision'),
"--batch_size", config.get('batch_size'),
"--train_steps", config.get('train_steps'),
"--rank", config.get('rank'),
Expand Down
2 changes: 1 addition & 1 deletion tabs/general_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
inputs = self.update_form(self.config)
inputs = self.update_form()
self.components = OrderedDict(inputs)
children = []
for child in self.settings_column.children:
Expand Down
2 changes: 1 addition & 1 deletion tabs/prepare_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
self.components = OrderedDict(self.update_form(self.config))
self.components = OrderedDict(self.update_form())
for i in range(len(self.settings_column.children)):
keys = list(self.components.keys())
properties[keys[i]] = self.settings_column.children[i]
Expand Down
6 changes: 3 additions & 3 deletions tabs/tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def add_buttons(self):
outputs=[self.save_status, self.config_file_box, *self.get_properties().values()]
)

def update_form(self, config):
def update_form(self):
inputs = dict()

for key, value in config.items():
for key, value in self.config.items():
category = 'Other'
for categories in self.config_categories.keys():
if key in self.config_categories[categories]:
Expand Down Expand Up @@ -114,6 +114,6 @@ def update_properties(self, *args):

properties_values[index] = value
#properties[key].value = value
return ["Config loaded. Edit below:", config_file_box, *properties_values]
return ["Config loaded.", config_file_box, *properties_values]
except Exception as e:
return [f"Error loading config: {e}", config_file_box, *properties_values]
2 changes: 1 addition & 1 deletion tabs/training_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
inputs = self.update_form(self.config)
inputs = self.update_form()
self.components = OrderedDict(inputs)
children = []
for child in self.settings_column.children:
Expand Down
2 changes: 1 addition & 1 deletion tabs/training_tab_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
self.components = OrderedDict(self.update_form(self.config))
self.components = OrderedDict(self.update_form())
for i in range(len(self.settings_column.children)):
keys = list(self.components.keys())
properties[keys[i]] = self.settings_column.children[i]
Expand Down
1 change: 0 additions & 1 deletion trainer_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def validate(self):
'lr_scheduler',
'lr_warmup_steps',
'max_grad_norm',
'mixed_precision',
'model_name',
'nccl_timeout',
'optimizer',
Expand Down
Loading