From 87799bb30b90a985a43a07c22ec44244ff7d3c56 Mon Sep 17 00:00:00 2001 From: Ruben Date: Tue, 4 Nov 2025 20:01:46 +0100 Subject: [PATCH 1/2] feat: add legacy post-processing wrappers for backward compatibility --- pufferlib/postprocess.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 pufferlib/postprocess.py diff --git a/pufferlib/postprocess.py b/pufferlib/postprocess.py new file mode 100644 index 000000000..bc2c8c72b --- /dev/null +++ b/pufferlib/postprocess.py @@ -0,0 +1,21 @@ +"""Legacy post-processing wrappers for backward compatibility. + +This module keeps the historical ``pufferlib.postprocess`` import path alive by +re-exporting the wrappers that now live in ``pufferlib.pufferlib``. +""" + +from pufferlib.pufferlib import ( + ClipAction, + EpisodeStats, + MeanOverAgents, + MultiagentEpisodeStats, + PettingZooWrapper, +) + +__all__ = [ + "ClipAction", + "EpisodeStats", + "MeanOverAgents", + "MultiagentEpisodeStats", + "PettingZooWrapper", +] From 6131d974c7a9f722853d45f22293371c486e2f5a Mon Sep 17 00:00:00 2001 From: Ruben Date: Tue, 4 Nov 2025 20:01:52 +0100 Subject: [PATCH 2/2] feat: implement TF32 configuration for improved performance in PuffeRL --- pufferlib/pufferl.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index f74fec44b..c3269b12f 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -57,7 +57,7 @@ class PuffeRL: def __init__(self, config, vecenv, policy, logger=None): # Backend perf optimization - torch.set_float32_matmul_precision('high') + self._configure_tf32() torch.backends.cudnn.deterministic = config['torch_deterministic'] torch.backends.cudnn.benchmark = True @@ -213,6 +213,30 @@ def __init__(self, config, vecenv, policy, logger=None): self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad) self.print_dashboard(clear=True) + @staticmethod + def _configure_tf32(): + '''Configure TF32 execution using the latest PyTorch API with fallbacks.''' + if not torch.backends.cuda.is_built(): + return + + try: + torch.backends.cuda.matmul.fp32_precision = 'tf32' + except Exception: + try: + torch.backends.cuda.matmul.allow_tf32 = True + except Exception: + pass + + cudnn_conv = getattr(torch.backends.cudnn, 'conv', None) + if cudnn_conv is not None: + try: + cudnn_conv.fp32_precision = 'tf32' + except Exception: + try: + torch.backends.cudnn.allow_tf32 = True + except Exception: + pass + @property def uptime(self): return time.time() - self.start_time