|
25 | 25 | from .utils import configure_optimizers_nanogpt |
26 | 26 |
|
27 | 27 |
|
| 28 | +class SIGReg(torch.nn.Module): |
| 29 | + """ |
| 30 | + Sketched Isotropic Gaussian Regularization (SIGReg) from LeJEPA. |
| 31 | +
|
| 32 | + This regularization constrains learned embeddings to an optimal isotropic Gaussian distribution, |
| 33 | + which helps prevent representation collapse and improves the quality of latent states. |
| 34 | +
|
| 35 | + Reference: LeJEPA (https://arxiv.org/abs/2511.08544) |
| 36 | +
|
| 37 | + Args: |
| 38 | + knots (int): Number of integration points for the quadrature. Default: 17 |
| 39 | + t_max (float): Maximum t value for integration. Default: 3.0 |
| 40 | + num_slices (int): Number of random projections for slicing. Default: 256 |
| 41 | + """ |
| 42 | + def __init__(self, knots=17, t_max=3.0, num_slices=256): |
| 43 | + super().__init__() |
| 44 | + # Create integration points from 0 to t_max |
| 45 | + t = torch.linspace(0, t_max, knots, dtype=torch.float32) |
| 46 | + dt = t_max / (knots - 1) |
| 47 | + |
| 48 | + # Trapezoidal rule weights |
| 49 | + weights = torch.full((knots,), 2 * dt, dtype=torch.float32) |
| 50 | + weights[[0, -1]] = dt |
| 51 | + |
| 52 | + # Gaussian window function |
| 53 | + window = torch.exp(-t.square() / 2.0) |
| 54 | + |
| 55 | + # Register as buffers (not trainable parameters) |
| 56 | + self.register_buffer("t", t) |
| 57 | + self.register_buffer("phi", window) |
| 58 | + self.register_buffer("weights", weights * window) |
| 59 | + self.num_slices = num_slices |
| 60 | + |
| 61 | + def forward(self, embeddings): |
| 62 | + """ |
| 63 | + Compute SIGReg loss for the given embeddings. |
| 64 | +
|
| 65 | + Args: |
| 66 | + embeddings (torch.Tensor): Input embeddings of shape (batch_size, seq_len, embed_dim) |
| 67 | + or (batch_size, embed_dim) |
| 68 | +
|
| 69 | + Returns: |
| 70 | + torch.Tensor: Scalar SIGReg loss value |
| 71 | + """ |
| 72 | + # Handle different input shapes |
| 73 | + if embeddings.dim() == 3: |
| 74 | + # Shape: (batch_size, seq_len, embed_dim) -> (batch_size * seq_len, embed_dim) |
| 75 | + batch_size, seq_len, embed_dim = embeddings.shape |
| 76 | + proj = embeddings.reshape(-1, embed_dim) |
| 77 | + elif embeddings.dim() == 2: |
| 78 | + # Shape: (batch_size, embed_dim) |
| 79 | + proj = embeddings |
| 80 | + else: |
| 81 | + raise ValueError(f"Expected embeddings to have 2 or 3 dimensions, got {embeddings.dim()}") |
| 82 | + |
| 83 | + num_samples, embed_dim = proj.shape |
| 84 | + |
| 85 | + # Generate random projection matrix A with shape (embed_dim, num_slices) |
| 86 | + A = torch.randn(embed_dim, self.num_slices, device=proj.device, dtype=proj.dtype) |
| 87 | + A = A / A.norm(p=2, dim=0, keepdim=True) # Normalize columns |
| 88 | + |
| 89 | + # Project embeddings: (num_samples, embed_dim) @ (embed_dim, num_slices) = (num_samples, num_slices) |
| 90 | + x_proj = proj @ A |
| 91 | + |
| 92 | + # Compute characteristic function at different t values |
| 93 | + # x_proj: (num_samples, num_slices), t: (knots,) |
| 94 | + # x_t: (num_samples, num_slices, knots) |
| 95 | + x_t = x_proj.unsqueeze(-1) * self.t |
| 96 | + |
| 97 | + # Compute empirical characteristic function |
| 98 | + # cos_part: (num_slices, knots), sin_part: (num_slices, knots) |
| 99 | + cos_part = x_t.cos().mean(0) # Average over samples |
| 100 | + sin_part = x_t.sin().mean(0) |
| 101 | + |
| 102 | + # Compute error from target Gaussian (phi is the target) |
| 103 | + err = (cos_part - self.phi).square() + sin_part.square() |
| 104 | + |
| 105 | + # Integrate using quadrature weights |
| 106 | + statistic = (err * self.weights).sum() * num_samples |
| 107 | + |
| 108 | + # Average over slices |
| 109 | + return statistic / self.num_slices |
| 110 | + |
| 111 | + |
28 | 112 | def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): |
29 | 113 | """ |
30 | 114 | Efficiently scale all weights of a module using vectorized operations. |
@@ -781,6 +865,33 @@ def _init_learn(self) -> None: |
781 | 865 | self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps |
782 | 866 | logging.info(f"self.policy_ls_eps_start: {self.policy_ls_eps_start}") |
783 | 867 |
|
| 868 | + # ==================== START: SIGReg Initialization ==================== |
| 869 | + # Initialize SIGReg (Sketched Isotropic Gaussian Regularization) for latent state regularization |
| 870 | + self.use_sigreg = self._cfg.model.world_model_cfg.get('use_sigreg', False) |
| 871 | + if self.use_sigreg: |
| 872 | + sigreg_knots = self._cfg.model.world_model_cfg.get('sigreg_knots', 17) |
| 873 | + sigreg_t_max = self._cfg.model.world_model_cfg.get('sigreg_t_max', 3.0) |
| 874 | + sigreg_num_slices = self._cfg.model.world_model_cfg.get('sigreg_num_slices', 256) |
| 875 | + self.sigreg_weight = self._cfg.model.world_model_cfg.get('sigreg_weight', 0.02) |
| 876 | + |
| 877 | + self.sigreg = SIGReg( |
| 878 | + knots=sigreg_knots, |
| 879 | + t_max=sigreg_t_max, |
| 880 | + num_slices=sigreg_num_slices |
| 881 | + ).to(self._cfg.device) |
| 882 | + |
| 883 | + logging.info("=" * 60) |
| 884 | + logging.info(">>> SIGReg (Sketched Isotropic Gaussian Regularization) Enabled <<<") |
| 885 | + logging.info(f" SIGReg Weight (lambda): {self.sigreg_weight:.4f}") |
| 886 | + logging.info(f" Integration Knots: {sigreg_knots}") |
| 887 | + logging.info(f" T_max: {sigreg_t_max}") |
| 888 | + logging.info(f" Number of Slices: {sigreg_num_slices}") |
| 889 | + logging.info("=" * 60) |
| 890 | + else: |
| 891 | + self.sigreg = None |
| 892 | + self.sigreg_weight = 0.0 |
| 893 | + # ===================== END: SIGReg Initialization ===================== |
| 894 | + |
784 | 895 | def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: |
785 | 896 | """ |
786 | 897 | Overview: |
@@ -991,6 +1102,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in |
991 | 1102 |
|
992 | 1103 | weighted_total_loss = (weights * losses.loss_total).mean() |
993 | 1104 |
|
| 1105 | + # ==================== START: SIGReg Loss Computation ==================== |
| 1106 | + # Compute SIGReg regularization loss on latent states if enabled |
| 1107 | + sigreg_loss = torch.tensor(0.0, device=self._cfg.device) |
| 1108 | + if self.use_sigreg and self.sigreg is not None: |
| 1109 | + # Get obs_embeddings from intermediate losses |
| 1110 | + if 'obs_embeddings' in losses.intermediate_losses: |
| 1111 | + obs_embeddings = losses.intermediate_losses['obs_embeddings'] |
| 1112 | + # Compute SIGReg loss |
| 1113 | + sigreg_loss = self.sigreg(obs_embeddings) |
| 1114 | + # Add SIGReg loss to total loss |
| 1115 | + weighted_total_loss = weighted_total_loss + self.sigreg_weight * sigreg_loss |
| 1116 | + # ===================== END: SIGReg Loss Computation ===================== |
| 1117 | + |
994 | 1118 | for loss_name, loss_value in losses.intermediate_losses.items(): |
995 | 1119 | self.intermediate_losses[f"{loss_name}"] = loss_value |
996 | 1120 |
|
@@ -1250,6 +1374,39 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in |
1250 | 1374 | "current_policy_label_eps":current_policy_label_eps, |
1251 | 1375 | } |
1252 | 1376 |
|
| 1377 | + # ==================== START: Add SIGReg Statistics ==================== |
| 1378 | + if self.use_sigreg: |
| 1379 | + return_log_dict['sigreg/loss'] = sigreg_loss.item() |
| 1380 | + return_log_dict['sigreg/weight'] = self.sigreg_weight |
| 1381 | + return_log_dict['sigreg/weighted_loss'] = (self.sigreg_weight * sigreg_loss).item() |
| 1382 | + |
| 1383 | + # Additional statistics for debugging |
| 1384 | + if 'obs_embeddings' in losses.intermediate_losses: |
| 1385 | + obs_embeddings = losses.intermediate_losses['obs_embeddings'] |
| 1386 | + with torch.no_grad(): |
| 1387 | + # Compute embedding statistics |
| 1388 | + emb_mean = obs_embeddings.mean().item() |
| 1389 | + emb_std = obs_embeddings.std().item() |
| 1390 | + emb_norm = obs_embeddings.norm(p=2, dim=-1).mean().item() |
| 1391 | + |
| 1392 | + return_log_dict['sigreg/embedding_mean'] = emb_mean |
| 1393 | + return_log_dict['sigreg/embedding_std'] = emb_std |
| 1394 | + return_log_dict['sigreg/embedding_norm'] = emb_norm |
| 1395 | + |
| 1396 | + # Check for isotropy: compute variance along each dimension |
| 1397 | + if obs_embeddings.dim() == 3: |
| 1398 | + # Shape: (batch, seq, dim) |
| 1399 | + flat_emb = obs_embeddings.reshape(-1, obs_embeddings.shape[-1]) |
| 1400 | + else: |
| 1401 | + flat_emb = obs_embeddings |
| 1402 | + |
| 1403 | + dim_vars = flat_emb.var(dim=0) |
| 1404 | + return_log_dict['sigreg/dim_var_mean'] = dim_vars.mean().item() |
| 1405 | + return_log_dict['sigreg/dim_var_std'] = dim_vars.std().item() |
| 1406 | + return_log_dict['sigreg/dim_var_max'] = dim_vars.max().item() |
| 1407 | + return_log_dict['sigreg/dim_var_min'] = dim_vars.min().item() |
| 1408 | + # ===================== END: Add SIGReg Statistics ===================== |
| 1409 | + |
1253 | 1410 | if norm_log_dict: |
1254 | 1411 | return_log_dict.update(norm_log_dict) |
1255 | 1412 |
|
|
0 commit comments