Skip to content

Commit 714939f

Browse files
committed
feat: fixed mmdit + load from wandb checkpoint for training
1 parent be099fb commit 714939f

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

flaxdiff/models/simple_mmdit.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def setup(self):
614614
]
615615

616616
# Encoder blocks (from coarse to fine)
617-
self.encoder_blocks = []
617+
encoder_blocks = []
618618
for stage in range(num_stages):
619619
stage_blocks = [
620620
MMDiTBlock(
@@ -632,7 +632,9 @@ def setup(self):
632632
)
633633
for i in range(self.num_layers[stage] // 2) # Half for encoder, half for decoder
634634
]
635-
self.encoder_blocks.append(stage_blocks)
635+
encoder_blocks.append(stage_blocks)
636+
637+
self.encoder_blocks = encoder_blocks
636638

637639
# Patch expanding layers (from coarse to fine)
638640
if num_stages > 1:
@@ -647,7 +649,7 @@ def setup(self):
647649
]
648650

649651
# Decoder blocks (from coarse to fine)
650-
self.decoder_blocks = []
652+
decoder_blocks = []
651653
for stage in range(num_stages-1, -1, -1):
652654
stage_blocks = [
653655
MMDiTBlock(
@@ -665,7 +667,8 @@ def setup(self):
665667
)
666668
for i in range(self.num_layers[stage] // 2) # Half for encoder, half for decoder
667669
]
668-
self.decoder_blocks.append(stage_blocks)
670+
decoder_blocks.append(stage_blocks)
671+
self.decoder_blocks = decoder_blocks
669672

670673
# Fusion layers for skip connections
671674
if num_stages > 1:

flaxdiff/trainer/simple_trainer.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from flaxdiff.utils import RandomMarkovState
2727
from flax.training import dynamic_scale as dynamic_scale_lib
2828
from dataclasses import dataclass
29+
import shutil
2930
import gc
3031

3132
PROCESS_COLOR_MAP = {
@@ -73,6 +74,54 @@ class SimpleTrainState(train_state.TrainState):
7374
metrics: Metrics
7475
dynamic_scale: dynamic_scale_lib.DynamicScale
7576

77+
def move_contents_to_subdir(target_dir, new_subdir_name):
78+
# --- 1. Validate Target Directory ---
79+
if not os.path.isdir(target_dir):
80+
print(f"Error: Target directory '{target_dir}' not found or is not a directory.")
81+
return
82+
# --- 2. Define Paths ---
83+
# Construct the full path for the new subdirectory
84+
new_subdir_path = os.path.join(target_dir, new_subdir_name)
85+
# --- 3. Create New Subdirectory ---
86+
try:
87+
# Create the subdirectory.
88+
# exist_ok=True prevents an error if the directory already exists.
89+
os.makedirs(new_subdir_path, exist_ok=True)
90+
print(f"Subdirectory '{new_subdir_path}' created or already exists.")
91+
except OSError as e:
92+
print(f"Error creating subdirectory '{new_subdir_path}': {e}")
93+
return # Stop execution if subdirectory creation fails
94+
# --- 4. List Contents of Target Directory ---
95+
try:
96+
items_to_move = os.listdir(target_dir)
97+
except OSError as e:
98+
print(f"Error listing contents of '{target_dir}': {e}")
99+
return # Stop if we can't list directory contents
100+
# --- 5. Move Items ---
101+
print(f"Moving items from '{target_dir}' to '{new_subdir_path}'...")
102+
moved_count = 0
103+
error_count = 0
104+
for item_name in items_to_move:
105+
# Construct the full path of the item in the target directory
106+
source_path = os.path.join(target_dir, item_name)
107+
# IMPORTANT: Skip the newly created subdirectory itself!
108+
if source_path == new_subdir_path:
109+
continue
110+
# Construct the destination path inside the new subdirectory
111+
destination_path = os.path.join(new_subdir_path, item_name)
112+
# Move the item
113+
try:
114+
shutil.move(source_path, destination_path)
115+
# print(f" Moved: '{item_name}'") # Uncomment for verbose output
116+
moved_count += 1
117+
except Exception as e:
118+
print(f" Error moving '{item_name}': {e}")
119+
error_count += 1
120+
print(f"\nOperation complete.")
121+
print(f" Successfully moved: {moved_count} item(s).")
122+
if error_count > 0:
123+
print(f" Errors encountered: {error_count} item(s).")
124+
76125
@dataclass
77126
class SimpleTrainer:
78127
state: SimpleTrainState
@@ -124,6 +173,17 @@ def __init__(self,
124173
if train_start_step_override is None:
125174
train_start_step_override = run.summary['train/step'] + 1
126175
print(f"Resuming from previous run {wandb_config['id']} with start step {train_start_step_override}")
176+
177+
# If load_from_checkpoint is not set, and an artifact is found, load the artifact
178+
if load_from_checkpoint is None:
179+
model_artifacts = [i for i in run.logged_artifacts() if i.type == 'model']
180+
if model_artifacts:
181+
artifact = model_artifacts[0]
182+
artifact_dir = artifact.download()
183+
print(f"Loading model from artifact {artifact.name} at {artifact_dir}")
184+
# Move the artifact's contents
185+
load_from_checkpoint = os.path.join(artifact_dir, str(run.summary['train/step']))
186+
move_contents_to_subdir(artifact_dir, load_from_checkpoint)
127187

128188
# define our custom x axis metric
129189
self.wandb.define_metric("train/step")
@@ -272,6 +332,7 @@ def load(self, checkpoint_path=None, checkpoint_step=None):
272332
f"{step}")
273333
self.loaded_checkpoint_path = loaded_checkpoint_path
274334
ckpt = checkpointer.restore(step)
335+
275336
state = ckpt['state']
276337
best_state = ckpt['best_state']
277338
rngstate = ckpt['rngs']
@@ -590,5 +651,5 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps
590651
)
591652
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
592653

593-
self.save(epochs)
654+
self.save(epochs)#
594655
return self.state

0 commit comments

Comments
 (0)