Skip to content

Commit f8b899c

Browse files
committed
feature(pu): add use_sigreg option
1 parent 556b2ec commit f8b899c

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

lzero/policy/unizero.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,90 @@
2525
from .utils import configure_optimizers_nanogpt
2626

2727

28+
class SIGReg(torch.nn.Module):
29+
"""
30+
Sketched Isotropic Gaussian Regularization (SIGReg) from LeJEPA.
31+
32+
This regularization constrains learned embeddings to an optimal isotropic Gaussian distribution,
33+
which helps prevent representation collapse and improves the quality of latent states.
34+
35+
Reference: LeJEPA (https://arxiv.org/abs/2511.08544)
36+
37+
Args:
38+
knots (int): Number of integration points for the quadrature. Default: 17
39+
t_max (float): Maximum t value for integration. Default: 3.0
40+
num_slices (int): Number of random projections for slicing. Default: 256
41+
"""
42+
def __init__(self, knots=17, t_max=3.0, num_slices=256):
43+
super().__init__()
44+
# Create integration points from 0 to t_max
45+
t = torch.linspace(0, t_max, knots, dtype=torch.float32)
46+
dt = t_max / (knots - 1)
47+
48+
# Trapezoidal rule weights
49+
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
50+
weights[[0, -1]] = dt
51+
52+
# Gaussian window function
53+
window = torch.exp(-t.square() / 2.0)
54+
55+
# Register as buffers (not trainable parameters)
56+
self.register_buffer("t", t)
57+
self.register_buffer("phi", window)
58+
self.register_buffer("weights", weights * window)
59+
self.num_slices = num_slices
60+
61+
def forward(self, embeddings):
62+
"""
63+
Compute SIGReg loss for the given embeddings.
64+
65+
Args:
66+
embeddings (torch.Tensor): Input embeddings of shape (batch_size, seq_len, embed_dim)
67+
or (batch_size, embed_dim)
68+
69+
Returns:
70+
torch.Tensor: Scalar SIGReg loss value
71+
"""
72+
# Handle different input shapes
73+
if embeddings.dim() == 3:
74+
# Shape: (batch_size, seq_len, embed_dim) -> (batch_size * seq_len, embed_dim)
75+
batch_size, seq_len, embed_dim = embeddings.shape
76+
proj = embeddings.reshape(-1, embed_dim)
77+
elif embeddings.dim() == 2:
78+
# Shape: (batch_size, embed_dim)
79+
proj = embeddings
80+
else:
81+
raise ValueError(f"Expected embeddings to have 2 or 3 dimensions, got {embeddings.dim()}")
82+
83+
num_samples, embed_dim = proj.shape
84+
85+
# Generate random projection matrix A with shape (embed_dim, num_slices)
86+
A = torch.randn(embed_dim, self.num_slices, device=proj.device, dtype=proj.dtype)
87+
A = A / A.norm(p=2, dim=0, keepdim=True) # Normalize columns
88+
89+
# Project embeddings: (num_samples, embed_dim) @ (embed_dim, num_slices) = (num_samples, num_slices)
90+
x_proj = proj @ A
91+
92+
# Compute characteristic function at different t values
93+
# x_proj: (num_samples, num_slices), t: (knots,)
94+
# x_t: (num_samples, num_slices, knots)
95+
x_t = x_proj.unsqueeze(-1) * self.t
96+
97+
# Compute empirical characteristic function
98+
# cos_part: (num_slices, knots), sin_part: (num_slices, knots)
99+
cos_part = x_t.cos().mean(0) # Average over samples
100+
sin_part = x_t.sin().mean(0)
101+
102+
# Compute error from target Gaussian (phi is the target)
103+
err = (cos_part - self.phi).square() + sin_part.square()
104+
105+
# Integrate using quadrature weights
106+
statistic = (err * self.weights).sum() * num_samples
107+
108+
# Average over slices
109+
return statistic / self.num_slices
110+
111+
28112
def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float):
29113
"""
30114
Efficiently scale all weights of a module using vectorized operations.
@@ -781,6 +865,33 @@ def _init_learn(self) -> None:
781865
self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps
782866
logging.info(f"self.policy_ls_eps_start: {self.policy_ls_eps_start}")
783867

868+
# ==================== START: SIGReg Initialization ====================
869+
# Initialize SIGReg (Sketched Isotropic Gaussian Regularization) for latent state regularization
870+
self.use_sigreg = self._cfg.model.world_model_cfg.get('use_sigreg', False)
871+
if self.use_sigreg:
872+
sigreg_knots = self._cfg.model.world_model_cfg.get('sigreg_knots', 17)
873+
sigreg_t_max = self._cfg.model.world_model_cfg.get('sigreg_t_max', 3.0)
874+
sigreg_num_slices = self._cfg.model.world_model_cfg.get('sigreg_num_slices', 256)
875+
self.sigreg_weight = self._cfg.model.world_model_cfg.get('sigreg_weight', 0.02)
876+
877+
self.sigreg = SIGReg(
878+
knots=sigreg_knots,
879+
t_max=sigreg_t_max,
880+
num_slices=sigreg_num_slices
881+
).to(self._cfg.device)
882+
883+
logging.info("=" * 60)
884+
logging.info(">>> SIGReg (Sketched Isotropic Gaussian Regularization) Enabled <<<")
885+
logging.info(f" SIGReg Weight (lambda): {self.sigreg_weight:.4f}")
886+
logging.info(f" Integration Knots: {sigreg_knots}")
887+
logging.info(f" T_max: {sigreg_t_max}")
888+
logging.info(f" Number of Slices: {sigreg_num_slices}")
889+
logging.info("=" * 60)
890+
else:
891+
self.sigreg = None
892+
self.sigreg_weight = 0.0
893+
# ===================== END: SIGReg Initialization =====================
894+
784895
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
785896
"""
786897
Overview:
@@ -991,6 +1102,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
9911102

9921103
weighted_total_loss = (weights * losses.loss_total).mean()
9931104

1105+
# ==================== START: SIGReg Loss Computation ====================
1106+
# Compute SIGReg regularization loss on latent states if enabled
1107+
sigreg_loss = torch.tensor(0.0, device=self._cfg.device)
1108+
if self.use_sigreg and self.sigreg is not None:
1109+
# Get obs_embeddings from intermediate losses
1110+
if 'obs_embeddings' in losses.intermediate_losses:
1111+
obs_embeddings = losses.intermediate_losses['obs_embeddings']
1112+
# Compute SIGReg loss
1113+
sigreg_loss = self.sigreg(obs_embeddings)
1114+
# Add SIGReg loss to total loss
1115+
weighted_total_loss = weighted_total_loss + self.sigreg_weight * sigreg_loss
1116+
# ===================== END: SIGReg Loss Computation =====================
1117+
9941118
for loss_name, loss_value in losses.intermediate_losses.items():
9951119
self.intermediate_losses[f"{loss_name}"] = loss_value
9961120

@@ -1250,6 +1374,39 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
12501374
"current_policy_label_eps":current_policy_label_eps,
12511375
}
12521376

1377+
# ==================== START: Add SIGReg Statistics ====================
1378+
if self.use_sigreg:
1379+
return_log_dict['sigreg/loss'] = sigreg_loss.item()
1380+
return_log_dict['sigreg/weight'] = self.sigreg_weight
1381+
return_log_dict['sigreg/weighted_loss'] = (self.sigreg_weight * sigreg_loss).item()
1382+
1383+
# Additional statistics for debugging
1384+
if 'obs_embeddings' in losses.intermediate_losses:
1385+
obs_embeddings = losses.intermediate_losses['obs_embeddings']
1386+
with torch.no_grad():
1387+
# Compute embedding statistics
1388+
emb_mean = obs_embeddings.mean().item()
1389+
emb_std = obs_embeddings.std().item()
1390+
emb_norm = obs_embeddings.norm(p=2, dim=-1).mean().item()
1391+
1392+
return_log_dict['sigreg/embedding_mean'] = emb_mean
1393+
return_log_dict['sigreg/embedding_std'] = emb_std
1394+
return_log_dict['sigreg/embedding_norm'] = emb_norm
1395+
1396+
# Check for isotropy: compute variance along each dimension
1397+
if obs_embeddings.dim() == 3:
1398+
# Shape: (batch, seq, dim)
1399+
flat_emb = obs_embeddings.reshape(-1, obs_embeddings.shape[-1])
1400+
else:
1401+
flat_emb = obs_embeddings
1402+
1403+
dim_vars = flat_emb.var(dim=0)
1404+
return_log_dict['sigreg/dim_var_mean'] = dim_vars.mean().item()
1405+
return_log_dict['sigreg/dim_var_std'] = dim_vars.std().item()
1406+
return_log_dict['sigreg/dim_var_max'] = dim_vars.max().item()
1407+
return_log_dict['sigreg/dim_var_min'] = dim_vars.min().item()
1408+
# ===================== END: Add SIGReg Statistics =====================
1409+
12531410
if norm_log_dict:
12541411
return_log_dict.update(norm_log_dict)
12551412

0 commit comments

Comments
 (0)