Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pufferlib/config/default.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ project = ablations
seed = 42
torch_deterministic = True
cpu_offload = False
device = cuda
device = default
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is controversial or not, but just means people don't need to manually pass --train.device mps every time, automatically select the best one based off of the get_accelerator function

optimizer = muon
precision = float32
total_timesteps = 10_000_000
Expand Down
44 changes: 44 additions & 0 deletions pufferlib/extensions/mps/pufferlib.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <metal_stdlib>
using namespace metal;

kernel void puff_advantage_kernel(
device const float* values [[buffer(0)]],
device const float* rewards [[buffer(1)]],
device const float* dones [[buffer(2)]],
device const float* importance [[buffer(3)]],
device float* advantages [[buffer(4)]],
constant float& gamma [[buffer(5)]],
constant float& lambda [[buffer(6)]],
constant float& rho_clip [[buffer(7)]],
constant float& c_clip [[buffer(8)]],
constant int& horizon [[buffer(9)]],
uint row [[thread_position_in_grid]])
{
int offset = row * horizon;
device const float* row_values = values + offset;
device const float* row_rewards = rewards + offset;
device const float* row_dones = dones + offset;
device const float* row_importance = importance + offset;
device float* row_advantages = advantages + offset;

float gamma_lambda = gamma * lambda;

float lastpufferlam = 0.0f;
for (int t = horizon - 2; t >= 0; t--) {
int t_next = t + 1;

float importance_t = row_importance[t];
float done_next = row_dones[t_next];
float value_t = row_values[t];
float value_next = row_values[t_next];
float reward_next = row_rewards[t_next];

float rho_t = fmin(importance_t, rho_clip);
float c_t = fmin(importance_t, c_clip);

float nextnonterminal = 1.0f - done_next;
float delta = rho_t * (reward_next + gamma * value_next * nextnonterminal - value_t);
lastpufferlam = delta + gamma_lambda * c_t * lastpufferlam * nextnonterminal;
row_advantages[t] = lastpufferlam;
}
}
106 changes: 106 additions & 0 deletions pufferlib/extensions/mps/pufferlib.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <torch/extension.h>

namespace pufferlib {

static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

void compute_puff_advantage_mps(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {

@autoreleasepool {
TORCH_CHECK(values.device().is_mps(), "All tensors must be on MPS device");
TORCH_CHECK(values.is_contiguous(), "values must be contiguous");
TORCH_CHECK(rewards.is_contiguous(), "rewards must be contiguous");
TORCH_CHECK(dones.is_contiguous(), "dones must be contiguous");
TORCH_CHECK(importance.is_contiguous(), "importance must be contiguous");
TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous");
TORCH_CHECK(values.scalar_type() == torch::kFloat32, "All tensors must be float32");

int num_steps = values.size(0);
int horizon = values.size(1);

id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSError* error = nil;

// probably not all too necessary to cache, but does save like 0.1ms per call
static id<MTLFunction> function = nil;
static id<MTLComputePipelineState> pipelineState = nil;

if (function == nil) {
// read the file & compile the shader
NSString* sourcePath = [[@(__FILE__) stringByDeletingLastPathComponent]
stringByAppendingPathComponent:@"pufferlib.metal"];
NSString* source = [NSString stringWithContentsOfFile:sourcePath
encoding:NSUTF8StringEncoding error:&error];
TORCH_CHECK(source, "Failed to read Metal source file: ",
error ? [[error localizedDescription] UTF8String] : "unknown error");

id<MTLLibrary> library = [device newLibraryWithSource:source options:nil error:&error];
TORCH_CHECK(library, "Failed to compile Metal library: ",
[[error localizedDescription] UTF8String]);

function = [library newFunctionWithName:@"puff_advantage_kernel"];
TORCH_CHECK(function, "Failed to find puff_advantage_kernel function");

pipelineState = [device newComputePipelineStateWithFunction:function error:&error];
TORCH_CHECK(pipelineState, "Failed to create compute pipeline: ",
[[error localizedDescription] UTF8String]);
}


id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");

dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();

dispatch_sync(serialQueue, ^{
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
TORCH_CHECK(encoder, "Failed to create compute command encoder");

[encoder setComputePipelineState:pipelineState];
[encoder setBuffer:getMTLBufferStorage(values)
offset:values.storage_offset() * values.element_size() atIndex:0];
[encoder setBuffer:getMTLBufferStorage(rewards)
offset:rewards.storage_offset() * rewards.element_size() atIndex:1];
[encoder setBuffer:getMTLBufferStorage(dones)
offset:dones.storage_offset() * dones.element_size() atIndex:2];
[encoder setBuffer:getMTLBufferStorage(importance)
offset:importance.storage_offset() * importance.element_size() atIndex:3];
[encoder setBuffer:getMTLBufferStorage(advantages)
offset:advantages.storage_offset() * advantages.element_size() atIndex:4];

float gamma_f = gamma, lambda_f = lambda, rho_clip_f = rho_clip, c_clip_f = c_clip;
int horizon_i = horizon;

[encoder setBytes:&gamma_f length:sizeof(float) atIndex:5];
[encoder setBytes:&lambda_f length:sizeof(float) atIndex:6];
[encoder setBytes:&rho_clip_f length:sizeof(float) atIndex:7];
[encoder setBytes:&c_clip_f length:sizeof(float) atIndex:8];
[encoder setBytes:&horizon_i length:sizeof(int) atIndex:9];

MTLSize gridSize = MTLSizeMake(num_steps, 1, 1);

NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup;
if (threadGroupSize > num_steps) {
threadGroupSize = num_steps;
}
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);

[encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
[encoder endEncoding];

torch::mps::commit();
});
}
}

TORCH_LIBRARY_IMPL(pufferlib, MPS, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_mps);
}

}
9 changes: 8 additions & 1 deletion pufferlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,14 @@ def __init__(self, env, policy, input_size=128, hidden_size=128):
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name and param.ndim >= 2:
nn.init.orthogonal_(param, 1.0)
if param.device.type == 'mps':
# Apple MPS does not support orthogonal

param.to(device='cpu')
nn.init.orthogonal_(param, 1.0)
param.to(device=param.device)
else:
nn.init.orthogonal_(param, 1.0)

self.lstm = nn.LSTM(input_size, hidden_size)

Expand Down
25 changes: 17 additions & 8 deletions pufferlib/pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# Assume advantage kernel has been built if torch has been compiled with CUDA or HIP support
# and can find CUDA or HIP in the system
ADVANTAGE_CUDA = bool(CUDA_HOME or ROCM_HOME)
ADVANTAGE_MPS = bool(torch.backends.mps.is_available())

class PuffeRL:
def __init__(self, config, vecenv, policy, logger=None):
Expand Down Expand Up @@ -664,7 +665,8 @@ def compute_puff_advantage(values, rewards, terminals,
compile the fast version.'''

device = values.device
if not ADVANTAGE_CUDA:

if not ADVANTAGE_CUDA and not ADVANTAGE_MPS:
values = values.cpu()
rewards = rewards.cpu()
terminals = terminals.cpu()
Expand All @@ -674,7 +676,7 @@ def compute_puff_advantage(values, rewards, terminals,
torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals,
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip)

if not ADVANTAGE_CUDA:
if not ADVANTAGE_CUDA and not ADVANTAGE_MPS:
return advantages.to(device)

return advantages
Expand All @@ -692,6 +694,12 @@ def abbreviate(num, b2, c2):
else:
return f'{b2}{num/1e12:.2f}{c2}T'

def get_accelerator(device):
if device == 'default':
return torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu'
else:
return device

def duration(seconds, b2, c2):
if seconds < 0:
return f"{b2}0{c2}s"
Expand Down Expand Up @@ -736,8 +744,8 @@ def __call__(self, name, epoch, nest=False):
if (epoch + 1) % self.frequency != 0:
return

if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.accelerator.is_available():
torch.accelerator.synchronize()

tick = time.time()
if len(self.stack) != 0 and not nest:
Expand All @@ -754,8 +762,8 @@ def pop(self, end):
profile['elapsed'] += delta * self.frequency

def end(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.accelerator.is_available():
torch.accelerator.synchronize()

end = time.time()
for i in range(len(self.stack)):
Expand Down Expand Up @@ -992,7 +1000,7 @@ def eval(env_name, args=None, vecenv=None, policy=None):
ob, info = vecenv.reset()
driver = vecenv.driver_env
num_agents = vecenv.observation_space.shape[0]
device = args['train']['device']
device = get_accelerator(args['train']['device'])

state = {}
if args['train']['use_rnn']:
Expand Down Expand Up @@ -1144,7 +1152,7 @@ def load_policy(args, vecenv, env_name=''):
module_name = 'pufferlib.ocean' if package == 'ocean' else f'pufferlib.environments.{package}'
env_module = importlib.import_module(module_name)

device = args['train']['device']
device = get_accelerator(args['train']['device'])
policy_cls = getattr(env_module.torch, args['policy_name'])
policy = policy_cls(vecenv.driver_env, **args['policy'])

Expand Down Expand Up @@ -1282,6 +1290,7 @@ def auto_type(value):

args['train']['env'] = args['env_name'] or '' # for trainer dashboard
args['train']['use_rnn'] = args['rnn_name'] is not None
args['train']['device'] = get_accelerator(args['train']['device'])
return args

def main():
Expand Down
8 changes: 7 additions & 1 deletion pufferlib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,13 @@ def _flattened_tensor_size(native_dtype):

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
"""CleanRL's default layer initialization"""
torch.nn.init.orthogonal_(layer.weight, std)
if layer.weight.device.type == 'mps':
# Apple MPS does not support orthogonal
layer.weight.to(device='cpu')
nn.init.orthogonal_(layer.weight, std)
layer.weight.to(device=layer.device)
else:
nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer

Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
CUDA_HOME,
ROCM_HOME
)
from torch.backends import mps

# build cuda extension if torch can find CUDA or HIP/ROCM in the system
# may require `uv pip install --no-build-isolation` or `python setup.py build_ext --inplace`
BUID_CUDA_EXT = bool(CUDA_HOME or ROCM_HOME)

# build mps extension if torch can find MPS in the system
BUILD_MPS_EXT = bool(mps.is_available())

# Build with DEBUG=1 to enable debug symbols
DEBUG = os.getenv("DEBUG", "0") == "1"
NO_OCEAN = os.getenv("NO_OCEAN", "0") == "1"
Expand Down Expand Up @@ -243,6 +247,9 @@ def run(self):
if BUID_CUDA_EXT:
extension = CUDAExtension
torch_sources.append("pufferlib/extensions/cuda/pufferlib.cu")
elif BUILD_MPS_EXT:
extension = CppExtension
torch_sources.append("pufferlib/extensions/mps/pufferlib.mm")
else:
extension = CppExtension

Expand Down
Loading