Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pufferlib/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Legacy post-processing wrappers for backward compatibility.

This module keeps the historical ``pufferlib.postprocess`` import path alive by
re-exporting the wrappers that now live in ``pufferlib.pufferlib``.
"""

from pufferlib.pufferlib import (
ClipAction,
EpisodeStats,
MeanOverAgents,
MultiagentEpisodeStats,
PettingZooWrapper,
)

__all__ = [
"ClipAction",
"EpisodeStats",
"MeanOverAgents",
"MultiagentEpisodeStats",
"PettingZooWrapper",
]
26 changes: 25 additions & 1 deletion pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
class PuffeRL:
def __init__(self, config, vecenv, policy, logger=None):
# Backend perf optimization
torch.set_float32_matmul_precision('high')
self._configure_tf32()
torch.backends.cudnn.deterministic = config['torch_deterministic']
torch.backends.cudnn.benchmark = True

Expand Down Expand Up @@ -213,6 +213,30 @@ def __init__(self, config, vecenv, policy, logger=None):
self.model_size = sum(p.numel() for p in policy.parameters() if p.requires_grad)
self.print_dashboard(clear=True)

@staticmethod
def _configure_tf32():
'''Configure TF32 execution using the latest PyTorch API with fallbacks.'''
if not torch.backends.cuda.is_built():
return

try:
torch.backends.cuda.matmul.fp32_precision = 'tf32'
except Exception:
try:
torch.backends.cuda.matmul.allow_tf32 = True
except Exception:
pass

cudnn_conv = getattr(torch.backends.cudnn, 'conv', None)
if cudnn_conv is not None:
try:
cudnn_conv.fp32_precision = 'tf32'
except Exception:
try:
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass

@property
def uptime(self):
return time.time() - self.start_time
Expand Down