Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions intermediate_source/mario_rl_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,24 @@
#
# %%bash
# pip install gym-super-mario-bros==7.4.0
# pip install tensordict==0.3.0
# pip install torchrl==0.3.0
# pip install 'tensordict>=0.3.0'
# pip install 'torchrl>=0.3.0'
# pip install matplotlib
#

######### Patch for numpy 2.x compatibility
import nes_py._rom

def patched_prg_rom_stop(self):
return self.prg_rom_start + int(self.prg_rom_size) * 2**10

def patched_chr_rom_stop(self):
return self.chr_rom_start + int(self.chr_rom_size) * 2**10

nes_py._rom.ROM.prg_rom_stop = property(patched_prg_rom_stop)
nes_py._rom.ROM.chr_rom_stop = property(patched_chr_rom_stop)
########################################################################

import torch
from torch import nn
from torchvision import transforms as T
Expand Down Expand Up @@ -84,7 +98,6 @@
# the action in a state. We try to approximate this function.
#


######################################################################
# Environment
# """"""""""""""""
Expand All @@ -99,11 +112,26 @@
# (next) state, reward and other info.
#


################### Patch for NumPy 2.x: add np.bool8 alias if missing
if not hasattr(np, "bool8"):
np.bool8 = np.bool_


################### Patch the _x_position property to cast RAM values to int
def patched_x_position(self):
# Cast to int to avoid numpy uint8 overflow
return int(self.ram[0x6d]) * 0x100 + int(self.ram[0x86])

gym_super_mario_bros.smb_env.SuperMarioBrosEnv._x_position = property(patched_x_position)

#######################################################################################

# Initialize Super Mario environment (in v0.26 change render mode to 'human' to see results on the screen)
if gym.__version__ < '0.26':
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", new_step_api=True)
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v3", new_step_api=True)
else:
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb', apply_api_compatibility=True)
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v3", render_mode='rgb', apply_api_compatibility=True)

# Limit the action-space to
# 0. walk right
Expand Down Expand Up @@ -292,7 +320,11 @@ def __init__(self, state_dim, action_dim, save_dir):
self.action_dim = action_dim
self.save_dir = save_dir

self.device = "cuda" if torch.cuda.is_available() else "cpu"
use_accel = torch.accelerator.current_accelerator()
if use_accel is None:
use_accel = "cpu"
self.device = use_accel
#self.device = "xpu" if torch.xpu.is_available() else "cpu"

# Mario's DNN to predict the most optimal action - we implement this in the Learn section
self.net = MarioNet(self.state_dim, self.action_dim).float()
Expand Down Expand Up @@ -735,9 +767,10 @@ def record(self, episode, epsilon, step):
# In this example we run the training loop for 40 episodes, but for Mario to truly learn the ways of
# his world, we suggest running the loop for at least 40,000 episodes!
#
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
use_accel = torch.accelerator.current_accelerator() ###torch.xpu.is_available()
if use_accel is None:
use_accel = "cpu"
print(f"Using device: {use_accel}\n")

save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
Expand Down