Skip to content

Commit 90e8d6c

Browse files
authored
update save format (#28)
* fix * update settings * Update dtype to torch.float32 in DiT model * Remove unnecessary assert statement in DiT class * Refactor logging and checkpoint saving in train_img.py * Update save directory to include global step in ckpt_utils.py and log global step in train_img.py
1 parent ea1f83d commit 90e8d6c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

opendit/utils/ckpt_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ def save(
5757
lr_scheduler: _LRScheduler,
5858
epoch: int,
5959
step: int,
60+
global_step: int,
6061
batch_size: int,
6162
coordinator: DistCoordinator,
6263
save_dir: str,
6364
shape_dict: dict,
6465
):
65-
save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
66+
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
6667
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
6768

6869
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
@@ -79,6 +80,7 @@ def save(
7980
running_states = {
8081
"epoch": epoch,
8182
"step": step,
83+
"global_step": global_step,
8284
"sample_start_index": step * batch_size,
8385
}
8486
if coordinator.is_master():

train_img.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,15 @@ def main(args):
244244
lr_scheduler,
245245
epoch,
246246
step + 1,
247+
global_step + 1,
247248
args.batch_size,
248249
coordinator,
249250
experiment_dir,
250251
ema_shape_dict,
251252
)
252-
logger.info(f"Saved checkpoint at epoch {epoch} step {step + 1} to {experiment_dir}")
253+
logger.info(
254+
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {experiment_dir}"
255+
)
253256

254257
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
255258
dataloader.sampler.set_start_index(0)

0 commit comments

Comments
 (0)