-
Notifications
You must be signed in to change notification settings - Fork 134
Expand file tree
/
Copy pathrun.py
More file actions
267 lines (237 loc) · 11 KB
/
run.py
File metadata and controls
267 lines (237 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import argparse
import yaml
from box import Box
import os
import torch
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from lightning.pytorch.loggers import WandbLogger
from typing import List
from math import ceil
import numpy as np
from lightning.pytorch.strategies import FSDPStrategy, DDPStrategy
from src.inference.download import download
from src.data.asset import Asset
from src.data.extract import get_files
from src.data.dataset import UniRigDatasetModule, DatasetConfig, ModelInput
from src.data.datapath import Datapath
from src.data.transform import TransformConfig
from src.tokenizer.spec import TokenizerConfig
from src.tokenizer.parse import get_tokenizer
from src.model.parse import get_model
from src.system.parse import get_system, get_writer
from tqdm import tqdm
import time
def load(task: str, path: str) -> Box:
if path.endswith('.yaml'):
path = path.removesuffix('.yaml')
path += '.yaml'
print(f"\033[92mload {task} config: {path}\033[0m")
return Box(yaml.safe_load(open(path, 'r')))
def nullable_string(val):
if not val:
return None
return val
if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--seed", type=int, required=False, default=123,
help="random seed")
parser.add_argument("--input", type=nullable_string, required=False, default=None,
help="a single input file or files splited by comma")
parser.add_argument("--input_dir", type=nullable_string, required=False, default=None,
help="input directory")
parser.add_argument("--output", type=nullable_string, required=False, default=None,
help="filename for a single output")
parser.add_argument("--output_dir", type=nullable_string, required=False, default=None,
help="output directory")
parser.add_argument("--npz_dir", type=nullable_string, required=False, default='tmp',
help="intermediate npz directory")
parser.add_argument("--cls", type=nullable_string, required=False, default=None,
help="class name")
parser.add_argument("--data_name", type=nullable_string, required=False, default=None,
help="npz filename from skeleton phase")
args = parser.parse_args()
L.seed_everything(args.seed, workers=True)
task = load('task', args.task)
mode = task.mode
assert mode in ['train', 'predict', 'validate']
if args.input is not None or args.input_dir is not None:
assert args.output_dir is not None or args.output is not None, 'output or output_dir must be specified'
assert args.npz_dir is not None, 'npz_dir must be specified'
files = get_files(
data_name=task.components.data_name,
inputs=args.input,
input_dataset_dir=args.input_dir,
output_dataset_dir=args.npz_dir,
force_override=True,
warning=False,
)
files = [f[1] for f in files]
if len(files) > 1 and args.output is not None:
print("\033[92mwarning: output is specified, but multiple files are detected. Output will be written.\033[0m")
datapath = Datapath(files=files, cls=args.cls)
else:
datapath = None
data_config = load('data', os.path.join('configs/data', task.components.data))
transform_config = load('transform', os.path.join('configs/transform', task.components.transform))
# get tokenizer
tokenizer_config = task.components.get('tokenizer', None)
if tokenizer_config is not None:
tokenizer_config = load('tokenizer', os.path.join('configs/tokenizer', task.components.tokenizer))
tokenizer_config = TokenizerConfig.parse(config=tokenizer_config)
# get data name
data_name = task.components.get('data_name', 'raw_data.npz')
if args.data_name is not None:
data_name = args.data_name
# get train dataset
train_dataset_config = data_config.get('train_dataset_config', None)
if train_dataset_config is not None:
train_dataset_config = DatasetConfig.parse(config=train_dataset_config)
# get train transform
train_transform_config = transform_config.get('train_transform_config', None)
if train_transform_config is not None:
train_transform_config = TransformConfig.parse(config=train_transform_config)
# get predict dataset
predict_dataset_config = data_config.get('predict_dataset_config', None)
if predict_dataset_config is not None:
predict_dataset_config = DatasetConfig.parse(config=predict_dataset_config).split_by_cls()
# get predict transform
predict_transform_config = transform_config.get('predict_transform_config', None)
if predict_transform_config is not None:
predict_transform_config = TransformConfig.parse(config=predict_transform_config)
# get validate dataset
validate_dataset_config = data_config.get('validate_dataset_config', None)
if validate_dataset_config is not None:
validate_dataset_config = DatasetConfig.parse(config=validate_dataset_config).split_by_cls()
# get validate transform
validate_transform_config = transform_config.get('validate_transform_config', None)
if validate_transform_config is not None:
validate_transform_config = TransformConfig.parse(config=validate_transform_config)
# get model
model_config = task.components.get('model', None)
if model_config is not None:
model_config = load('model', os.path.join('configs/model', model_config))
if tokenizer_config is not None:
tokenizer = get_tokenizer(config=tokenizer_config)
else:
tokenizer = None
model = get_model(tokenizer=tokenizer, **model_config)
else:
model = None
# set data
data = UniRigDatasetModule(
process_fn=None if model is None else model._process_fn,
train_dataset_config=train_dataset_config,
predict_dataset_config=predict_dataset_config,
predict_transform_config=predict_transform_config,
validate_dataset_config=validate_dataset_config,
train_transform_config=train_transform_config,
validate_transform_config=validate_transform_config,
tokenizer_config=tokenizer_config,
debug=False,
data_name=data_name,
datapath=datapath,
cls=args.cls,
)
# add call backs
callbacks = []
## get checkpoint callback
checkpoint_config = task.get('checkpoint', None)
if checkpoint_config is not None:
checkpoint_config['dirpath'] = os.path.join('experiments', task.experiment_name)
callbacks.append(ModelCheckpoint(**checkpoint_config))
## get writer callback
writer_config = task.get('writer', None)
if writer_config is not None:
assert predict_transform_config is not None, 'missing predict_transform_config in transform'
if args.output_dir is not None or args.output is not None:
if args.output is not None:
assert args.output.endswith('.fbx'), 'output must be .fbx'
writer_config['npz_dir'] = args.npz_dir
writer_config['output_dir'] = args.output_dir
writer_config['output_name'] = args.output
writer_config['user_mode'] = True
callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
# get trainer
trainer_config = task.get('trainer', {})
# get scheduler
scheduler_config = task.get('scheduler', None)
optimizer_config = task.get('optimizer', None)
loss_config = task.get('loss', None)
# get system
system_config = task.components.get('system', None)
if system_config is not None:
system_config = load('system', os.path.join('configs/system', system_config))
system = get_system(
**system_config,
model=model,
optimizer_config=optimizer_config,
loss_config=loss_config,
scheduler_config=scheduler_config,
steps_per_epoch=1 if train_dataset_config is None else
ceil(len(data.train_dataloader()) // trainer_config.devices // trainer_config.num_nodes),
)
else:
system = None
wandb_config = task.get('wandb', None)
if wandb_config is not None:
logger = WandbLogger(
config={
'task': task,
'data': data_config,
'tokenizer': tokenizer_config,
'train_dataset_config': train_dataset_config,
'validate_dataset_config': validate_dataset_config,
'predict_dataset_config': predict_dataset_config,
'train_transform_config': train_transform_config,
'validate_transform_config': validate_transform_config,
'predict_transform_config': predict_transform_config,
'model_config': model_config,
'optimizer_config': optimizer_config,
'system_config': system_config,
'checkpoint_config': checkpoint_config,
'writer_config': writer_config,
},
log_model=True,
**wandb_config
)
if logger.experiment.id is not None:
print(f"\033[92mWandbLogger started: {logger.experiment.id}\033[0m")
# Get the run URL using wandb.run.get_url() which is more reliable
run_url = logger.experiment.get_url() if hasattr(logger.experiment, 'get_url') else logger.experiment.url
print(f"\033[92mWandbLogger url: {run_url}\033[0m")
else:
print("\033[91mWandbLogger failed to start\033[0m")
else:
logger = None
# set ckpt path
resume_from_checkpoint = task.get('resume_from_checkpoint', None)
resume_from_checkpoint = download(resume_from_checkpoint)
if trainer_config.get('strategy', None) == "fsdp":
strategy = FSDPStrategy(
# Enable activation checkpointing on these layers
auto_wrap_policy={
torch.nn.MultiheadAttention
},
activation_checkpointing_policy={
torch.nn.Linear,
torch.nn.MultiheadAttention,
},
)
trainer_config['strategy'] = strategy
trainer = L.Trainer(
callbacks=callbacks,
logger=logger,
**trainer_config,
)
if mode == 'train':
trainer.fit(system, datamodule=data, ckpt_path=resume_from_checkpoint)
elif mode == 'predict':
assert resume_from_checkpoint is not None, 'expect resume_from_checkpoint in task'
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
elif mode == 'validate':
trainer.validate(system, datamodule=data, ckpt_path=resume_from_checkpoint)
else:
assert 0