diff --git a/evaluate.py b/evaluate.py new file mode 100755 index 00000000..bb415891 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,99 @@ +from typing import List +import yaml +import os + +import torch +import torch.distributed as dist + +import pydantic +from omegaconf import OmegaConf +from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader +from models.ema import EMAHelper +import copy + +# import torch._dynamo +# torch._dynamo.config.suppress_errors = True + +class EvalConfig(pydantic.BaseModel): + checkpoint: str + + save_outputs: List[str] = [] + # save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"] + + +def launch(): + eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore + + RANK = 0 + WORLD_SIZE = 1 + CPU_PROCESS_GROUP = None + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + # CPU GLOO process group + CPU_PROCESS_GROUP = dist.new_group(backend="gloo") + assert ( + dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE + ) + + with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f: + config = PretrainConfig(**yaml.safe_load(f)) + + config.eval_save_outputs = eval_cfg.save_outputs + config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint) + + # Dataloader + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + # Models + train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) + + # Try unwrap torch.compile + try: + train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True) + except: + train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True) + + train_state.step = 0 + ckpt_filename = os.path.basename(eval_cfg.checkpoint) + if ckpt_filename.startswith("step_"): + train_state.step = int(ckpt_filename.removeprefix("step_")) + + ema_helper = None + if config.ema: + print('Setup EMA') + ema_helper = EMAHelper(mu=config.ema_rate) + ema_helper.register(train_state.model) + if config.ema: + ema_helper.update(train_state.model) + + # Evaluate + print ("Starting evaluation") + + if config.ema: + print("SWITCH TO EMA") + train_state_eval = copy.deepcopy(train_state) + train_state_eval.model = ema_helper.ema_copy(train_state_eval.model) + else: + train_state_eval = train_state + train_state_eval.model.eval() + metrics = evaluate(config, train_state_eval, eval_loader, eval_metadata, + evaluators=[], + rank=RANK, + world_size=WORLD_SIZE, + cpu_group=CPU_PROCESS_GROUP) + + if metrics is not None: + print (metrics) + + +if __name__ == "__main__": + launch() diff --git a/models/layers.py b/models/layers.py old mode 100644 new mode 100755 index 5d5264bf..28fd40ab --- a/models/layers.py +++ b/models/layers.py @@ -89,8 +89,14 @@ def __init__(self, dim, max_position_embeddings, base, device=None): # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = nn.Buffer(emb.cos(), persistent=False) - self.sin_cached = nn.Buffer(emb.sin(), persistent=False) + # self.cos_cached = nn.Buffer(emb.cos(), persistent=False) + # self.sin_cached = nn.Buffer(emb.sin(), persistent=False) + + # --- CORRECTED CODE BLOCK --- + self.register_buffer('cos_cached', emb.cos(), persistent=False) + self.register_buffer('sin_cached', emb.sin(), persistent=False) + # --- END OF CORRECTION --- + def forward(self): return self.cos_cached, self.sin_cached diff --git a/models/recursive_reasoning/trm.py b/models/recursive_reasoning/trm.py old mode 100644 new mode 100755 index 5c3e39df..e278ec86 --- a/models/recursive_reasoning/trm.py +++ b/models/recursive_reasoning/trm.py @@ -150,8 +150,12 @@ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) # Initial states - self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) - self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + # self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + # self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('H_init', h_init_tensor, persistent=True) + l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('L_init', l_init_tensor, persistent=True) # Q head special init # Init Q to (almost) zero for faster learning during bootstrapping diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py index f369205e..ca64c868 100644 --- a/models/sparse_embedding.py +++ b/models/sparse_embedding.py @@ -13,17 +13,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, ini super().__init__() self.cast_to = cast_to + # --- CORRECTED CODE BLOCK --- # Real Weights - # Truncated LeCun normal init - self.weights = nn.Buffer( - trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True - ) + weights_tensor = trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) + self.register_buffer('weights', weights_tensor) # Local weights and IDs - # Local embeddings, with gradient, not persistent - self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) - # Local embedding IDs, not persistent - self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + local_weights_tensor = torch.zeros(batch_size, embedding_dim, requires_grad=True) + self.register_buffer('local_weights', local_weights_tensor, persistent=False) + + local_ids_tensor = torch.zeros(batch_size, dtype=torch.int32) + self.register_buffer('local_ids', local_ids_tensor, persistent=False) + # --- END OF CORRECTION --- def forward(self, inputs: torch.Tensor) -> torch.Tensor: if not self.training: @@ -78,21 +79,21 @@ def step(self, closure=None): # type: ignore else: assert False + assert local_weights_grad is not None assert local_ids is not None assert weights is not None - + # Apply SignSGD # Adam ≈ SignSGD if gradient is very sparse - if local_weights_grad is not None: - _sparse_emb_signsgd_dist( - local_weights_grad, - local_ids, - weights, - - lr=group["lr"], - weight_decay=group["weight_decay"], - world_size=group["world_size"] - ) + _sparse_emb_signsgd_dist( + local_weights_grad, + local_ids, + weights, + + lr=group["lr"], + weight_decay=group["weight_decay"], + world_size=group["world_size"] + ) def _sparse_emb_signsgd_dist( @@ -112,10 +113,10 @@ def _sparse_emb_signsgd_dist( if world_size > 1: all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) - all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) - dist.all_gather_into_tensor(all_ids, local_ids) + dist.all_gather_into_tensor(all_ids, local_ids) # Unique grad_ids, inv = all_ids.unique(return_inverse=True) diff --git a/pretrain.py b/pretrain.py old mode 100644 new mode 100755 index b9072e25..727ec291 --- a/pretrain.py +++ b/pretrain.py @@ -17,7 +17,7 @@ import hydra import pydantic from omegaconf import DictConfig -from adam_atan2 import AdamATan2 +from adam_atan2_pytorch import AdamAtan2 from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path @@ -147,7 +147,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, # Optimizers and lr if config.arch.puzzle_emb_ndim == 0: optimizers = [ - AdamATan2( + AdamAtan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, @@ -161,7 +161,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore - lr=0, # Needs to be set by scheduler + lr=0.0000001, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ) @@ -177,9 +177,9 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ), - AdamATan2( + AdamAtan2( model.parameters(), - lr=0, # Needs to be set by scheduler + lr=0.0000001, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) diff --git a/requirements.txt b/requirements.txt index cea0f55f..504610ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch -adam-atan2 +adam-atan2-pytorch einops tqdm coolname @@ -17,4 +17,4 @@ setuptools-scm pydantic-core huggingface_hub numba -triton \ No newline at end of file +triton