11import os
22import signal
33import subprocess
4- import time
4+
5+ import psutil
56
67from config import Config
78
@@ -18,88 +19,92 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
1819 assert config .get ('data_root' ), "Data root required"
1920 assert config .get ('pretrained_model_name_or_path' ), "pretrained_model_name_or_path required"
2021
21- # Model arguments
22- model_cmd = f"--model_name { config .get ('model_name' )} \
23- --pretrained_model_name_or_path { config .get ('pretrained_model_name_or_path' )} \
24- --text_encoder_dtype { config .get ('text_encoder_dtype' )} \
25- --text_encoder_2_dtype { config .get ('text_encoder_2_dtype' )} \
26- --text_encoder_3_dtype { config .get ('text_encoder_3_dtype' )} \
27- --vae_dtype { config .get ('vae_dtype' )} "
28-
22+ model_cmd = ["--model_name" , config .get ('model_name' ),
23+ "--pretrained_model_name_or_path" , config .get ('pretrained_model_name_or_path' ),
24+ "--text_encoder_dtype" , config .get ('text_encoder_dtype' ),
25+ "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
26+ "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
27+ "--vae_dtype" , config .get ('vae_dtype' )]
28+
2929 if config .get ('layerwise_upcasting_modules' ) != 'none' :
30- model_cmd += f"--layerwise_upcasting_modules { config .get ('layerwise_upcasting_modules' )} \
31- --layerwise_upcasting_storage_dtype { config .get ('layerwise_upcasting_storage_dtype' )} \
32- --layerwise_upcasting_granularity { config .get ('layerwise_upcasting_granularity' )} "
33-
34- # Dataset arguments
35- dataset_cmd = f"--data_root { config .get ('data_root' )} \
36- --video_column { config .get ('video_column' )} \
37- --caption_column { config .get ('caption_column' )} \
38- --id_token { config .get ('id_token' )} \
39- --video_resolution_buckets { config .get ('video_resolution_buckets' )} \
40- --caption_dropout_p { config .get ('caption_dropout_p' )} \
41- --caption_dropout_technique { config .get ('caption_dropout_technique' )} \
42- { '--precompute_conditions' if config .get ('precompute_conditions' ) else '' } "
43-
44- # Dataloader arguments
45- dataloader_cmd = f"--dataloader_num_workers { config .get ('dataloader_num_workers' )} "
30+ model_cmd += ["--layerwise_upcasting_modules" , config .get ('layerwise_upcasting_modules' ),
31+ "--layerwise_upcasting_storage_dtype" , config .get ('layerwise_upcasting_storage_dtype' ),
32+ "--layerwise_upcasting_granularity" , config .get ('layerwise_upcasting_granularity' )]
33+
34+ dataset_cmd = ["--data_root" , config .get ('data_root' ),
35+ "--video_column" , config .get ('video_column' ),
36+ "--caption_column" , config .get ('caption_column' ),
37+ "--id_token" , config .get ('id_token' ),
38+ "--video_resolution_buckets" ]
39+ dataset_cmd += config .get ('video_resolution_buckets' ).split (' ' )
40+ dataset_cmd += ["--image_resolution_buckets" ]
41+ dataset_cmd += config .get ('image_resolution_buckets' ).split (' ' )
42+ dataset_cmd += ["--caption_dropout_p" , config .get ('caption_dropout_p' ),
43+ "--caption_dropout_technique" , config .get ('caption_dropout_technique' ),
44+ "--text_encoder_dtype" , config .get ('text_encoder_dtype' ),
45+ "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
46+ "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
47+ "--vae_dtype" , config .get ('vae_dtype' ),
48+ '--precompute_conditions' if config .get ('precompute_conditions' ) else '' ]
49+ if config .get ('dataset_file' ):
50+ dataset_cmd += ["--dataset_file" , config .get ('dataset_file' )]
51+
52+ dataloader_cmd = ["--dataloader_num_workers" , config .get ('dataloader_num_workers' )]
4653
4754 # Diffusion arguments TODO: replace later
48- diffusion_cmd = f"{ config .get ('diffusion_options' )} "
49-
50- # Training arguments
51- training_cmd = f"--training_type { config .get ('training_type' )} \
52- --seed { config .get ('seed' )} \
53- --batch_size { config .get ('batch_size' )} \
54- --train_steps { config .get ('train_steps' )} \
55- --rank { config .get ('rank' )} \
56- --lora_alpha { config .get ('lora_alpha' )} \
57- --target_modules { config .get ('target_modules' )} \
58- --gradient_accumulation_steps { config .get ('gradient_accumulation_steps' )} \
59- { '--gradient_checkpointing' if config .get ('gradient_checkpointing' ) else '' } \
60- --checkpointing_steps { config .get ('checkpointing_steps' )} \
61- --checkpointing_limit { config .get ('checkpointing_limit' )} \
62- { '--enable_slicing' if config .get ('enable_slicing' ) else '' } \
63- { '--enable_tiling' if config .get ('enable_tiling' ) else '' } "
64-
65- # Optimizer arguments
66- optimizer_cmd = f"--optimizer { config .get ('optimizer' )} \
67- --lr { config .get ('lr' )} \
68- --lr_scheduler { config .get ('lr_scheduler' )} \
69- --lr_warmup_steps { config .get ('lr_warmup_steps' )} \
70- --lr_num_cycles { config .get ('lr_num_cycles' )} \
71- --beta1 { config .get ('beta1' )} \
72- --beta2 { config .get ('beta2' )} \
73- --weight_decay { config .get ('weight_decay' )} \
74- --epsilon { config .get ('epsilon' )} \
75- --max_grad_norm { config .get ('max_grad_norm' )} \
76- { '--use_8bit_bnb' if config .get ('use_8bit_bnb' ) else '' } "
77-
78- # Validation arguments
79- validation_cmd = f"--validation_prompts \" { config .get ('validation_prompts' )} \" \
80- --num_validation_videos { config .get ('num_validation_videos' )} \
81- --validation_steps { config .get ('validation_steps' )} "
82-
83- # Miscellaneous arguments
84- miscellaneous_cmd = f"--tracker_name { config .get ('tracker_name' )} \
85- --output_dir { config .get ('output_dir' )} \
86- --nccl_timeout { config .get ('nccl_timeout' )} \
87- --report_to { config .get ('report_to' )} "
88-
89- cmd = f"accelerate launch --config_file { finetrainers_path } /accelerate_configs/{ config .get ('accelerate_config' )} --gpu_ids { config .get ('gpu_ids' )} { finetrainers_path } /train.py \
90- { model_cmd } \
91- { dataset_cmd } \
92- { dataloader_cmd } \
93- { diffusion_cmd } \
94- { training_cmd } \
95- { optimizer_cmd } \
96- { validation_cmd } \
97- { miscellaneous_cmd } "
98-
99- print (cmd )
55+ diffusion_cmd = [config .get ('diffusion_options' )]
56+
57+ training_cmd = ["--training_type" , config .get ('training_type' ),
58+ "--seed" , config .get ('seed' ),
59+ "--mixed_precision" , config .get ('mixed_precision' ),
60+ "--batch_size" , config .get ('batch_size' ),
61+ "--train_steps" , config .get ('train_steps' ),
62+ "--rank" , config .get ('rank' ),
63+ "--lora_alpha" , config .get ('lora_alpha' ),
64+ "--target_modules" ]
65+ training_cmd += config .get ('target_modules' ).split (' ' )
66+ training_cmd += ["--gradient_accumulation_steps" , config .get ('gradient_accumulation_steps' ),
67+ '--gradient_checkpointing' if config .get ('gradient_checkpointing' ) else '' ,
68+ "--checkpointing_steps" , config .get ('checkpointing_steps' ),
69+ "--checkpointing_limit" , config .get ('checkpointing_limit' ),
70+ '--enable_slicing' if config .get ('enable_slicing' ) else '' ,
71+ '--enable_tiling' if config .get ('enable_tiling' ) else '' ]
72+ if config .get ('enable_model_cpu_offload' ):
73+ training_cmd += ["--enable_model_cpu_offload" ]
74+
75+ if config .get ('resume_from_checkpoint' ):
76+ training_cmd += ["--resume_from_checkpoint" , config .get ('resume_from_checkpoint' )]
77+
78+ optimizer_cmd = ["--optimizer" , config .get ('optimizer' ),
79+ "--lr" , config .get ('lr' ),
80+ "--lr_scheduler" , config .get ('lr_scheduler' ),
81+ "--lr_warmup_steps" , config .get ('lr_warmup_steps' ),
82+ "--lr_num_cycles" , config .get ('lr_num_cycles' ),
83+ "--beta1" , config .get ('beta1' ),
84+ "--beta2" , config .get ('beta2' ),
85+ "--weight_decay" , config .get ('weight_decay' ),
86+ "--epsilon" , config .get ('epsilon' ),
87+ "--max_grad_norm" , config .get ('max_grad_norm' ),
88+ '--use_8bit_bnb' if config .get ('use_8bit_bnb' ) else '' ]
89+
90+ validation_cmd = ["--validation_prompts" if config .get ('validation_prompts' ) else '' , config .get ('validation_prompts' ) or '' ,
91+ "--num_validation_videos" , config .get ('num_validation_videos' ),
92+ "--validation_steps" , config .get ('validation_steps' )]
93+
94+ miscellaneous_cmd = ["--tracker_name" , config .get ('tracker_name' ),
95+ "--output_dir" , config .get ('output_dir' ),
96+ "--nccl_timeout" , config .get ('nccl_timeout' ),
97+ "--report_to" , config .get ('report_to' )]
98+ accelerate_cmd = ["accelerate" , "launch" , "--config_file" , f"{ finetrainers_path } /accelerate_configs/{ config .get ('accelerate_config' )} " , "--gpu_ids" , config .get ('gpu_ids' )]
99+ cmd = accelerate_cmd + [f"{ finetrainers_path } /train.py" ] + model_cmd + dataset_cmd + dataloader_cmd + diffusion_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
100+ fixed_cmd = []
101+ for i in range (len (cmd )):
102+ if cmd [i ] != '' :
103+ fixed_cmd .append (f"{ cmd [i ]} " )
104+ print (' ' .join (fixed_cmd ))
100105 self .running = True
101106 with open (log_file , "w" ) as output_file :
102- self .process = subprocess .Popen (cmd , shell = True , stdout = output_file , stderr = output_file , text = True )
107+ self .process = subprocess .Popen (fixed_cmd , shell = False , stdout = output_file , stderr = output_file , text = True , preexec_fn = os . setsid )
103108 self .process .communicate ()
104109 return self .process
105110
@@ -109,12 +114,20 @@ def stop(self):
109114 try :
110115 self .running = False
111116 if self .process :
112- self .process .terminate ()
113- time .sleep (3 )
114- if self .process .poll () is None :
115- self .process .kill ()
117+ os .killpg (os .getpgid (self .process .pid ), signal .SIGTERM )
118+ self .terminate_process_tree (self .process .pid )
116119 except Exception as e :
117120 return f"Error stopping training: { e } "
118121 finally :
119122 self .process .wait ()
120- return "Training forcibly stopped"
123+ return "Training forcibly stopped"
124+
125+ def terminate_process_tree (pid ):
126+ try :
127+ parent = psutil .Process (pid )
128+ children = parent .children (recursive = True ) # Get child processes
129+ for child in children :
130+ child .terminate ()
131+ parent .terminate ()
132+ except psutil .NoSuchProcess :
133+ pass
0 commit comments