-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtrainer.py
More file actions
424 lines (365 loc) · 16.7 KB
/
trainer.py
File metadata and controls
424 lines (365 loc) · 16.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import argparse
import os
from typing import Optional, Union
import torch
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import TensorBoardLogger
from flow_matching.path import AffineProbPath
from flow_matching.path.scheduler import CondOTScheduler
from utils.general_utils import create_dataloader, load_and_prepare_data, load_config
from utils.motfm_logging import get_logger
from utils.utils_fm import build_model, validate_and_save_samples
logger = get_logger(__name__)
class FlowMatchingDataModule(pl.LightningDataModule):
"""Lightning ``DataModule`` wrapping the existing data helpers."""
def __init__(self, config: dict):
super().__init__()
self.config = config
self.train_data: Optional[dict] = None
self.val_data: Optional[dict] = None
def setup(self, stage: Optional[str] = None) -> None:
data_config = self.config["data_args"]
model_config = self.config.get("model_args", {})
logger.info(
f"Setting up data module for stage='{stage}' with pickle='{data_config['pickle_path']}'."
)
spatial_dims = model_config.get("spatial_dims", None)
if spatial_dims is not None:
spatial_dims = int(spatial_dims)
# Normalization knobs (optional in config; defaults preserve existing behavior).
image_norm = data_config.get("image_norm", "minmax_0_1")
mask_norm = data_config.get("mask_norm", "minmax_0_1")
norm_scope = data_config.get("norm_scope", "global")
clip_percentiles = data_config.get("clip_percentiles", None)
if clip_percentiles is not None:
clip_percentiles = (float(clip_percentiles[0]), float(clip_percentiles[1]))
norm_eps = float(data_config.get("norm_eps", 1e-6))
# Class mapping: prefer explicit ordering if provided.
class_values = data_config.get("class_values", None)
class_to_idx = {c: i for i, c in enumerate(class_values)} if class_values else None
class_conditioning = bool(model_config.get("with_conditioning", False))
expected_num_classes = None
if class_conditioning:
if model_config.get("cross_attention_dim", None) is None:
raise ValueError(
"`model_args.with_conditioning` is True but `model_args.cross_attention_dim` is missing."
)
expected_num_classes = int(model_config["cross_attention_dim"])
if class_values and expected_num_classes != len(class_values):
raise ValueError(
f"`model_args.cross_attention_dim`={expected_num_classes} does not match "
f"`data_args.class_values` length ({len(class_values)})."
)
def _load(split: str) -> dict:
return load_and_prepare_data(
pickle_path=data_config["pickle_path"],
split=split,
convert_classes_to_onehot=True,
spatial_dims=spatial_dims,
image_norm=image_norm,
mask_norm=mask_norm,
norm_scope=norm_scope,
clip_percentiles=clip_percentiles,
norm_eps=norm_eps,
class_to_idx=class_to_idx,
num_classes=expected_num_classes,
class_mapping_split=data_config.get("split_train", "train"),
)
mask_conditioning = bool(model_config.get("mask_conditioning", False))
def _assert_required_keys(data: dict, *, split_name: str) -> None:
if mask_conditioning and "masks" not in data:
raise ValueError(
f"`model_args.mask_conditioning` is True but split '{split_name}' has no masks."
)
if class_conditioning and "classes" not in data:
raise ValueError(
f"`model_args.with_conditioning` is True but split '{split_name}' has no classes."
)
if stage in (None, "fit"):
self.train_data = _load(data_config["split_train"])
self.val_data = _load(data_config["split_val"])
_assert_required_keys(self.train_data, split_name=data_config["split_train"])
_assert_required_keys(self.val_data, split_name=data_config["split_val"])
logger.info(
"Loaded train/val splits: "
f"train={int(self.train_data['images'].shape[0])}, "
f"val={int(self.val_data['images'].shape[0])}."
)
elif stage == "validate":
self.val_data = _load(data_config["split_val"])
_assert_required_keys(self.val_data, split_name=data_config["split_val"])
logger.info(f"Loaded validation split: val={int(self.val_data['images'].shape[0])}.")
def train_dataloader(self) -> torch.utils.data.DataLoader:
tr_args = self.config["train_args"]
sampler = None
shuffle = True
if bool(tr_args.get("class_balanced_sampling", False)):
classes = None if self.train_data is None else self.train_data.get("classes")
if classes is None:
logger.warning(
"Class-balanced sampling is enabled but no classes were found; "
"falling back to shuffle=True."
)
else:
if classes.ndim == 2:
class_idxs = classes.argmax(dim=1).to(dtype=torch.long)
num_classes = int(classes.shape[1])
elif classes.ndim == 1:
class_idxs = classes.to(dtype=torch.long)
num_classes = int(class_idxs.max().item() + 1)
else:
raise ValueError(
f"Unexpected classes tensor shape {tuple(classes.shape)}; "
"expected [N] indices or [N, K] one-hot."
)
# Inverse-frequency weighting with optional tempering via class_balance_power.
counts = torch.bincount(class_idxs, minlength=num_classes).to(dtype=torch.float32)
power = float(tr_args.get("class_balance_power", 1.0))
class_weights = counts.clamp_min(1.0).pow(-power)
sample_weights = class_weights[class_idxs].to(dtype=torch.double)
sampler = torch.utils.data.WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True,
)
shuffle = False
logger.info(
f"Using class-balanced sampling with {num_classes} classes and power={power:.3f}."
)
return create_dataloader(
Images=self.train_data["images"],
Masks=self.train_data.get("masks"),
classes=self.train_data.get("classes"),
batch_size=tr_args["batch_size"],
shuffle=shuffle,
sampler=sampler,
num_workers=int(tr_args.get("num_workers", 0)),
pin_memory=tr_args.get("pin_memory", None),
persistent_workers=tr_args.get("persistent_workers", None),
drop_last=bool(tr_args.get("drop_last", False)),
)
def val_dataloader(self) -> torch.utils.data.DataLoader:
tr_args = self.config["train_args"]
return create_dataloader(
Images=self.val_data["images"],
Masks=self.val_data.get("masks"),
classes=self.val_data.get("classes"),
batch_size=tr_args["batch_size"],
shuffle=False,
num_workers=int(tr_args.get("num_workers", 0)),
pin_memory=tr_args.get("pin_memory", None),
persistent_workers=tr_args.get("persistent_workers", None),
drop_last=False,
)
class FlowMatchingLightningModule(pl.LightningModule):
"""Lightning ``Module`` for the flow matching model."""
def __init__(self, config: dict):
super().__init__()
self.save_hyperparameters(config)
self.model = build_model(config["model_args"])
self.mask_conditioning = config["model_args"]["mask_conditioning"]
self.class_conditioning = config["model_args"]["with_conditioning"]
self.path = AffineProbPath(scheduler=CondOTScheduler())
def _compute_loss(self, batch: dict) -> torch.Tensor:
im_batch = batch["images"]
if self.mask_conditioning:
if "masks" not in batch:
raise KeyError(
"mask_conditioning is enabled but the dataloader batch has no 'masks' key."
)
mask_batch = batch["masks"]
else:
mask_batch = None
if self.class_conditioning:
if "classes" not in batch:
raise KeyError(
"class_conditioning is enabled but the dataloader batch has no 'classes' key."
)
class_batch = batch["classes"]
else:
class_batch = None
# Flow-matching target: learn velocity at a random interpolation point between noise and data.
x_0 = torch.randn_like(im_batch)
t = torch.rand(im_batch.shape[0], device=im_batch.device)
sample_info = self.path.sample(t=t, x_0=x_0, x_1=im_batch)
v_pred = self.model(
x=sample_info.x_t,
t=sample_info.t,
masks=mask_batch,
cond=class_batch,
)
return F.mse_loss(v_pred, sample_info.dx_t)
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
loss = self._compute_loss(batch)
self.log("train/loss", loss, prog_bar=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch: dict, batch_idx: int) -> None:
loss = self._compute_loss(batch)
self.log("val/loss", loss, prog_bar=True, on_epoch=True)
def configure_optimizers(self) -> optim.Optimizer:
lr = self.hparams["train_args"]["lr"]
return optim.Adam(self.model.parameters(), lr=lr)
def on_validation_epoch_end(self) -> None:
"""Run sampling/visualization at epoch end similar to utils.validate_and_save_samples."""
# Avoid duplicate work under DDP.
if hasattr(self.trainer, "is_global_zero") and not self.trainer.is_global_zero:
return
# Pull required configs
tr = self.hparams.get("train_args", {})
solver_args = self.hparams.get("solver_args", {})
# Resolve output directory from logger; fallback to default_root_dir
log_dir = None
if getattr(self.trainer, "logger", None) is not None and hasattr(
self.trainer.logger, "log_dir"
):
log_dir = self.trainer.logger.log_dir
if not log_dir:
log_dir = self.trainer.default_root_dir
# Get a fresh val dataloader
val_loader = self.trainer.datamodule.val_dataloader()
# Execute the validation sampling and saving
logger.info(f"Running validation sample export for epoch {self.current_epoch}.")
validate_and_save_samples(
model=self.model,
val_loader=val_loader,
device=self.device,
checkpoint_dir=log_dir,
epoch=self.current_epoch,
solver_config=solver_args,
max_samples=tr.get("num_val_samples", 16),
class_map=None,
mask_conditioning=self.mask_conditioning,
class_conditioning=self.class_conditioning,
)
def _resolve_resume_checkpoint(
explicit_ckpt_path: Optional[str], root_ckpt_dir: str, run_name: str
) -> Optional[str]:
"""Return the checkpoint path to resume from, if any."""
if explicit_ckpt_path:
return explicit_ckpt_path
ckpt_dir = os.path.join(root_ckpt_dir, run_name)
if not os.path.isdir(ckpt_dir):
return None
# Prefer Lightning's rolling checkpoint for exact resume state.
last_ckpt = os.path.join(ckpt_dir, "last.ckpt")
if os.path.isfile(last_ckpt):
return last_ckpt
# Fallback for older/manual checkpoint naming: pick the newest .ckpt file.
_candidates = [
os.path.join(ckpt_dir, fname)
for fname in os.listdir(ckpt_dir)
if fname.endswith(".ckpt") and os.path.isfile(os.path.join(ckpt_dir, fname))
]
if not _candidates:
return None
return max(_candidates, key=os.path.getmtime)
def _resolve_strategy(
accelerator: Union[str, int, list, tuple],
devices: Union[str, int, list, tuple],
):
"""
Use DDP only for true multi-GPU execution.
For single device (or CPU), keep strategy="auto" to avoid needless overhead.
"""
if str(accelerator).lower() not in {"gpu", "cuda"}:
return "auto"
if isinstance(devices, int):
return DDPStrategy(find_unused_parameters=True) if devices > 1 else "auto"
if isinstance(devices, (list, tuple)):
return DDPStrategy(find_unused_parameters=True) if len(devices) > 1 else "auto"
if isinstance(devices, str):
d = devices.strip().lower()
# "auto" and explicit single-device strings should stay on auto strategy.
if d in {"auto", "1"}:
return "auto"
# Comma-separated ids, e.g. "0,1"
if "," in d:
ids = [x for x in d.split(",") if x.strip() != ""]
return DDPStrategy(find_unused_parameters=True) if len(ids) > 1 else "auto"
return "auto"
def main() -> None:
parser = argparse.ArgumentParser(description="Train the flow matching model with Lightning.")
parser.add_argument(
"--config_path",
type=str,
default="configs/default.yaml",
help="Path to the configuration file.",
)
args = parser.parse_args()
config = load_config(args.config_path)
run_name = os.path.splitext(os.path.basename(args.config_path))[0]
tr = config["train_args"]
root_ckpt_dir = tr["checkpoint_dir"]
logger.info(f"Loaded training config: {args.config_path}")
logger.info(f"Run name: {run_name}")
logger.info(f"Checkpoint root directory: {root_ckpt_dir}")
seed = tr.get("seed")
if seed is not None:
seed = int(seed)
pl.seed_everything(seed, workers=True)
logger.info(f"Using seed={seed} for reproducible training.")
# Data and model modules
datamodule = FlowMatchingDataModule(config)
model = FlowMatchingLightningModule(config)
# Logging and callbacks
tb_logger = TensorBoardLogger(save_dir=root_ckpt_dir, name=run_name)
ckpt_every = max(1, int(tr.get("val_freq", 5)))
ckpt_cb = ModelCheckpoint(
dirpath=os.path.join(root_ckpt_dir, run_name),
filename="epoch{epoch:03d}-valloss{val/loss:.6f}",
monitor="val/loss",
mode="min",
save_top_k=3,
save_last=True,
auto_insert_metric_name=False,
every_n_epochs=ckpt_every,
)
lr_cb = LearningRateMonitor(logging_interval="step")
cbs = [ckpt_cb, lr_cb]
# Precision setup with safe bf16/fp16 detection
_bf16_supported = (
torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)()
)
_fp16_supported = torch.cuda.is_available()
default_precision = (
"bf16-mixed" if _bf16_supported else ("16-mixed" if _fp16_supported else "32-true")
)
precision = tr.get("precision", default_precision)
resume_ckpt = _resolve_resume_checkpoint(tr.get("ckpt_path"), root_ckpt_dir, run_name)
if resume_ckpt:
logger.info(f"Resuming training from checkpoint: {resume_ckpt}")
else:
logger.info("No checkpoint found. Starting training from scratch.")
accelerator = tr.get("accelerator", "auto")
devices = tr.get("devices", "auto")
strategy = _resolve_strategy(accelerator=accelerator, devices=devices)
deterministic = bool(tr.get("deterministic", False))
logger.info(
f"Trainer runtime: accelerator={accelerator}, devices={devices}, "
f"strategy={strategy}, precision={precision}, deterministic={deterministic}."
)
trainer = pl.Trainer(
default_root_dir=root_ckpt_dir,
max_epochs=tr["num_epochs"],
precision=precision,
accumulate_grad_batches=tr.get("gradient_accumulation_steps", 8),
gradient_clip_val=tr.get("grad_clip_norm", 0.0) or None,
check_val_every_n_epoch=ckpt_every,
enable_progress_bar=True,
logger=tb_logger,
callbacks=cbs,
# Distributed/accelerator knobs
accelerator=accelerator,
devices=devices,
strategy=strategy,
deterministic=deterministic,
log_every_n_steps=tr.get("log_every_n_steps", 50),
num_sanity_val_steps=tr.get("num_sanity_val_steps", 0),
)
trainer.fit(model, datamodule=datamodule, ckpt_path=resume_ckpt)
if __name__ == "__main__":
main()