Skip to content

Commit 46d598d

Browse files
committed
fix: use in pufferrl, missed in merge conflict
1 parent cb781d3 commit 46d598d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pufferlib/pufferl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
# Assume advantage kernel has been built if torch has been compiled with CUDA or HIP support
5454
# and can find CUDA or HIP in the system
5555
ADVANTAGE_CUDA = bool(CUDA_HOME or ROCM_HOME)
56+
ADVANTAGE_MPS = bool(torch.backends.mps.is_available())
5657

5758
class PuffeRL:
5859
def __init__(self, config, vecenv, policy, logger=None):
@@ -664,7 +665,8 @@ def compute_puff_advantage(values, rewards, terminals,
664665
compile the fast version.'''
665666

666667
device = values.device
667-
if not ADVANTAGE_CUDA:
668+
669+
if not ADVANTAGE_CUDA and not ADVANTAGE_MPS:
668670
values = values.cpu()
669671
rewards = rewards.cpu()
670672
terminals = terminals.cpu()
@@ -674,7 +676,7 @@ def compute_puff_advantage(values, rewards, terminals,
674676
torch.ops.pufferlib.compute_puff_advantage(values, rewards, terminals,
675677
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip)
676678

677-
if not ADVANTAGE_CUDA:
679+
if not ADVANTAGE_CUDA and not ADVANTAGE_MPS:
678680
return advantages.to(device)
679681

680682
return advantages
@@ -1134,7 +1136,9 @@ def autotune(args=None, env_name=None, vecenv=None, policy=None):
11341136

11351137
def load_env(env_name, args):
11361138
package = args['package']
1139+
print("package", package)
11371140
module_name = 'pufferlib.ocean' if package == 'ocean' else f'pufferlib.environments.{package}'
1141+
print("module_name", module_name)
11381142
env_module = importlib.import_module(module_name)
11391143
make_env = env_module.env_creator(env_name)
11401144
return pufferlib.vector.make(make_env, env_kwargs=args['env'], **args['vec'])

0 commit comments

Comments
 (0)