11import os
22import signal
33import subprocess
4- import time
4+
5+ import psutil
56
67from config import Config
78
@@ -18,87 +19,77 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
1819 assert config .get ('data_root' ), "Data root required"
1920 assert config .get ('pretrained_model_name_or_path' ), "pretrained_model_name_or_path required"
2021
21- # Model arguments
22- model_cmd = f"--model_name { config .get ('model_name' )} \
23- --pretrained_model_name_or_path { config .get ('pretrained_model_name_or_path' )} "
24-
25- # Dataset arguments
26- dataset_cmd = f"--data_root { config .get ('data_root' )} \
27- --video_column { config .get ('video_column' )} \
28- --caption_column { config .get ('caption_column' )} \
29- --id_token { config .get ('id_token' )} \
30- --video_resolution_buckets { config .get ('video_resolution_buckets' )} \
31- --caption_dropout_p { config .get ('caption_dropout_p' )} \
32- --caption_dropout_technique { config .get ('caption_dropout_technique' )} \
33- { '--precompute_conditions' if config .get ('precompute_conditions' ) else '' } \
34- --text_encoder_dtype { config .get ('text_encoder_dtype' )} \
35- --text_encoder_2_dtype { config .get ('text_encoder_2_dtype' )} \
36- --text_encoder_3_dtype { config .get ('text_encoder_3_dtype' )} \
37- --vae_dtype { config .get ('vae_dtype' )} "
38-
39- # Dataloader arguments
40- dataloader_cmd = f"--dataloader_num_workers { config .get ('dataloader_num_workers' )} "
22+ model_cmd = ["--model_name" , config .get ('model_name' ),
23+ "--pretrained_model_name_or_path" , config .get ('pretrained_model_name_or_path' )]
24+
25+ dataset_cmd = ["--data_root" , config .get ('data_root' ),
26+ "--video_column" , config .get ('video_column' ),
27+ "--caption_column" , config .get ('caption_column' ),
28+ "--id_token" , config .get ('id_token' ),
29+ "--video_resolution_buckets" ]
30+ dataset_cmd += config .get ('video_resolution_buckets' ).split (' ' )
31+ dataset_cmd += ["--caption_dropout_p" , config .get ('caption_dropout_p' ),
32+ "--caption_dropout_technique" , config .get ('caption_dropout_technique' ),
33+ "--text_encoder_dtype" , config .get ('text_encoder_dtype' ),
34+ "--text_encoder_2_dtype" , config .get ('text_encoder_2_dtype' ),
35+ "--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
36+ "--vae_dtype" , config .get ('vae_dtype' ),
37+ '--precompute_conditions' if config .get ('precompute_conditions' ) else '' ]
38+
39+ dataloader_cmd = ["--dataloader_num_workers" , config .get ('dataloader_num_workers' )]
4140
4241 # Diffusion arguments TODO: replace later
43- diffusion_cmd = f" { config .get ('diffusion_options' )} "
44-
45- # Training arguments
46- training_cmd = f "--training_type { config .get ('training_type' ) } \
47- --seed { config .get ('seed' ) } \
48- --mixed_precision { config .get ('mixed_precision' ) } \
49- --batch_size { config .get ('batch_size' ) } \
50- --train_steps { config .get ('train_steps' ) } \
51- --rank { config .get ('rank' ) } \
52- --lora_alpha { config . get ( 'lora_alpha' ) } \
53- --target_modules { config .get ('target_modules' )} \
54- --gradient_accumulation_steps { config .get ('gradient_accumulation_steps' )} \
55- { '--gradient_checkpointing' if config .get ('gradient_checkpointing' ) else '' } \
56- --checkpointing_steps { config .get ('checkpointing_steps' )} \
57- --checkpointing_limit { config .get ('checkpointing_limit' )} \
58- { '--enable_slicing' if config .get ('enable_slicing' ) else '' } \
59- { '--enable_tiling' if config .get ('enable_tiling' ) else '' } "
42+ diffusion_cmd = [ config .get ('diffusion_options' )]
43+
44+ training_cmd = [ "--training_type" , config . get ( 'training_type' ),
45+ "--seed" , config .get ('seed' ),
46+ "--mixed_precision" , config .get ('mixed_precision' ),
47+ "--batch_size" , config .get ('batch_size' ),
48+ "--train_steps" , config .get ('train_steps' ),
49+ "--rank" , config .get ('rank' ),
50+ "--lora_alpha" , config .get ('lora_alpha' ),
51+ "--target_modules" ]
52+ training_cmd += config .get ('target_modules' ). split ( ' ' )
53+ training_cmd += [ " --gradient_accumulation_steps" , config .get ('gradient_accumulation_steps' ),
54+ '--gradient_checkpointing' if config .get ('gradient_checkpointing' ) else '' ,
55+ " --checkpointing_steps" , config .get ('checkpointing_steps' ),
56+ " --checkpointing_limit" , config .get ('checkpointing_limit' ),
57+ '--enable_slicing' if config .get ('enable_slicing' ) else '' ,
58+ '--enable_tiling' if config .get ('enable_tiling' ) else '' ]
6059
6160 if config .get ('resume_from_checkpoint' ):
62- training_cmd += f"--resume_from_checkpoint { config .get ('resume_from_checkpoint' )} "
63-
64- # Optimizer arguments
65- optimizer_cmd = f"--optimizer { config .get ('optimizer' )} \
66- --lr { config .get ('lr' )} \
67- --lr_scheduler { config .get ('lr_scheduler' )} \
68- --lr_warmup_steps { config .get ('lr_warmup_steps' )} \
69- --lr_num_cycles { config .get ('lr_num_cycles' )} \
70- --beta1 { config .get ('beta1' )} \
71- --beta2 { config .get ('beta2' )} \
72- --weight_decay { config .get ('weight_decay' )} \
73- --epsilon { config .get ('epsilon' )} \
74- --max_grad_norm { config .get ('max_grad_norm' )} \
75- { '--use_8bit_bnb' if config .get ('use_8bit_bnb' ) else '' } "
76-
77- # Validation arguments
78- validation_cmd = f"--validation_prompts \" { config .get ('validation_prompts' )} \" \
79- --num_validation_videos { config .get ('num_validation_videos' )} \
80- --validation_steps { config .get ('validation_steps' )} "
81-
82- # Miscellaneous arguments
83- miscellaneous_cmd = f"--tracker_name { config .get ('tracker_name' )} \
84- --output_dir { config .get ('output_dir' )} \
85- --nccl_timeout { config .get ('nccl_timeout' )} \
86- --report_to { config .get ('report_to' )} "
87-
88- cmd = f"accelerate launch --config_file { finetrainers_path } /accelerate_configs/{ config .get ('accelerate_config' )} --gpu_ids { config .get ('gpu_ids' )} { finetrainers_path } /train.py \
89- { model_cmd } \
90- { dataset_cmd } \
91- { dataloader_cmd } \
92- { diffusion_cmd } \
93- { training_cmd } \
94- { optimizer_cmd } \
95- { validation_cmd } \
96- { miscellaneous_cmd } "
97-
98- print (cmd )
61+ training_cmd += ["--resume_from_checkpoint" , config .get ('resume_from_checkpoint' )]
62+
63+ optimizer_cmd = ["--optimizer" , config .get ('optimizer' ),
64+ "--lr" , config .get ('lr' ),
65+ "--lr_scheduler" , config .get ('lr_scheduler' ),
66+ "--lr_warmup_steps" , config .get ('lr_warmup_steps' ),
67+ "--lr_num_cycles" , config .get ('lr_num_cycles' ),
68+ "--beta1" , config .get ('beta1' ),
69+ "--beta2" , config .get ('beta2' ),
70+ "--weight_decay" , config .get ('weight_decay' ),
71+ "--epsilon" , config .get ('epsilon' ),
72+ "--max_grad_norm" , config .get ('max_grad_norm' ),
73+ '--use_8bit_bnb' if config .get ('use_8bit_bnb' ) else '' ]
74+
75+ validation_cmd = ["--validation_prompts" if config .get ('validation_prompts' ) else '' , config .get ('validation_prompts' ) or '' ,
76+ "--num_validation_videos" , config .get ('num_validation_videos' ),
77+ "--validation_steps" , config .get ('validation_steps' )]
78+
79+ miscellaneous_cmd = ["--tracker_name" , config .get ('tracker_name' ),
80+ "--output_dir" , config .get ('output_dir' ),
81+ "--nccl_timeout" , config .get ('nccl_timeout' ),
82+ "--report_to" , config .get ('report_to' )]
83+ accelerate_cmd = ["accelerate" , "launch" , "--config_file" , f"{ finetrainers_path } /accelerate_configs/{ config .get ('accelerate_config' )} " , "--gpu_ids" , config .get ('gpu_ids' )]
84+ cmd = accelerate_cmd + [f"{ finetrainers_path } /train.py" ] + model_cmd + dataset_cmd + dataloader_cmd + diffusion_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
85+ fixed_cmd = []
86+ for i in range (len (cmd )):
87+ if cmd [i ] != '' :
88+ fixed_cmd .append (f"{ cmd [i ]} " )
89+ print (' ' .join (fixed_cmd ))
9990 self .running = True
10091 with open (log_file , "w" ) as output_file :
101- self .process = subprocess .Popen (cmd , shell = True , stdout = output_file , stderr = output_file , text = True )
92+ self .process = subprocess .Popen (fixed_cmd , shell = False , stdout = output_file , stderr = output_file , text = True , preexec_fn = os . setsid )
10293 self .process .communicate ()
10394 return self .process
10495
@@ -108,12 +99,20 @@ def stop(self):
10899 try :
109100 self .running = False
110101 if self .process :
111- self .process .terminate ()
112- time .sleep (3 )
113- if self .process .poll () is None :
114- self .process .kill ()
102+ os .killpg (os .getpgid (self .process .pid ), signal .SIGTERM )
103+ self .terminate_process_tree (self .process .pid )
115104 except Exception as e :
116105 return f"Error stopping training: { e } "
117106 finally :
118107 self .process .wait ()
119- return "Training forcibly stopped"
108+ return "Training forcibly stopped"
109+
110+ def terminate_process_tree (pid ):
111+ try :
112+ parent = psutil .Process (pid )
113+ children = parent .children (recursive = True ) # Get child processes
114+ for child in children :
115+ child .terminate ()
116+ parent .terminate ()
117+ except psutil .NoSuchProcess :
118+ pass
0 commit comments