|
26 | 26 | from flaxdiff.utils import RandomMarkovState |
27 | 27 | from flax.training import dynamic_scale as dynamic_scale_lib |
28 | 28 | from dataclasses import dataclass |
| 29 | +import shutil |
29 | 30 | import gc |
30 | 31 |
|
31 | 32 | PROCESS_COLOR_MAP = { |
@@ -73,6 +74,54 @@ class SimpleTrainState(train_state.TrainState): |
73 | 74 | metrics: Metrics |
74 | 75 | dynamic_scale: dynamic_scale_lib.DynamicScale |
75 | 76 |
|
| 77 | +def move_contents_to_subdir(target_dir, new_subdir_name): |
| 78 | + # --- 1. Validate Target Directory --- |
| 79 | + if not os.path.isdir(target_dir): |
| 80 | + print(f"Error: Target directory '{target_dir}' not found or is not a directory.") |
| 81 | + return |
| 82 | + # --- 2. Define Paths --- |
| 83 | + # Construct the full path for the new subdirectory |
| 84 | + new_subdir_path = os.path.join(target_dir, new_subdir_name) |
| 85 | + # --- 3. Create New Subdirectory --- |
| 86 | + try: |
| 87 | + # Create the subdirectory. |
| 88 | + # exist_ok=True prevents an error if the directory already exists. |
| 89 | + os.makedirs(new_subdir_path, exist_ok=True) |
| 90 | + print(f"Subdirectory '{new_subdir_path}' created or already exists.") |
| 91 | + except OSError as e: |
| 92 | + print(f"Error creating subdirectory '{new_subdir_path}': {e}") |
| 93 | + return # Stop execution if subdirectory creation fails |
| 94 | + # --- 4. List Contents of Target Directory --- |
| 95 | + try: |
| 96 | + items_to_move = os.listdir(target_dir) |
| 97 | + except OSError as e: |
| 98 | + print(f"Error listing contents of '{target_dir}': {e}") |
| 99 | + return # Stop if we can't list directory contents |
| 100 | + # --- 5. Move Items --- |
| 101 | + print(f"Moving items from '{target_dir}' to '{new_subdir_path}'...") |
| 102 | + moved_count = 0 |
| 103 | + error_count = 0 |
| 104 | + for item_name in items_to_move: |
| 105 | + # Construct the full path of the item in the target directory |
| 106 | + source_path = os.path.join(target_dir, item_name) |
| 107 | + # IMPORTANT: Skip the newly created subdirectory itself! |
| 108 | + if source_path == new_subdir_path: |
| 109 | + continue |
| 110 | + # Construct the destination path inside the new subdirectory |
| 111 | + destination_path = os.path.join(new_subdir_path, item_name) |
| 112 | + # Move the item |
| 113 | + try: |
| 114 | + shutil.move(source_path, destination_path) |
| 115 | + # print(f" Moved: '{item_name}'") # Uncomment for verbose output |
| 116 | + moved_count += 1 |
| 117 | + except Exception as e: |
| 118 | + print(f" Error moving '{item_name}': {e}") |
| 119 | + error_count += 1 |
| 120 | + print(f"\nOperation complete.") |
| 121 | + print(f" Successfully moved: {moved_count} item(s).") |
| 122 | + if error_count > 0: |
| 123 | + print(f" Errors encountered: {error_count} item(s).") |
| 124 | + |
76 | 125 | @dataclass |
77 | 126 | class SimpleTrainer: |
78 | 127 | state: SimpleTrainState |
@@ -124,6 +173,17 @@ def __init__(self, |
124 | 173 | if train_start_step_override is None: |
125 | 174 | train_start_step_override = run.summary['train/step'] + 1 |
126 | 175 | print(f"Resuming from previous run {wandb_config['id']} with start step {train_start_step_override}") |
| 176 | + |
| 177 | + # If load_from_checkpoint is not set, and an artifact is found, load the artifact |
| 178 | + if load_from_checkpoint is None: |
| 179 | + model_artifacts = [i for i in run.logged_artifacts() if i.type == 'model'] |
| 180 | + if model_artifacts: |
| 181 | + artifact = model_artifacts[0] |
| 182 | + artifact_dir = artifact.download() |
| 183 | + print(f"Loading model from artifact {artifact.name} at {artifact_dir}") |
| 184 | + # Move the artifact's contents |
| 185 | + load_from_checkpoint = os.path.join(artifact_dir, str(run.summary['train/step'])) |
| 186 | + move_contents_to_subdir(artifact_dir, load_from_checkpoint) |
127 | 187 |
|
128 | 188 | # define our custom x axis metric |
129 | 189 | self.wandb.define_metric("train/step") |
@@ -272,6 +332,7 @@ def load(self, checkpoint_path=None, checkpoint_step=None): |
272 | 332 | f"{step}") |
273 | 333 | self.loaded_checkpoint_path = loaded_checkpoint_path |
274 | 334 | ckpt = checkpointer.restore(step) |
| 335 | + |
275 | 336 | state = ckpt['state'] |
276 | 337 | best_state = ckpt['best_state'] |
277 | 338 | rngstate = ckpt['rngs'] |
@@ -590,5 +651,5 @@ def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps |
590 | 651 | ) |
591 | 652 | print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index])) |
592 | 653 |
|
593 | | - self.save(epochs) |
| 654 | + self.save(epochs)# |
594 | 655 | return self.state |
0 commit comments