5353# Assume advantage kernel has been built if torch has been compiled with CUDA or HIP support
5454# and can find CUDA or HIP in the system
5555ADVANTAGE_CUDA = bool (CUDA_HOME or ROCM_HOME )
56+ ADVANTAGE_MPS = bool (torch .backends .mps .is_available ())
5657
5758class PuffeRL :
5859 def __init__ (self , config , vecenv , policy , logger = None ):
@@ -664,7 +665,8 @@ def compute_puff_advantage(values, rewards, terminals,
664665 compile the fast version.'''
665666
666667 device = values .device
667- if not ADVANTAGE_CUDA :
668+
669+ if not ADVANTAGE_CUDA and not ADVANTAGE_MPS :
668670 values = values .cpu ()
669671 rewards = rewards .cpu ()
670672 terminals = terminals .cpu ()
@@ -674,7 +676,7 @@ def compute_puff_advantage(values, rewards, terminals,
674676 torch .ops .pufferlib .compute_puff_advantage (values , rewards , terminals ,
675677 ratio , advantages , gamma , gae_lambda , vtrace_rho_clip , vtrace_c_clip )
676678
677- if not ADVANTAGE_CUDA :
679+ if not ADVANTAGE_CUDA and not ADVANTAGE_MPS :
678680 return advantages .to (device )
679681
680682 return advantages
@@ -1134,7 +1136,9 @@ def autotune(args=None, env_name=None, vecenv=None, policy=None):
11341136
11351137def load_env (env_name , args ):
11361138 package = args ['package' ]
1139+ print ("package" , package )
11371140 module_name = 'pufferlib.ocean' if package == 'ocean' else f'pufferlib.environments.{ package } '
1141+ print ("module_name" , module_name )
11381142 env_module = importlib .import_module (module_name )
11391143 make_env = env_module .env_creator (env_name )
11401144 return pufferlib .vector .make (make_env , env_kwargs = args ['env' ], ** args ['vec' ])
0 commit comments