diff --git a/config/config_categories.yaml b/config/config_categories.yaml index b1758f0..3a1809f 100644 --- a/config/config_categories.yaml +++ b/config/config_categories.yaml @@ -1,5 +1,6 @@ -Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p -Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint +Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p, precompute_conditions +Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator -Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config \ No newline at end of file +Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config +Model: model_name, pretrained_model_name_or_path, text_encoder_dtype, text_encoder_2_dtype, text_encoder_3_dtype, vae_dtype, layerwise_upcasting_modules, layerwise_upcasting_storage_dtype, layerwise_upcasting_granularity \ No newline at end of file diff --git a/config/config_template.yaml b/config/config_template.yaml index 07f3c80..3982058 100644 --- a/config/config_template.yaml +++ b/config/config_template.yaml @@ -20,6 +20,9 @@ gpu_ids: '0' gradient_accumulation_steps: 4 gradient_checkpointing: true id_token: afkx +layerwise_upcasting_modules: [none, transformer] +layerwise_upcasting_skip_modules_pattern: 'patch_embed pos_embed x_embedder context_embedder ^proj_in$ ^proj_out$ norm' +layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2] image_resolution_buckets: 512x768 lora_alpha: 128 lr: 0.0001 @@ -27,7 +30,6 @@ lr_num_cycles: 1 lr_scheduler: ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup'] lr_warmup_steps: 400 max_grad_norm: 1.0 -mixed_precision: [bf16, fp16, 'no'] model_name: ltx_video nccl_timeout: 1800 num_validation_videos: 0 @@ -45,6 +47,7 @@ text_encoder_dtype: [bf16, fp16, fp32] text_encoder_2_dtype: [bf16, fp16, fp32] text_encoder_3_dtype: [bf16, fp16, fp32] tracker_name: finetrainers +transformer_dtype: [bf16, fp16, fp32] train_steps: 3000 training_type: lora use_8bit_bnb: false diff --git a/pyproject.toml b/pyproject.toml index 82b7ecb..4fc6511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "finetrainers-ui" -version = "0.9.3" +version = "0.10.0" dependencies = [ "gradio", "torch>=2.4.1" diff --git a/run_trainer.py b/run_trainer.py index 2c3b887..e4cc7d6 100644 --- a/run_trainer.py +++ b/run_trainer.py @@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str): assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required" model_cmd = ["--model_name", config.get('model_name'), - "--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path')] + "--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'), + "--text_encoder_dtype", config.get('text_encoder_dtype'), + "--text_encoder_2_dtype", config.get('text_encoder_2_dtype'), + "--text_encoder_3_dtype", config.get('text_encoder_3_dtype'), + "--vae_dtype", config.get('vae_dtype')] + + if config.get('layerwise_upcasting_modules') != 'none': + model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'), + "--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'), + "--layerwise_upcasting_skip_modules_pattern", config.get('layerwise_upcasting_skip_modules_pattern')] dataset_cmd = ["--data_root", config.get('data_root'), "--video_column", config.get('video_column'), @@ -36,6 +45,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str): "--text_encoder_2_dtype", config.get('text_encoder_2_dtype'), "--text_encoder_3_dtype", config.get('text_encoder_3_dtype'), "--vae_dtype", config.get('vae_dtype'), + "--transformer_dtype", config.get('transformer_dtype'), '--precompute_conditions' if config.get('precompute_conditions') else ''] if config.get('dataset_file'): dataset_cmd += ["--dataset_file", config.get('dataset_file')] @@ -47,7 +57,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str): training_cmd = ["--training_type", config.get('training_type'), "--seed", config.get('seed'), - "--mixed_precision", config.get('mixed_precision'), "--batch_size", config.get('batch_size'), "--train_steps", config.get('train_steps'), "--rank", config.get('rank'), diff --git a/tabs/general_tab.py b/tabs/general_tab.py index 842268f..8e47736 100644 --- a/tabs/general_tab.py +++ b/tabs/general_tab.py @@ -15,7 +15,7 @@ def __init__(self, title, config_file_path, allow_load=False): try: with self.settings_column: - inputs = self.update_form(self.config) + inputs = self.update_form() self.components = OrderedDict(inputs) children = [] for child in self.settings_column.children: diff --git a/tabs/prepare_tab.py b/tabs/prepare_tab.py index 8fd882a..e45bab7 100644 --- a/tabs/prepare_tab.py +++ b/tabs/prepare_tab.py @@ -20,7 +20,7 @@ def __init__(self, title, config_file_path, allow_load=False): try: with self.settings_column: - self.components = OrderedDict(self.update_form(self.config)) + self.components = OrderedDict(self.update_form()) for i in range(len(self.settings_column.children)): keys = list(self.components.keys()) properties[keys[i]] = self.settings_column.children[i] diff --git a/tabs/tab.py b/tabs/tab.py index 9bc030f..6d71814 100644 --- a/tabs/tab.py +++ b/tabs/tab.py @@ -71,10 +71,10 @@ def add_buttons(self): outputs=[self.save_status, self.config_file_box, *self.get_properties().values()] ) - def update_form(self, config): + def update_form(self): inputs = dict() - for key, value in config.items(): + for key, value in self.config.items(): category = 'Other' for categories in self.config_categories.keys(): if key in self.config_categories[categories]: @@ -114,6 +114,6 @@ def update_properties(self, *args): properties_values[index] = value #properties[key].value = value - return ["Config loaded. Edit below:", config_file_box, *properties_values] + return ["Config loaded.", config_file_box, *properties_values] except Exception as e: return [f"Error loading config: {e}", config_file_box, *properties_values] \ No newline at end of file diff --git a/tabs/training_tab.py b/tabs/training_tab.py index 74c6d23..904cec2 100644 --- a/tabs/training_tab.py +++ b/tabs/training_tab.py @@ -30,7 +30,7 @@ def __init__(self, title, config_file_path, allow_load=False): try: with self.settings_column: - inputs = self.update_form(self.config) + inputs = self.update_form() self.components = OrderedDict(inputs) children = [] for child in self.settings_column.children: diff --git a/tabs/training_tab_legacy.py b/tabs/training_tab_legacy.py index f754ba4..6b65b99 100644 --- a/tabs/training_tab_legacy.py +++ b/tabs/training_tab_legacy.py @@ -17,7 +17,7 @@ def __init__(self, title, config_file_path, allow_load=False): try: with self.settings_column: - self.components = OrderedDict(self.update_form(self.config)) + self.components = OrderedDict(self.update_form()) for i in range(len(self.settings_column.children)): keys = list(self.components.keys()) properties[keys[i]] = self.settings_column.children[i] diff --git a/trainer_config_validator.py b/trainer_config_validator.py index dbaa2cf..fe4492e 100644 --- a/trainer_config_validator.py +++ b/trainer_config_validator.py @@ -29,7 +29,6 @@ def validate(self): 'lr_scheduler', 'lr_warmup_steps', 'max_grad_norm', - 'mixed_precision', 'model_name', 'nccl_timeout', 'optimizer',