diff --git a/examples/resources b/examples/resources new file mode 120000 index 000000000..92db3035c --- /dev/null +++ b/examples/resources @@ -0,0 +1 @@ +/workspace/PufferLib/pufferlib/resources \ No newline at end of file diff --git a/pufferlib/contrastive_loss.py b/pufferlib/contrastive_loss.py new file mode 100644 index 000000000..c77cf49c6 --- /dev/null +++ b/pufferlib/contrastive_loss.py @@ -0,0 +1,296 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple, Optional, Any +from collections import defaultdict + + +class ContrastiveLoss(nn.Module): + """Contrastive loss for representation learning in PufferLib. + + Implements InfoNCE loss with geometric future positives and shuffled negatives. + The loss samples (st, at) pairs, creates positive examples sf^(1) by looking + Δ ~ GEOM(1-γ) steps ahead, and generates negative examples by shuffling. + """ + + def __init__( + self, + temperature: float = 0.1, + contrastive_coef: float = 1.0, + embedding_dim: int = 256, + discount: float = 0.99, + use_projection_head: bool = False, + device: torch.device = None, + ): + super().__init__() + self.temperature = temperature + self.contrastive_coef = contrastive_coef + self.embedding_dim = embedding_dim + self.discount = discount + self.device = device or torch.device('cpu') + + # Projection head will be created dynamically if needed + self.projection_head = None + self.use_projection_head = use_projection_head + self._value_projection = None + + # Metrics tracking + self.loss_tracker = defaultdict(list) + + def forward( + self, + embeddings: torch.Tensor, + terminals: torch.Tensor, + truncations: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, float]]: + """Compute contrastive loss. + + Args: + embeddings: [segments, horizon, embedding_dim] tensor of representations + terminals: [segments, horizon] tensor of done flags + truncations: [segments, horizon] tensor of truncation flags (optional) + + Returns: + loss: Contrastive loss tensor + metrics: Dictionary of metrics for logging + """ + segments, horizon = embeddings.shape[0], embeddings.shape[1] + embedding_dim = embeddings.shape[-1] + + if embedding_dim == 0: + return torch.tensor(0.0, device=self.device), self._empty_metrics() + + # Create done mask combining terminals and truncations + done_mask = terminals.to(dtype=torch.bool) + if truncations is not None: + done_mask = torch.logical_or(done_mask, truncations.to(dtype=torch.bool)) + + # Apply projection head if configured + if self.use_projection_head: + embeddings = self._apply_projection_head(embeddings) + + # Sample contrastive pairs + batch_indices, anchor_steps, positive_steps, sampled_deltas = self._sample_pairs( + done_mask, segments, horizon + ) + + num_pairs = len(batch_indices) + if num_pairs < 2: + return torch.tensor(0.0, device=self.device), self._empty_metrics(num_pairs) + + # Extract embeddings for contrastive learning + batch_idx_tensor = torch.tensor(batch_indices, device=self.device, dtype=torch.long) + anchor_idx_tensor = torch.tensor(anchor_steps, device=self.device, dtype=torch.long) + positive_idx_tensor = torch.tensor(positive_steps, device=self.device, dtype=torch.long) + + anchor_embeddings = embeddings[batch_idx_tensor, anchor_idx_tensor] + positive_embeddings = embeddings[batch_idx_tensor, positive_idx_tensor] + + # Compute similarities and InfoNCE loss + similarities = anchor_embeddings @ positive_embeddings.T + positive_logits = similarities.diagonal().unsqueeze(1) + + # Create negative logits by masking out positive pairs + mask = torch.eye(num_pairs, device=self.device, dtype=torch.bool) + negative_logits = similarities[~mask].view(num_pairs, num_pairs - 1) + + # Combine positive and negative logits + logits = torch.cat([positive_logits, negative_logits], dim=1) / self.temperature + labels = torch.zeros(num_pairs, dtype=torch.long, device=self.device) + + # Compute InfoNCE loss + infonce_loss = F.cross_entropy(logits, labels, reduction="mean") + + # Compute metrics + metrics = self._compute_metrics( + positive_logits, negative_logits, num_pairs, sampled_deltas + ) + + return infonce_loss * self.contrastive_coef, metrics + + def _apply_projection_head(self, embeddings: torch.Tensor) -> torch.Tensor: + """Apply projection head to embeddings.""" + if self.projection_head is None: + input_dim = embeddings.shape[-1] + self.projection_head = nn.Linear(input_dim, self.embedding_dim).to(self.device) + + # Reshape for linear layer: [segments, horizon, dim] -> [segments*horizon, dim] + original_shape = embeddings.shape[:2] + embeddings_flat = embeddings.view(-1, embeddings.shape[-1]) + projected_flat = self.projection_head(embeddings_flat) + + # Reshape back: [segments*horizon, embedding_dim] -> [segments, horizon, embedding_dim] + return projected_flat.view(*original_shape, self.embedding_dim) + + def _sample_pairs( + self, + done_mask: torch.Tensor, + segments: int, + horizon: int + ) -> Tuple[list, list, list, list]: + """Sample anchor and positive pairs using geometric distribution.""" + prob = max(1.0 - float(self.discount), 1e-8) + geom_dist = torch.distributions.Geometric( + probs=torch.tensor(prob, device=self.device) + ) + + done_mask_cpu = done_mask.detach().to("cpu") + + batch_indices = [] + anchor_steps = [] + positive_steps = [] + sampled_deltas = [] + + for batch_idx in range(segments): + done_row = done_mask_cpu[batch_idx].view(-1) + + # Find episode boundaries + episode_bounds = [] + start = 0 + for step, done in enumerate(done_row.tolist()): + if done: + episode_bounds.append((start, step)) + start = step + 1 + if start < horizon: + episode_bounds.append((start, horizon - 1)) + + # Collect candidate anchors + candidate_anchors = [] + for episode_start, episode_end in episode_bounds: + if episode_end - episode_start < 1: + continue + for anchor in range(episode_start, episode_end): + candidate_anchors.append((anchor, episode_end)) + + if not candidate_anchors: + continue + + # Sample anchor and positive + choice_idx = int(torch.randint(len(candidate_anchors), (1,), device=self.device).item()) + anchor_step, episode_end = candidate_anchors[choice_idx] + max_future = episode_end - anchor_step + + if max_future < 1: + continue + + # Sample delta using geometric distribution + delta = int(geom_dist.sample().item()) + attempts = 0 + while delta > max_future and attempts < 10: + delta = int(geom_dist.sample().item()) + attempts += 1 + if delta > max_future: + delta = max_future + + positive_step = anchor_step + delta + + batch_indices.append(batch_idx) + anchor_steps.append(anchor_step) + positive_steps.append(positive_step) + sampled_deltas.append(float(delta)) + + return batch_indices, anchor_steps, positive_steps, sampled_deltas + + def _compute_metrics( + self, + positive_logits: torch.Tensor, + negative_logits: torch.Tensor, + num_pairs: int, + sampled_deltas: list, + ) -> Dict[str, float]: + """Compute metrics for logging.""" + return { + "positive_sim_mean": positive_logits.mean().item(), + "negative_sim_mean": negative_logits.mean().item(), + "positive_sim_std": positive_logits.std().item(), + "negative_sim_std": negative_logits.std().item(), + "num_pairs": num_pairs, + "delta_mean": float(sum(sampled_deltas) / len(sampled_deltas)) if sampled_deltas else 0.0, + } + + def _empty_metrics(self, num_pairs: int = 0) -> Dict[str, float]: + """Return empty metrics when no loss can be computed.""" + return { + "positive_sim_mean": 0.0, + "negative_sim_mean": 0.0, + "positive_sim_std": 0.0, + "negative_sim_std": 0.0, + "num_pairs": num_pairs, + "delta_mean": 0.0, + } + + +def get_embeddings_from_policy_data( + policy_logits: torch.Tensor, + policy_values: torch.Tensor, + embedding_dim: int, + device: torch.device, +) -> torch.Tensor: + """Extract embeddings from policy outputs. + + This is a helper function to extract embeddings when they're not directly + available from the policy. In practice, you'd want to modify your policy + to return embeddings directly. + + Args: + policy_logits: Action logits from policy forward pass + policy_values: Value predictions from policy + embedding_dim: Desired embedding dimension + device: Target device + + Returns: + embeddings: [batch_size, embedding_dim] tensor + """ + # Fallback: use value as embeddings but create learnable projection + # This is suboptimal but demonstrates the interface + values = policy_values.squeeze(-1) if policy_values.dim() > 1 else policy_values + + if values.dim() == 1: + # Create a simple learnable projection from 1D value to embedding_dim + projection = nn.Linear(1, embedding_dim).to(device) + nn.init.xavier_uniform_(projection.weight) + values = values.unsqueeze(-1) # [N] -> [N, 1] + embeddings = projection(values) # [N, 1] -> [N, embedding_dim] + return embeddings + else: + return values + + +def compute_contrastive_loss_pufferlib( + embeddings: torch.Tensor, + terminals: torch.Tensor, + truncations: Optional[torch.Tensor] = None, + temperature: float = 0.1, + contrastive_coef: float = 1.0, + embedding_dim: int = 256, + discount: float = 0.99, + device: torch.device = None, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """Functional interface for contrastive loss computation. + + This function can be directly integrated into PufferLib's training loop + without needing to modify the main PuffeRL class. + + Args: + embeddings: [segments, horizon, embedding_dim] representation tensor + terminals: [segments, horizon] done flags tensor + truncations: [segments, horizon] truncation flags (optional) + temperature: Temperature for InfoNCE loss + contrastive_coef: Coefficient for contrastive loss + embedding_dim: Target embedding dimension + discount: Discount factor for geometric sampling + device: Target device + + Returns: + loss: Contrastive loss value + metrics: Dictionary of logging metrics + """ + contrastive_loss = ContrastiveLoss( + temperature=temperature, + contrastive_coef=contrastive_coef, + embedding_dim=embedding_dim, + discount=discount, + device=device or embeddings.device, + ) + + return contrastive_loss(embeddings, terminals, truncations) \ No newline at end of file diff --git a/pufferlib/pufferl_with_contrastive.py b/pufferlib/pufferl_with_contrastive.py new file mode 100644 index 000000000..0400d0ac0 --- /dev/null +++ b/pufferlib/pufferl_with_contrastive.py @@ -0,0 +1,269 @@ +""" +Extended PuffeRL trainer with contrastive loss integration. + +This demonstrates how to integrate the contrastive loss into PufferLib's training loop. +The key additions are: +1. Extracting embeddings from the policy +2. Computing contrastive loss alongside standard losses +3. Adding contrastive metrics to logging +""" + +import torch +from collections import defaultdict +from .contrastive_loss import compute_contrastive_loss_pufferlib, get_embeddings_from_policy_data + + +def train_with_contrastive_loss(pufferl_instance): + """Modified training function that includes contrastive loss. + + This extends the standard PuffeRL training loop to include contrastive learning. + You would replace the train() method in PuffeRL with this implementation, + or create a subclass that overrides the train method. + """ + profile = pufferl_instance.profile + epoch = pufferl_instance.epoch + profile('train', epoch) + losses = defaultdict(float) + config = pufferl_instance.config + device = config['device'] + + # Standard PPO setup + b0 = config['prio_beta0'] + a = config['prio_alpha'] + clip_coef = config['clip_coef'] + vf_clip = config['vf_clip_coef'] + anneal_beta = b0 + (1 - b0) * a * pufferl_instance.epoch / pufferl_instance.total_epochs + pufferl_instance.ratio[:] = 1 + + # Contrastive loss configuration (should come from config) + use_contrastive = config.get('use_contrastive_loss', False) + contrastive_coef = config.get('contrastive_coef', 1.0) + contrastive_temperature = config.get('contrastive_temperature', 0.1) + contrastive_discount = config.get('contrastive_discount', 0.99) + embedding_dim = config.get('embedding_dim', 256) + + for mb in range(pufferl_instance.total_minibatches): + profile('train_misc', epoch, nest=True) + pufferl_instance.amp_context.__enter__() + + shape = pufferl_instance.values.shape + advantages = torch.zeros(shape, device=device) + + # Import the advantage computation function + from pufferlib.pufferl import compute_puff_advantage + + advantages = compute_puff_advantage( + pufferl_instance.values, + pufferl_instance.rewards, + pufferl_instance.terminals, + pufferl_instance.ratio, + advantages, + config['gamma'], + config['gae_lambda'], + config['vtrace_rho_clip'], + config['vtrace_c_clip'] + ) + + profile('train_copy', epoch) + adv = advantages.abs().sum(axis=1) + prio_weights = torch.nan_to_num(adv**a, 0, 0, 0) + prio_probs = (prio_weights + 1e-6) / (prio_weights.sum() + 1e-6) + idx = torch.multinomial(prio_probs, pufferl_instance.minibatch_segments) + mb_prio = (pufferl_instance.segments * prio_probs[idx, None])**-anneal_beta + mb_obs = pufferl_instance.observations[idx] + mb_actions = pufferl_instance.actions[idx] + mb_logprobs = pufferl_instance.logprobs[idx] + mb_rewards = pufferl_instance.rewards[idx] + mb_terminals = pufferl_instance.terminals[idx] + mb_truncations = pufferl_instance.truncations[idx] + mb_ratio = pufferl_instance.ratio[idx] + mb_values = pufferl_instance.values[idx] + mb_returns = advantages[idx] + mb_values + mb_advantages = advantages[idx] + + profile('train_forward', epoch) + if not config['use_rnn']: + mb_obs_flat = mb_obs.reshape(-1, *pufferl_instance.vecenv.single_observation_space.shape) + else: + mb_obs_flat = mb_obs + + state = dict( + action=mb_actions, + lstm_h=None, + lstm_c=None, + ) + + # Forward pass through policy + logits, newvalue = pufferl_instance.policy(mb_obs_flat, state) + + # Import the sampling function + import pufferlib.pytorch + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + + profile('train_misc', epoch) + newlogprob = newlogprob.reshape(mb_logprobs.shape) + logratio = newlogprob - mb_logprobs + ratio = logratio.exp() + pufferl_instance.ratio[idx] = ratio.detach() + + with torch.no_grad(): + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > config['clip_coef']).float().mean() + + adv = advantages[idx] + adv = compute_puff_advantage( + mb_values, mb_rewards, mb_terminals, + ratio, adv, config['gamma'], config['gae_lambda'], + config['vtrace_rho_clip'], config['vtrace_c_clip'] + ) + adv = mb_advantages + adv = mb_prio * (adv - adv.mean()) / (adv.std() + 1e-8) + + # Standard PPO losses + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + newvalue = newvalue.view(mb_returns.shape) + v_clipped = mb_values + torch.clamp(newvalue - mb_values, -vf_clip, vf_clip) + v_loss_unclipped = (newvalue - mb_returns) ** 2 + v_loss_clipped = (v_clipped - mb_returns) ** 2 + v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() + + entropy_loss = entropy.mean() + + # Standard PPO loss + standard_loss = pg_loss + config['vf_coef'] * v_loss - config['ent_coef'] * entropy_loss + + # Contrastive loss computation + total_loss = standard_loss + contrastive_loss_value = torch.tensor(0.0, device=device) + contrastive_metrics = {} + + if use_contrastive: + try: + # Extract embeddings from policy output + # In practice, you'd want to modify your policy to directly provide embeddings + # This is a fallback that uses value predictions as embeddings + embeddings = get_embeddings_from_policy_data( + logits, newvalue, embedding_dim, device + ) + + # Reshape embeddings to match minibatch structure + if embeddings.dim() == 2: + embeddings = embeddings.view(mb_obs.shape[0], mb_obs.shape[1], -1) + + # Compute contrastive loss + contrastive_loss_value, contrastive_metrics = compute_contrastive_loss_pufferlib( + embeddings=embeddings, + terminals=mb_terminals, + truncations=mb_truncations, + temperature=contrastive_temperature, + contrastive_coef=contrastive_coef, + embedding_dim=embedding_dim, + discount=contrastive_discount, + device=device, + ) + + # Add contrastive loss to total loss + total_loss = total_loss + contrastive_loss_value + + except Exception as e: + # Log error but don't crash training + print(f"Warning: Contrastive loss computation failed: {e}") + contrastive_loss_value = torch.tensor(0.0, device=device) + + # Use total loss for backward pass + pufferl_instance.amp_context.__enter__() + + # Update values as in original + pufferl_instance.values[idx] = newvalue.detach().float() + + # Standard loss logging + profile('train_misc', epoch) + losses['policy_loss'] += pg_loss.item() / pufferl_instance.total_minibatches + losses['value_loss'] += v_loss.item() / pufferl_instance.total_minibatches + losses['entropy'] += entropy_loss.item() / pufferl_instance.total_minibatches + losses['old_approx_kl'] += old_approx_kl.item() / pufferl_instance.total_minibatches + losses['approx_kl'] += approx_kl.item() / pufferl_instance.total_minibatches + losses['clipfrac'] += clipfrac.item() / pufferl_instance.total_minibatches + losses['importance'] += ratio.mean().item() / pufferl_instance.total_minibatches + + # Contrastive loss logging + if use_contrastive: + losses['contrastive_loss'] += contrastive_loss_value.item() / pufferl_instance.total_minibatches + for key, value in contrastive_metrics.items(): + metric_name = f'contrastive_{key}' + if metric_name not in losses: + losses[metric_name] = 0.0 + losses[metric_name] += value / pufferl_instance.total_minibatches + + # Learn on accumulated minibatches + profile('learn', epoch) + total_loss.backward() + if (mb + 1) % pufferl_instance.accumulate_minibatches == 0: + torch.nn.utils.clip_grad_norm_(pufferl_instance.policy.parameters(), config['max_grad_norm']) + pufferl_instance.optimizer.step() + pufferl_instance.optimizer.zero_grad() + + # Rest of the training function remains the same as original + profile('train_misc', epoch) + if config['anneal_lr']: + pufferl_instance.scheduler.step() + + y_pred = pufferl_instance.values.flatten() + y_true = advantages.flatten() + pufferl_instance.values.flatten() + var_y = y_true.var() + explained_var = torch.nan if var_y == 0 else 1 - (y_true - y_pred).var() / var_y + losses['explained_variance'] = explained_var.item() + + profile.end() + logs = None + pufferl_instance.epoch += 1 + done_training = pufferl_instance.global_step >= config['total_timesteps'] + + if done_training or pufferl_instance.global_step == 0 or pufferl_instance.uptime > pufferl_instance.last_log_time + 0.25: + logs = pufferl_instance.mean_and_log() + pufferl_instance.losses = losses + pufferl_instance.print_dashboard() + pufferl_instance.stats = defaultdict(list) + pufferl_instance.last_log_time = pufferl_instance.uptime + pufferl_instance.last_log_step = pufferl_instance.global_step + profile.clear() + + if pufferl_instance.epoch % config['checkpoint_interval'] == 0 or done_training: + pufferl_instance.save_checkpoint() + pufferl_instance.msg = f'Checkpoint saved at update {pufferl_instance.epoch}' + + return logs + + +class PuffeRLWithContrastive: + """Example of how to extend PuffeRL with contrastive loss. + + You can use this as a reference for integrating contrastive loss + into your own PufferLib training setup. + """ + + def __init__(self, config, vecenv, policy, logger=None): + # Import and initialize base PuffeRL + from pufferlib.pufferl import PuffeRL + self.base_trainer = PuffeRL(config, vecenv, policy, logger) + + # Add contrastive loss specific configuration + self.use_contrastive = config.get('use_contrastive_loss', False) + if self.use_contrastive: + print(f"Contrastive loss enabled with coefficient {config.get('contrastive_coef', 1.0)}") + + def train(self): + """Training method with contrastive loss.""" + return train_with_contrastive_loss(self.base_trainer) + + def evaluate(self): + """Evaluation remains the same.""" + return self.base_trainer.evaluate() + + def __getattr__(self, name): + """Delegate other attributes to base trainer.""" + return getattr(self.base_trainer, name) \ No newline at end of file diff --git a/simple_test.py b/simple_test.py new file mode 100644 index 000000000..24188343a --- /dev/null +++ b/simple_test.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +"""Simple test to verify PufferLib can be imported and basic functionality works.""" + +import sys +import os + +# Add PufferLib to Python path +sys.path.insert(0, '/workspace/PufferLib') + +try: + # Test basic imports that don't require external dependencies + print("Testing basic PufferLib imports...") + + # First try to import the main module parts individually + try: + import pufferlib.exceptions + print("✓ pufferlib.exceptions imported successfully") + except Exception as e: + print(f"✗ pufferlib.exceptions failed: {e}") + + try: + import pufferlib.utils + print("✓ pufferlib.utils imported successfully") + except Exception as e: + print(f"✗ pufferlib.utils failed: {e}") + + # Test if we can at least load the module structure + print(f"✓ PufferLib directory found at: {os.path.dirname(__file__)}") + print("✓ Basic test completed - PufferLib structure is accessible") + +except Exception as e: + print(f"✗ Error during testing: {e}") + import traceback + traceback.print_exc() + +print("\nTo fully test PufferLib, you need to install dependencies:") +print("pip install numpy gymnasium") \ No newline at end of file