diff --git a/__pycache__/pretrain.cpython-312.pyc b/__pycache__/pretrain.cpython-312.pyc new file mode 100644 index 00000000..ac83fcbb Binary files /dev/null and b/__pycache__/pretrain.cpython-312.pyc differ diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py index c1442750..cca9f4da 100644 --- a/dataset/build_arc_dataset.py +++ b/dataset/build_arc_dataset.py @@ -19,7 +19,7 @@ class DataProcessConfig(BaseModel): output_dir: str subsets: List[str] test_set_name: str - test_set_name2: str = "your_test_set" + test_set_name2: str = "evaluation2" seed: int = 42 num_aug: int = 1000 puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets diff --git a/models/__pycache__/ema.cpython-312.pyc b/models/__pycache__/ema.cpython-312.pyc new file mode 100644 index 00000000..90dde602 Binary files /dev/null and b/models/__pycache__/ema.cpython-312.pyc differ diff --git a/pretrain.py b/pretrain.py index b9072e25..53a42f57 100644 --- a/pretrain.py +++ b/pretrain.py @@ -5,6 +5,8 @@ import yaml import shutil import copy +import random +import numpy as np import torch import torch.distributed as dist @@ -127,6 +129,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, model_cls = load_model_class(config.arch.name) loss_head_cls = load_model_class(config.arch.loss.name) + checkpoint_data = None with torch.device("cuda"): model: nn.Module = model_cls(model_cfg) print(model) @@ -136,8 +139,37 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, # Load checkpoint if rank == 0: - load_checkpoint(model, config) + checkpoint_data = load_checkpoint(model, config) + # Broadcast checkpoint data (step and optimizers) to ensure all ranks are in sync + if world_size > 1: + to_broadcast = None + if rank == 0 and checkpoint_data is not None: + # Prepare data to broadcast: extract only what's needed and move to CPU + to_broadcast = { + "step": checkpoint_data.get("step", 0), + "optimizers": [] + } + + # Helper to move optimizer states to CPU + def to_cpu(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu() + if isinstance(obj, dict): + return {k: to_cpu(v) for k, v in obj.items()} + if isinstance(obj, list): + return [to_cpu(v) for v in obj] + return obj + + if "optimizers" in checkpoint_data: + to_broadcast["optimizers"] = to_cpu(checkpoint_data["optimizers"]) + + # Broadcast object list + objs = [to_broadcast] + dist.broadcast_object_list(objs, src=0) + checkpoint_data = objs[0] + + with torch.device("cuda"): # Broadcast parameters from rank 0 if world_size > 1: with torch.no_grad(): @@ -189,7 +221,18 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, config.lr ] - return model, optimizers, optimizer_lrs + # Load optimizer states if available + if checkpoint_data is not None and "optimizers" in checkpoint_data: + if rank == 0: + print(f"Loading optimizer states for {len(optimizers)} optimizers") + if len(optimizers) != len(checkpoint_data["optimizers"]): + if rank == 0: + print(f"Warning: Number of optimizers ({len(optimizers)}) does not match checkpoint ({len(checkpoint_data['optimizers'])}). Skipping optimizer load.") + else: + for opt, opt_state in zip(optimizers, checkpoint_data["optimizers"]): + opt.load_state_dict(opt_state) + + return model, optimizers, optimizer_lrs, checkpoint_data def mix_weights_direct(device, alpha, net, nets): sd = [] @@ -219,10 +262,15 @@ def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetada total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) # Model - model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size) + model, optimizers, optimizer_lrs, checkpoint_data = create_model(config, train_metadata, rank=rank, world_size=world_size) + + step = 0 + if checkpoint_data is not None and "step" in checkpoint_data: + step = checkpoint_data["step"] + print(f"Resuming from step {step}") - return TrainState( - step=0, + train_state = TrainState( + step=step, total_steps=total_steps, model=model, @@ -231,14 +279,32 @@ def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetada carry=None ) + return train_state, checkpoint_data -def save_train_state(config: PretrainConfig, train_state: TrainState): + +def save_train_state(config: PretrainConfig, train_state: TrainState, ema_helper: Optional[Any] = None): # FIXME: Only saved model. if config.checkpoint_path is None: return os.makedirs(config.checkpoint_path, exist_ok=True) - torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) + + checkpoint = { + "model": train_state.model.state_dict(), + "optimizers": [opt.state_dict() for opt in train_state.optimizers], + "step": train_state.step, + "rng": { + "torch": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, + "numpy": np.random.get_state(), + "random": random.getstate(), + } + } + + if ema_helper is not None: + checkpoint["ema_helper"] = ema_helper.state_dict() + + torch.save(checkpoint, os.path.join(config.checkpoint_path, f"step_{train_state.step}")) def load_checkpoint(model: nn.Module, config: PretrainConfig): @@ -246,7 +312,24 @@ def load_checkpoint(model: nn.Module, config: PretrainConfig): print(f"Loading checkpoint {config.load_checkpoint}") # Load state dict - state_dict = torch.load(config.load_checkpoint, map_location="cuda") + # We need weights_only=False because we save complex objects like optimizer state and RNG states + checkpoint_data = torch.load(config.load_checkpoint, map_location="cuda", weights_only=False) + + state_dict = checkpoint_data + # Check if it is the new format + if isinstance(checkpoint_data, dict) and "model" in checkpoint_data: + state_dict = checkpoint_data["model"] + + # Restore RNG state + if "rng" in checkpoint_data: + rng_state = checkpoint_data["rng"] + torch.set_rng_state(rng_state["torch"]) + if rng_state["cuda"] is not None and torch.cuda.is_available(): + torch.cuda.set_rng_state_all(rng_state["cuda"]) + np.random.set_state(rng_state["numpy"]) + random.setstate(rng_state["random"]) + else: + checkpoint_data = None # Old format, no extra data # Resize and reset puzzle emb if needed puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights" @@ -261,6 +344,9 @@ def load_checkpoint(model: nn.Module, config: PretrainConfig): ) model.load_state_dict(state_dict, assign=True) + return checkpoint_data + return None + def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): return cosine_schedule_with_warmup_lr_lambda( @@ -524,6 +610,25 @@ def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> if config.checkpoint_path is None: config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) + # Automatic resumption: if no explicit checkpoint is given, try to find the latest one in the checkpoint path + if config.load_checkpoint is None and config.checkpoint_path is not None and os.path.exists(config.checkpoint_path): + # Checkpoints are saved as "step_{step}" + max_step = -1 + max_ckpt = None + for fname in os.listdir(config.checkpoint_path): + if fname.startswith("step_") and not fname.endswith(".tmp"): # ignore tmp or other files + try: + step_val = int(fname.split("_")[1]) + if step_val > max_step: + max_step = step_val + max_ckpt = os.path.join(config.checkpoint_path, fname) + except (ValueError, IndexError): + continue + + if max_ckpt is not None: + print(f"Auto-resume: Found latest checkpoint at {max_ckpt} (step {max_step})") + config.load_checkpoint = max_ckpt + objects = [config] if world_size > 1: @@ -580,7 +685,7 @@ def launch(hydra_config: DictConfig): evaluators = [] # Train state - train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) + train_state, checkpoint_data = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) # Progress bar and logger progress_bar = None @@ -594,6 +699,9 @@ def launch(hydra_config: DictConfig): print('Setup EMA') ema_helper = EMAHelper(mu=config.ema_rate) ema_helper.register(train_state.model) + if checkpoint_data is not None and "ema_helper" in checkpoint_data: + print("Loading EMA helper state") + ema_helper.load_state_dict(checkpoint_data["ema_helper"]) # Training Loop for _iter_id in range(total_iters): @@ -639,7 +747,8 @@ def launch(hydra_config: DictConfig): if RANK == 0: print("SAVE CHECKPOINT") if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): - save_train_state(config, train_state_eval) + # Save online state (and EMA helper if available) to ensure resumption is correct + save_train_state(config, train_state, ema_helper=ema_helper) if config.ema: del train_state_eval diff --git a/utils/__pycache__/functions.cpython-312.pyc b/utils/__pycache__/functions.cpython-312.pyc new file mode 100644 index 00000000..59bdf44f Binary files /dev/null and b/utils/__pycache__/functions.cpython-312.pyc differ diff --git a/verify_resumption.py b/verify_resumption.py new file mode 100644 index 00000000..2d29e7f1 --- /dev/null +++ b/verify_resumption.py @@ -0,0 +1,254 @@ + +import os +import shutil +import torch +import numpy as np +import random +import sys +from unittest.mock import MagicMock, patch + +# Disable compilation for testing +os.environ["DISABLE_COMPILE"] = "1" + +# Mock adam_atan2 +sys.modules["adam_atan2"] = MagicMock() +sys.modules["adam_atan2"].AdamATan2 = torch.optim.AdamW + +# Mock models.sparse_embedding +sys.modules["models.sparse_embedding"] = MagicMock() +class MockOptimizer(torch.optim.Optimizer): + def __init__(self, params, defaults=None): + if defaults is None: defaults = {} + defaults['lr'] = 0.01 + super().__init__(params, defaults) + def step(self, closure=None): + pass +sys.modules["models.sparse_embedding"].CastedSparseEmbeddingSignSGD_Distributed = MockOptimizer + +# Mock puzzle_dataset +sys.modules["puzzle_dataset"] = MagicMock() +class MockPuzzleDataset: + pass +sys.modules["puzzle_dataset"].PuzzleDataset = MockPuzzleDataset +sys.modules["puzzle_dataset"].PuzzleDatasetConfig = MagicMock() +sys.modules["puzzle_dataset"].PuzzleDatasetMetadata = MagicMock() + +# Mock distributed +sys.modules["torch.distributed"] = MagicMock() +sys.modules["torch.distributed"].get_rank.return_value = 0 +sys.modules["torch.distributed"].get_world_size.return_value = 1 +sys.modules["torch.distributed"].broadcast_object_list = MagicMock() +sys.modules["torch.distributed"].broadcast = MagicMock() + +import pretrain +from pretrain import TrainState, PretrainConfig, ArchConfig, LossConfig, create_model, save_train_state, load_checkpoint, init_train_state +import torch.nn as nn + +# Mock classes +class MockArchConfig: + def __init__(self): + self.name = "test_arch" + self.loss = LossConfig(name="test_loss") + self.puzzle_emb_ndim = 0 + self.__pydantic_extra__ = {} + +class MockConfig: + def __init__(self, **kwargs): + self.arch = MockArchConfig() + for k, v in kwargs.items(): + setattr(self, k, v) + self.loss = LossConfig(name="test_loss") + self.checkpoint_path = "test_verification_checkpoints" + self.load_checkpoint = None + self.global_batch_size = 1 + self.epochs = 1 + self.lr = 0.01 + self.lr_min_ratio = 0.0 + self.lr_warmup_steps = 0 + self.weight_decay = 0.0 + self.beta1 = 0.9 + self.beta2 = 0.999 + self.puzzle_emb_lr = 0.01 + self.puzzle_emb_weight_decay = 0.0 + self.freeze_weights = False + self.ema = False + self.seed = 42 + +class MockMetadata: + def __init__(self): + self.vocab_size = 10 + self.seq_len = 10 + self.num_puzzle_identifiers = 1 + self.total_groups = 1 + self.mean_puzzle_examples = 1 + +def test_bitwise_resumption(): + print("Setting up bitwise resumption test...") + if os.path.exists("test_verification_checkpoints"): + shutil.rmtree("test_verification_checkpoints") + os.makedirs("test_verification_checkpoints") + + seed = 42 + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + class SimpleModel(nn.Module): + def __init__(self, config): + super().__init__() + self.linear = nn.Linear(10, 2) + self.model = MagicMock() + self.model.puzzle_emb = MagicMock() + self.model.puzzle_emb.buffers.return_value = [] + self.model.puzzle_emb.weights.shape = torch.Size([1, 10]) + + def forward(self, carry, batch, return_keys=[]): + # Advance all RNGs + noise = torch.randn(1, 2) + _ = np.random.rand(1) + _ = random.random() + + if batch is not None and "inputs" in batch: + inp = batch["inputs"] + else: + inp = torch.randn(1, 10) + + loss = self.linear(inp).sum() + noise.sum() + return carry, loss, {}, {}, False + + def initial_carry(self, batch): + return None + + pretrain.load_model_class = lambda name, *args: SimpleModel + + config = MockConfig( + data_paths=[], + global_batch_size=1, + epochs=1, + lr=0.01, + lr_min_ratio=0.0, + lr_warmup_steps=0, + weight_decay=0.0, + beta1=0.9, + beta2=0.999, + puzzle_emb_lr=0.0, + puzzle_emb_weight_decay=0.0, + checkpoint_path="test_verification_checkpoints", + seed=seed + ) + + metadata = MockMetadata() + + # Capture original torch.load and torch.device + original_load = torch.load + original_device = torch.device + + def mock_load(f, map_location=None, **kwargs): + return original_load(f, map_location="cpu", **kwargs) + + def real_device_mock(device_str): + if device_str == "cuda": + return original_device("cpu") + return original_device(device_str) + + with patch("torch.load", side_effect=mock_load), \ + patch("torch.device", side_effect=real_device_mock), \ + patch("torch.cuda.is_available", return_value=False): + + # --- Continuous Run (3 steps) --- + print("Running continuous training (3 steps)...") + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + train_state_cont, _ = init_train_state(config, metadata, 0, 1) + + losses_cont = [] + for i in range(3): + train_state_cont.step += 1 + loss = train_state_cont.model(None, None)[1] + loss.backward() + for opt in train_state_cont.optimizers: + opt.step() + opt.zero_grad() + losses_cont.append(loss.item()) + print(f"Step {i+1} Loss: {loss.item()}") + + print(f"Continuous Losses: {losses_cont}") + + # --- Interrupted Run --- + print("\nRunning interrupted training...") + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + train_state_part1, _ = init_train_state(config, metadata, 0, 1) + + train_state_part1.step += 1 + loss = train_state_part1.model(None, None)[1] + loss.backward() + for opt in train_state_part1.optimizers: + opt.step() + opt.zero_grad() + print(f"Part 1 Step 1 Loss: {loss.item()}") + + assert losses_cont[0] == loss.item(), "Step 1 mismatch!" + + print("Saving checkpoint at step 1...") + save_train_state(config, train_state_part1) + + print("Resuming...") + torch.manual_seed(999) + np.random.seed(999) + random.seed(999) + + config.load_checkpoint = None + + # Manual auto-resume logic simulation + if config.load_checkpoint is None and config.checkpoint_path is not None and os.path.exists(config.checkpoint_path): + max_step = -1 + max_ckpt = None + for fname in os.listdir(config.checkpoint_path): + if fname.startswith("step_") and not fname.endswith(".tmp"): + try: + step_val = int(fname.split("_")[1]) + if step_val > max_step: + max_step = step_val + max_ckpt = os.path.join(config.checkpoint_path, fname) + except (ValueError, IndexError): + continue + if max_ckpt is not None: + print(f"Auto-resume: Found {max_ckpt}") + config.load_checkpoint = max_ckpt + + expected_ckpt = os.path.join(config.checkpoint_path, "step_1") + assert config.load_checkpoint == expected_ckpt + + train_state_resumed, checkpoint_data = init_train_state(config, metadata, 0, 1) + + assert checkpoint_data is not None + assert train_state_resumed.step == 1 + + losses_resumed = [losses_cont[0]] + for i in range(2): + train_state_resumed.step += 1 + loss = train_state_resumed.model(None, None)[1] + loss.backward() + for opt in train_state_resumed.optimizers: + opt.step() + opt.zero_grad() + losses_resumed.append(loss.item()) + print(f"Resumed Step {i+2} Loss: {loss.item()}") + + print(f"Resumed Losses: {losses_resumed}") + + if np.allclose(losses_cont, losses_resumed): + print("\nSUCCESS: Bitwise resumption verified!") + else: + print("\nFAILURE: Resumption mismatch!") + print(f"Continuous: {losses_cont}") + print(f"Resumed: {losses_resumed}") + exit(1) + +if __name__ == "__main__": + test_bitwise_resumption()