Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added __pycache__/pretrain.cpython-312.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion dataset/build_arc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DataProcessConfig(BaseModel):
output_dir: str
subsets: List[str]
test_set_name: str
test_set_name2: str = "your_test_set"
test_set_name2: str = "evaluation2"
seed: int = 42
num_aug: int = 1000
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
Expand Down
Binary file added models/__pycache__/ema.cpython-312.pyc
Binary file not shown.
129 changes: 119 additions & 10 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import yaml
import shutil
import copy
import random
import numpy as np

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -127,6 +129,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
model_cls = load_model_class(config.arch.name)
loss_head_cls = load_model_class(config.arch.loss.name)

checkpoint_data = None
with torch.device("cuda"):
model: nn.Module = model_cls(model_cfg)
print(model)
Expand All @@ -136,8 +139,37 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,

# Load checkpoint
if rank == 0:
load_checkpoint(model, config)
checkpoint_data = load_checkpoint(model, config)

# Broadcast checkpoint data (step and optimizers) to ensure all ranks are in sync
if world_size > 1:
to_broadcast = None
if rank == 0 and checkpoint_data is not None:
# Prepare data to broadcast: extract only what's needed and move to CPU
to_broadcast = {
"step": checkpoint_data.get("step", 0),
"optimizers": []
}

# Helper to move optimizer states to CPU
def to_cpu(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu()
if isinstance(obj, dict):
return {k: to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [to_cpu(v) for v in obj]
return obj

if "optimizers" in checkpoint_data:
to_broadcast["optimizers"] = to_cpu(checkpoint_data["optimizers"])

# Broadcast object list
objs = [to_broadcast]
dist.broadcast_object_list(objs, src=0)
checkpoint_data = objs[0]

with torch.device("cuda"):
# Broadcast parameters from rank 0
if world_size > 1:
with torch.no_grad():
Expand Down Expand Up @@ -189,7 +221,18 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
config.lr
]

return model, optimizers, optimizer_lrs
# Load optimizer states if available
if checkpoint_data is not None and "optimizers" in checkpoint_data:
if rank == 0:
print(f"Loading optimizer states for {len(optimizers)} optimizers")
if len(optimizers) != len(checkpoint_data["optimizers"]):
if rank == 0:
print(f"Warning: Number of optimizers ({len(optimizers)}) does not match checkpoint ({len(checkpoint_data['optimizers'])}). Skipping optimizer load.")
else:
for opt, opt_state in zip(optimizers, checkpoint_data["optimizers"]):
opt.load_state_dict(opt_state)

return model, optimizers, optimizer_lrs, checkpoint_data

def mix_weights_direct(device, alpha, net, nets):
sd = []
Expand Down Expand Up @@ -219,10 +262,15 @@ def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetada
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)

# Model
model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
model, optimizers, optimizer_lrs, checkpoint_data = create_model(config, train_metadata, rank=rank, world_size=world_size)

step = 0
if checkpoint_data is not None and "step" in checkpoint_data:
step = checkpoint_data["step"]
print(f"Resuming from step {step}")

return TrainState(
step=0,
train_state = TrainState(
step=step,
total_steps=total_steps,

model=model,
Expand All @@ -231,22 +279,57 @@ def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetada
carry=None
)

return train_state, checkpoint_data

def save_train_state(config: PretrainConfig, train_state: TrainState):

def save_train_state(config: PretrainConfig, train_state: TrainState, ema_helper: Optional[Any] = None):
# FIXME: Only saved model.
if config.checkpoint_path is None:
return

os.makedirs(config.checkpoint_path, exist_ok=True)
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))

checkpoint = {
"model": train_state.model.state_dict(),
"optimizers": [opt.state_dict() for opt in train_state.optimizers],
"step": train_state.step,
"rng": {
"torch": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
"numpy": np.random.get_state(),
"random": random.getstate(),
}
}

if ema_helper is not None:
checkpoint["ema_helper"] = ema_helper.state_dict()

torch.save(checkpoint, os.path.join(config.checkpoint_path, f"step_{train_state.step}"))


def load_checkpoint(model: nn.Module, config: PretrainConfig):
if config.load_checkpoint is not None:
print(f"Loading checkpoint {config.load_checkpoint}")

# Load state dict
state_dict = torch.load(config.load_checkpoint, map_location="cuda")
# We need weights_only=False because we save complex objects like optimizer state and RNG states
checkpoint_data = torch.load(config.load_checkpoint, map_location="cuda", weights_only=False)

state_dict = checkpoint_data
# Check if it is the new format
if isinstance(checkpoint_data, dict) and "model" in checkpoint_data:
state_dict = checkpoint_data["model"]

# Restore RNG state
if "rng" in checkpoint_data:
rng_state = checkpoint_data["rng"]
torch.set_rng_state(rng_state["torch"])
if rng_state["cuda"] is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng_state["cuda"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])
else:
checkpoint_data = None # Old format, no extra data

# Resize and reset puzzle emb if needed
puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
Expand All @@ -261,6 +344,9 @@ def load_checkpoint(model: nn.Module, config: PretrainConfig):
)
model.load_state_dict(state_dict, assign=True)

return checkpoint_data
return None


def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
return cosine_schedule_with_warmup_lr_lambda(
Expand Down Expand Up @@ -524,6 +610,25 @@ def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) ->
if config.checkpoint_path is None:
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)

# Automatic resumption: if no explicit checkpoint is given, try to find the latest one in the checkpoint path
if config.load_checkpoint is None and config.checkpoint_path is not None and os.path.exists(config.checkpoint_path):
# Checkpoints are saved as "step_{step}"
max_step = -1
max_ckpt = None
for fname in os.listdir(config.checkpoint_path):
if fname.startswith("step_") and not fname.endswith(".tmp"): # ignore tmp or other files
try:
step_val = int(fname.split("_")[1])
if step_val > max_step:
max_step = step_val
max_ckpt = os.path.join(config.checkpoint_path, fname)
except (ValueError, IndexError):
continue

if max_ckpt is not None:
print(f"Auto-resume: Found latest checkpoint at {max_ckpt} (step {max_step})")
config.load_checkpoint = max_ckpt

objects = [config]

if world_size > 1:
Expand Down Expand Up @@ -580,7 +685,7 @@ def launch(hydra_config: DictConfig):
evaluators = []

# Train state
train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
train_state, checkpoint_data = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)

# Progress bar and logger
progress_bar = None
Expand All @@ -594,6 +699,9 @@ def launch(hydra_config: DictConfig):
print('Setup EMA')
ema_helper = EMAHelper(mu=config.ema_rate)
ema_helper.register(train_state.model)
if checkpoint_data is not None and "ema_helper" in checkpoint_data:
print("Loading EMA helper state")
ema_helper.load_state_dict(checkpoint_data["ema_helper"])

# Training Loop
for _iter_id in range(total_iters):
Expand Down Expand Up @@ -639,7 +747,8 @@ def launch(hydra_config: DictConfig):
if RANK == 0:
print("SAVE CHECKPOINT")
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
save_train_state(config, train_state_eval)
# Save online state (and EMA helper if available) to ensure resumption is correct
save_train_state(config, train_state, ema_helper=ema_helper)

if config.ema:
del train_state_eval
Expand Down
Binary file added utils/__pycache__/functions.cpython-312.pyc
Binary file not shown.
Loading