Skip to content

Commit e1cd838

Browse files
authored
Merge pull request #14 from neph1/remove-shell-use
remove shell, fixes stop button
2 parents 9c63e72 + 989b5d0 commit e1cd838

File tree

1 file changed

+80
-81
lines changed

1 file changed

+80
-81
lines changed

run_trainer.py

Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import signal
33
import subprocess
4-
import time
4+
5+
import psutil
56

67
from config import Config
78

@@ -18,87 +19,77 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
1819
assert config.get('data_root'), "Data root required"
1920
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"
2021

21-
# Model arguments
22-
model_cmd = f"--model_name {config.get('model_name')} \
23-
--pretrained_model_name_or_path {config.get('pretrained_model_name_or_path')}"
24-
25-
# Dataset arguments
26-
dataset_cmd = f"--data_root {config.get('data_root')} \
27-
--video_column {config.get('video_column')} \
28-
--caption_column {config.get('caption_column')} \
29-
--id_token {config.get('id_token')} \
30-
--video_resolution_buckets {config.get('video_resolution_buckets')} \
31-
--caption_dropout_p {config.get('caption_dropout_p')} \
32-
--caption_dropout_technique {config.get('caption_dropout_technique')} \
33-
{'--precompute_conditions' if config.get('precompute_conditions') else ''} \
34-
--text_encoder_dtype {config.get('text_encoder_dtype')} \
35-
--text_encoder_2_dtype {config.get('text_encoder_2_dtype')} \
36-
--text_encoder_3_dtype {config.get('text_encoder_3_dtype')} \
37-
--vae_dtype {config.get('vae_dtype')} "
38-
39-
# Dataloader arguments
40-
dataloader_cmd = f"--dataloader_num_workers {config.get('dataloader_num_workers')}"
22+
model_cmd = ["--model_name", config.get('model_name'),
23+
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path')]
24+
25+
dataset_cmd = ["--data_root", config.get('data_root'),
26+
"--video_column", config.get('video_column'),
27+
"--caption_column", config.get('caption_column'),
28+
"--id_token", config.get('id_token'),
29+
"--video_resolution_buckets"]
30+
dataset_cmd += config.get('video_resolution_buckets').split(' ')
31+
dataset_cmd += ["--caption_dropout_p", config.get('caption_dropout_p'),
32+
"--caption_dropout_technique", config.get('caption_dropout_technique'),
33+
"--text_encoder_dtype", config.get('text_encoder_dtype'),
34+
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
35+
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
36+
"--vae_dtype", config.get('vae_dtype'),
37+
'--precompute_conditions' if config.get('precompute_conditions') else '']
38+
39+
dataloader_cmd = ["--dataloader_num_workers", config.get('dataloader_num_workers')]
4140

4241
# Diffusion arguments TODO: replace later
43-
diffusion_cmd = f"{config.get('diffusion_options')}"
44-
45-
# Training arguments
46-
training_cmd = f"--training_type {config.get('training_type')} \
47-
--seed {config.get('seed')} \
48-
--mixed_precision {config.get('mixed_precision')} \
49-
--batch_size {config.get('batch_size')} \
50-
--train_steps {config.get('train_steps')} \
51-
--rank {config.get('rank')} \
52-
--lora_alpha {config.get('lora_alpha')} \
53-
--target_modules {config.get('target_modules')} \
54-
--gradient_accumulation_steps {config.get('gradient_accumulation_steps')} \
55-
{'--gradient_checkpointing' if config.get('gradient_checkpointing') else ''} \
56-
--checkpointing_steps {config.get('checkpointing_steps')} \
57-
--checkpointing_limit {config.get('checkpointing_limit')} \
58-
{'--enable_slicing' if config.get('enable_slicing') else ''} \
59-
{'--enable_tiling' if config.get('enable_tiling') else ''} "
42+
diffusion_cmd = [config.get('diffusion_options')]
43+
44+
training_cmd = ["--training_type", config.get('training_type'),
45+
"--seed", config.get('seed'),
46+
"--mixed_precision", config.get('mixed_precision'),
47+
"--batch_size", config.get('batch_size'),
48+
"--train_steps", config.get('train_steps'),
49+
"--rank", config.get('rank'),
50+
"--lora_alpha", config.get('lora_alpha'),
51+
"--target_modules"]
52+
training_cmd += config.get('target_modules').split(' ')
53+
training_cmd += ["--gradient_accumulation_steps", config.get('gradient_accumulation_steps'),
54+
'--gradient_checkpointing' if config.get('gradient_checkpointing') else '',
55+
"--checkpointing_steps", config.get('checkpointing_steps'),
56+
"--checkpointing_limit", config.get('checkpointing_limit'),
57+
'--enable_slicing' if config.get('enable_slicing') else '',
58+
'--enable_tiling' if config.get('enable_tiling') else '']
6059

6160
if config.get('resume_from_checkpoint'):
62-
training_cmd += f"--resume_from_checkpoint {config.get('resume_from_checkpoint')}"
63-
64-
# Optimizer arguments
65-
optimizer_cmd = f"--optimizer {config.get('optimizer')} \
66-
--lr {config.get('lr')} \
67-
--lr_scheduler {config.get('lr_scheduler')} \
68-
--lr_warmup_steps {config.get('lr_warmup_steps')} \
69-
--lr_num_cycles {config.get('lr_num_cycles')} \
70-
--beta1 {config.get('beta1')} \
71-
--beta2 {config.get('beta2')} \
72-
--weight_decay {config.get('weight_decay')} \
73-
--epsilon {config.get('epsilon')} \
74-
--max_grad_norm {config.get('max_grad_norm')} \
75-
{'--use_8bit_bnb' if config.get('use_8bit_bnb') else ''}"
76-
77-
# Validation arguments
78-
validation_cmd = f"--validation_prompts \"{config.get('validation_prompts')}\" \
79-
--num_validation_videos {config.get('num_validation_videos')} \
80-
--validation_steps {config.get('validation_steps')}"
81-
82-
# Miscellaneous arguments
83-
miscellaneous_cmd = f"--tracker_name {config.get('tracker_name')} \
84-
--output_dir {config.get('output_dir')} \
85-
--nccl_timeout {config.get('nccl_timeout')} \
86-
--report_to {config.get('report_to')}"
87-
88-
cmd = f"accelerate launch --config_file {finetrainers_path}/accelerate_configs/{config.get('accelerate_config')} --gpu_ids {config.get('gpu_ids')} {finetrainers_path}/train.py \
89-
{model_cmd} \
90-
{dataset_cmd} \
91-
{dataloader_cmd} \
92-
{diffusion_cmd} \
93-
{training_cmd} \
94-
{optimizer_cmd} \
95-
{validation_cmd} \
96-
{miscellaneous_cmd}"
97-
98-
print(cmd)
61+
training_cmd += ["--resume_from_checkpoint", config.get('resume_from_checkpoint')]
62+
63+
optimizer_cmd = ["--optimizer", config.get('optimizer'),
64+
"--lr", config.get('lr'),
65+
"--lr_scheduler", config.get('lr_scheduler'),
66+
"--lr_warmup_steps", config.get('lr_warmup_steps'),
67+
"--lr_num_cycles", config.get('lr_num_cycles'),
68+
"--beta1", config.get('beta1'),
69+
"--beta2", config.get('beta2'),
70+
"--weight_decay", config.get('weight_decay'),
71+
"--epsilon", config.get('epsilon'),
72+
"--max_grad_norm", config.get('max_grad_norm'),
73+
'--use_8bit_bnb' if config.get('use_8bit_bnb') else '']
74+
75+
validation_cmd = ["--validation_prompts" if config.get('validation_prompts') else '', config.get('validation_prompts') or '',
76+
"--num_validation_videos", config.get('num_validation_videos'),
77+
"--validation_steps", config.get('validation_steps')]
78+
79+
miscellaneous_cmd = ["--tracker_name", config.get('tracker_name'),
80+
"--output_dir", config.get('output_dir'),
81+
"--nccl_timeout", config.get('nccl_timeout'),
82+
"--report_to", config.get('report_to')]
83+
accelerate_cmd = ["accelerate", "launch", "--config_file", f"{finetrainers_path}/accelerate_configs/{config.get('accelerate_config')}", "--gpu_ids", config.get('gpu_ids')]
84+
cmd = accelerate_cmd + [f"{finetrainers_path}/train.py"] + model_cmd + dataset_cmd + dataloader_cmd + diffusion_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
85+
fixed_cmd = []
86+
for i in range(len(cmd)):
87+
if cmd[i] != '':
88+
fixed_cmd.append(f"{cmd[i]}")
89+
print(' '.join(fixed_cmd))
9990
self.running = True
10091
with open(log_file, "w") as output_file:
101-
self.process = subprocess.Popen(cmd, shell=True, stdout=output_file, stderr=output_file, text=True)
92+
self.process = subprocess.Popen(fixed_cmd, shell=False, stdout=output_file, stderr=output_file, text=True, preexec_fn=os.setsid)
10293
self.process.communicate()
10394
return self.process
10495

@@ -108,12 +99,20 @@ def stop(self):
10899
try:
109100
self.running = False
110101
if self.process:
111-
self.process.terminate()
112-
time.sleep(3)
113-
if self.process.poll() is None:
114-
self.process.kill()
102+
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
103+
self.terminate_process_tree(self.process.pid)
115104
except Exception as e:
116105
return f"Error stopping training: {e}"
117106
finally:
118107
self.process.wait()
119-
return "Training forcibly stopped"
108+
return "Training forcibly stopped"
109+
110+
def terminate_process_tree(pid):
111+
try:
112+
parent = psutil.Process(pid)
113+
children = parent.children(recursive=True) # Get child processes
114+
for child in children:
115+
child.terminate()
116+
parent.terminate()
117+
except psutil.NoSuchProcess:
118+
pass

0 commit comments

Comments
 (0)