Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,11 @@ def __init__(

# Initial convolution: stride 2
self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2)
if norm_type == 'BN':
self.norm1 = nn.BatchNorm2d(out_channels // 2)
elif norm_type == 'LN':
self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
eps=1e-5)

# Stage 1 with residual blocks
self.resblocks1 = nn.ModuleList([
Expand Down Expand Up @@ -734,7 +738,15 @@ def __init__(
self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type)
else:
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm = build_normalization(norm_type, dim=2)(num_channels)
if norm_type == 'BN':
self.norm = nn.BatchNorm2d(num_channels)
elif norm_type == 'LN':
if downsample:
self.norm = nn.LayerNorm(
[num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
eps=1e-5)
else:
self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)

self.resblocks = nn.ModuleList([
ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)
Expand Down
115 changes: 102 additions & 13 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.nn.utils.convert_parameters import (parameters_to_vector,
vector_to_parameters)

from .utils import configure_optimizers_nanogpt
from .utils import configure_optimizers_nanogpt, SIGReg


def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float):
Expand Down Expand Up @@ -781,6 +781,33 @@ def _init_learn(self) -> None:
self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps
logging.info(f"self.policy_ls_eps_start: {self.policy_ls_eps_start}")

# ==================== START: SIGReg Initialization ====================
# Initialize SIGReg (Sketched Isotropic Gaussian Regularization) for latent state regularization
self.use_sigreg = self._cfg.model.world_model_cfg.get('use_sigreg', False)
if self.use_sigreg:
sigreg_knots = self._cfg.model.world_model_cfg.get('sigreg_knots', 17)
sigreg_t_max = self._cfg.model.world_model_cfg.get('sigreg_t_max', 3.0)
sigreg_num_slices = self._cfg.model.world_model_cfg.get('sigreg_num_slices', 256)
self.sigreg_weight = self._cfg.model.world_model_cfg.get('sigreg_weight', 0.02)

self.sigreg = SIGReg(
knots=sigreg_knots,
t_max=sigreg_t_max,
num_slices=sigreg_num_slices
).to(self._cfg.device)

logging.info("=" * 60)
logging.info(">>> SIGReg (Sketched Isotropic Gaussian Regularization) Enabled <<<")
logging.info(f" SIGReg Weight (lambda): {self.sigreg_weight:.4f}")
logging.info(f" Integration Knots: {sigreg_knots}")
logging.info(f" T_max: {sigreg_t_max}")
logging.info(f" Number of Slices: {sigreg_num_slices}")
logging.info("=" * 60)
else:
self.sigreg = None
self.sigreg_weight = 0.0
# ===================== END: SIGReg Initialization =====================

def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
"""
Overview:
Expand Down Expand Up @@ -991,6 +1018,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in

weighted_total_loss = (weights * losses.loss_total).mean()

# ==================== START: SIGReg Loss Computation ====================
# Compute SIGReg regularization loss on latent states if enabled
sigreg_loss = torch.tensor(0.0, device=self._cfg.device)
if self.use_sigreg and self.sigreg is not None:
# Get obs_embeddings from intermediate losses
if 'obs_embeddings' in losses.intermediate_losses:
obs_embeddings = losses.intermediate_losses['obs_embeddings']
# Compute SIGReg loss
sigreg_loss = self.sigreg(obs_embeddings)
# Add SIGReg loss to total loss
weighted_total_loss = weighted_total_loss + self.sigreg_weight * sigreg_loss
# ===================== END: SIGReg Loss Computation =====================

for loss_name, loss_value in losses.intermediate_losses.items():
self.intermediate_losses[f"{loss_name}"] = loss_value

Expand Down Expand Up @@ -1250,6 +1290,39 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
"current_policy_label_eps":current_policy_label_eps,
}

# ==================== START: Add SIGReg Statistics ====================
if self.use_sigreg:
return_log_dict['sigreg/loss'] = sigreg_loss.item()
return_log_dict['sigreg/weight'] = self.sigreg_weight
return_log_dict['sigreg/weighted_loss'] = (self.sigreg_weight * sigreg_loss).item()

# Additional statistics for debugging
if 'obs_embeddings' in losses.intermediate_losses:
obs_embeddings = losses.intermediate_losses['obs_embeddings']
with torch.no_grad():
# Compute embedding statistics
emb_mean = obs_embeddings.mean().item()
emb_std = obs_embeddings.std().item()
emb_norm = obs_embeddings.norm(p=2, dim=-1).mean().item()

return_log_dict['sigreg/embedding_mean'] = emb_mean
return_log_dict['sigreg/embedding_std'] = emb_std
return_log_dict['sigreg/embedding_norm'] = emb_norm

# Check for isotropy: compute variance along each dimension
if obs_embeddings.dim() == 3:
# Shape: (batch, seq, dim)
flat_emb = obs_embeddings.reshape(-1, obs_embeddings.shape[-1])
else:
flat_emb = obs_embeddings

dim_vars = flat_emb.var(dim=0)
return_log_dict['sigreg/dim_var_mean'] = dim_vars.mean().item()
return_log_dict['sigreg/dim_var_std'] = dim_vars.std().item()
return_log_dict['sigreg/dim_var_max'] = dim_vars.max().item()
return_log_dict['sigreg/dim_var_min'] = dim_vars.min().item()
# ===================== END: Add SIGReg Statistics =====================

if norm_log_dict:
return_log_dict.update(norm_log_dict)

Expand Down Expand Up @@ -1335,10 +1408,10 @@ def _init_collect(self) -> None:
self._collect_epsilon = 0.0
self.collector_env_num = self._cfg.collector_env_num
if self._cfg.model.model_type == 'conv':
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_obs_collect = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
elif self._cfg.model.model_type == 'mlp':
self.last_batch_obs = torch.full(
self.last_batch_obs_collect = torch.full(
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
).to(self._cfg.device)
self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)]
Expand Down Expand Up @@ -1390,7 +1463,7 @@ def _forward_collect(
output = {i: None for i in ready_env_id}

with torch.no_grad():
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep)
network_output = self._collect_model.initial_inference(self.last_batch_obs_collect, self.last_batch_action_collect, data, timestep)
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)

pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
Expand Down Expand Up @@ -1461,7 +1534,7 @@ def _forward_collect(
}
batch_action.append(action)

self.last_batch_obs = data
self.last_batch_obs_collect = data
self.last_batch_action_collect = batch_action

# This logic is a temporary workaround specific to the muzero_segment_collector.
Expand Down Expand Up @@ -1505,10 +1578,10 @@ def _init_eval(self) -> None:
self.evaluator_env_num = self._cfg.evaluator_env_num

if self._cfg.model.model_type == 'conv':
self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_obs_eval = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device)
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
elif self._cfg.model.model_type == 'mlp':
self.last_batch_obs = torch.full(
self.last_batch_obs_eval = torch.full(
[self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id,
).to(self._cfg.device)
self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)]
Expand Down Expand Up @@ -1623,13 +1696,13 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
- reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset.
"""
if reset_init_data:
self.last_batch_obs = initialize_pad_batch(
self.last_batch_obs_collect = initialize_pad_batch(
self._cfg.model.observation_shape,
self._cfg.collector_env_num,
self._cfg.device,
pad_token_id=self.pad_token_id
)
self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)]
self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)]


# We must handle both single int and list of ints for env_id.
Expand Down Expand Up @@ -1696,7 +1769,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
self._cfg.device,
pad_token_id=self.pad_token_id
)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')

else:
self.last_batch_obs_eval = initialize_pad_batch(
Expand All @@ -1705,9 +1778,9 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
self._cfg.device,
pad_token_id=self.pad_token_id
)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:{self.last_batch_obs_eval.shape}')

self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)]

# This logic handles the crucial end-of-episode cache clearing for evaluation.
# The evaluator calls `_policy.reset([env_id])` when an episode is done.
Expand Down Expand Up @@ -1919,7 +1992,23 @@ def _monitor_vars_learn(self) -> List[str]:
'stability/warning_count', # Number of warnings issued in current check
]

return base_vars + norm_vars+ head_clip_vars + enhanced_policy_vars + stability_vars
sigreg_vars = []
# Check if SIGReg is enabled
if self.use_sigreg:
sigreg_vars = [
'sigreg/loss',
'sigreg/weight',
'sigreg/weighted_loss',
'sigreg/embedding_mean',
'sigreg/embedding_std',
'sigreg/embedding_norm',
'sigreg/dim_var_mean',
'sigreg/dim_var_std',
'sigreg/dim_var_max',
'sigreg/dim_var_min',
]

return base_vars + norm_vars+ head_clip_vars + enhanced_policy_vars + stability_vars + sigreg_vars


def _state_dict_learn(self) -> Dict[str, Any]:
Expand Down
84 changes: 84 additions & 0 deletions lzero/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,87 @@ def mz_network_output_unpack(network_output: Dict) -> Tuple:
value = network_output.value # shape: (batch_size, support_support_size)
policy_logits = network_output.policy_logits # shape: (batch_size, action_space_size)
return latent_state, reward, value, policy_logits


class SIGReg(nn.Module):
"""
Sketched Isotropic Gaussian Regularization (SIGReg) from LeJEPA.

This regularization constrains learned embeddings to an optimal isotropic Gaussian distribution,
which helps prevent representation collapse and improves the quality of latent states.

Reference: LeJEPA (https://arxiv.org/abs/2511.08544)

Args:
knots (int): Number of integration points for the quadrature. Default: 17
t_max (float): Maximum t value for integration. Default: 3.0
num_slices (int): Number of random projections for slicing. Default: 256
"""
def __init__(self, knots=17, t_max=3.0, num_slices=256):
super().__init__()
# Create integration points from 0 to t_max
t = torch.linspace(0, t_max, knots, dtype=torch.float32)
dt = t_max / (knots - 1)

# Trapezoidal rule weights
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt

# Gaussian window function
window = torch.exp(-t.square() / 2.0)

# Register as buffers (not trainable parameters)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
self.num_slices = num_slices

def forward(self, embeddings):
"""
Compute SIGReg loss for the given embeddings.

Args:
embeddings (torch.Tensor): Input embeddings of shape (batch_size, seq_len, embed_dim)
or (batch_size, embed_dim)

Returns:
torch.Tensor: Scalar SIGReg loss value
"""
# Handle different input shapes
if embeddings.dim() == 3:
# Shape: (batch_size, seq_len, embed_dim) -> (batch_size * seq_len, embed_dim)
batch_size, seq_len, embed_dim = embeddings.shape
proj = embeddings.reshape(-1, embed_dim)
elif embeddings.dim() == 2:
# Shape: (batch_size, embed_dim)
proj = embeddings
else:
raise ValueError(f"Expected embeddings to have 2 or 3 dimensions, got {embeddings.dim()}")

num_samples, embed_dim = proj.shape

# Generate random projection matrix A with shape (embed_dim, num_slices)
A = torch.randn(embed_dim, self.num_slices, device=proj.device, dtype=proj.dtype)
A = A / A.norm(p=2, dim=0, keepdim=True) # Normalize columns

# Project embeddings: (num_samples, embed_dim) @ (embed_dim, num_slices) = (num_samples, num_slices)
x_proj = proj @ A

# Compute characteristic function at different t values
# x_proj: (num_samples, num_slices), t: (knots,)
# x_t: (num_samples, num_slices, knots)
x_t = x_proj.unsqueeze(-1) * self.t

# Compute empirical characteristic function
# cos_part: (num_slices, knots), sin_part: (num_slices, knots)
cos_part = x_t.cos().mean(0) # Average over samples
sin_part = x_t.sin().mean(0)

# Compute error from target Gaussian (phi is the target)
err = (cos_part - self.phi).square() + sin_part.square()

# Integrate using quadrature weights
statistic = (err * self.weights).sum() * num_samples

# Average over slices
return statistic / self.num_slices
Loading