Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List
import yaml
import os

import torch
import torch.distributed as dist

import pydantic
from omegaconf import OmegaConf
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
from models.ema import EMAHelper
import copy

# import torch._dynamo
# torch._dynamo.config.suppress_errors = True

class EvalConfig(pydantic.BaseModel):
checkpoint: str

save_outputs: List[str] = []
# save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]


def launch():
eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore

RANK = 0
WORLD_SIZE = 1
CPU_PROCESS_GROUP = None
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")

RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

# CPU GLOO process group
CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
assert (
dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
)

with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
config = PretrainConfig(**yaml.safe_load(f))

config.eval_save_outputs = eval_cfg.save_outputs
config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)

# Dataloader
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)

# Models
train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)

# Try unwrap torch.compile
try:
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
except:
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)

train_state.step = 0
ckpt_filename = os.path.basename(eval_cfg.checkpoint)
if ckpt_filename.startswith("step_"):
train_state.step = int(ckpt_filename.removeprefix("step_"))

ema_helper = None
if config.ema:
print('Setup EMA')
ema_helper = EMAHelper(mu=config.ema_rate)
ema_helper.register(train_state.model)
if config.ema:
ema_helper.update(train_state.model)

# Evaluate
print ("Starting evaluation")

if config.ema:
print("SWITCH TO EMA")
train_state_eval = copy.deepcopy(train_state)
train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
else:
train_state_eval = train_state
train_state_eval.model.eval()
metrics = evaluate(config, train_state_eval, eval_loader, eval_metadata,
evaluators=[],
rank=RANK,
world_size=WORLD_SIZE,
cpu_group=CPU_PROCESS_GROUP)

if metrics is not None:
print (metrics)


if __name__ == "__main__":
launch()
10 changes: 8 additions & 2 deletions models/layers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ def __init__(self, dim, max_position_embeddings, base, device=None):

# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
# self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
# self.sin_cached = nn.Buffer(emb.sin(), persistent=False)

# --- CORRECTED CODE BLOCK ---
self.register_buffer('cos_cached', emb.cos(), persistent=False)
self.register_buffer('sin_cached', emb.sin(), persistent=False)
# --- END OF CORRECTION ---


def forward(self):
return self.cos_cached, self.sin_cached
Expand Down
8 changes: 6 additions & 2 deletions models/recursive_reasoning/trm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,12 @@ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])

# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
self.register_buffer('H_init', h_init_tensor, persistent=True)
l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
self.register_buffer('L_init', l_init_tensor, persistent=True)

# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
Expand Down
43 changes: 22 additions & 21 deletions models/sparse_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, ini
super().__init__()
self.cast_to = cast_to

# --- CORRECTED CODE BLOCK ---
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
)
weights_tensor = trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
self.register_buffer('weights', weights_tensor)

# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
local_weights_tensor = torch.zeros(batch_size, embedding_dim, requires_grad=True)
self.register_buffer('local_weights', local_weights_tensor, persistent=False)

local_ids_tensor = torch.zeros(batch_size, dtype=torch.int32)
self.register_buffer('local_ids', local_ids_tensor, persistent=False)
# --- END OF CORRECTION ---

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
Expand Down Expand Up @@ -78,21 +79,21 @@ def step(self, closure=None): # type: ignore
else:
assert False

assert local_weights_grad is not None
assert local_ids is not None
assert weights is not None

# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
if local_weights_grad is not None:
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,

lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"]
)
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,

lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"]
)


def _sparse_emb_signsgd_dist(
Expand All @@ -112,10 +113,10 @@ def _sparse_emb_signsgd_dist(

if world_size > 1:
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)

dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
dist.all_gather_into_tensor(all_ids, local_ids)

# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
Expand Down
10 changes: 5 additions & 5 deletions pretrain.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import hydra
import pydantic
from omegaconf import DictConfig
from adam_atan2 import AdamATan2
from adam_atan2_pytorch import AdamAtan2

from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
from utils.functions import load_model_class, get_model_source_path
Expand Down Expand Up @@ -147,7 +147,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
# Optimizers and lr
if config.arch.puzzle_emb_ndim == 0:
optimizers = [
AdamATan2(
AdamAtan2(
model.parameters(),
lr=0, # Needs to be set by scheduler
weight_decay=config.weight_decay,
Expand All @@ -161,7 +161,7 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
optimizers = [
CastedSparseEmbeddingSignSGD_Distributed(
model.model.puzzle_emb.buffers(), # type: ignore
lr=0, # Needs to be set by scheduler
lr=0.0000001, # Needs to be set by scheduler
weight_decay=config.puzzle_emb_weight_decay,
world_size=world_size
)
Expand All @@ -177,9 +177,9 @@ def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata,
weight_decay=config.puzzle_emb_weight_decay,
world_size=world_size
),
AdamATan2(
AdamAtan2(
model.parameters(),
lr=0, # Needs to be set by scheduler
lr=0.0000001, # Needs to be set by scheduler
weight_decay=config.weight_decay,
betas=(config.beta1, config.beta2)
)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
adam-atan2
adam-atan2-pytorch
einops
tqdm
coolname
Expand All @@ -17,4 +17,4 @@ setuptools-scm
pydantic-core
huggingface_hub
numba
triton
triton