-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_agent.py
More file actions
68 lines (53 loc) · 2.13 KB
/
test_agent.py
File metadata and controls
68 lines (53 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#import torch
import numpy as np
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros import make
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from agent import DQNAgent # Assuming DQNAgent is in agent.py
# Preprocess observation
import cv2
def preprocess_observation(obs):
"""
Preprocess the observation: convert to grayscale, resize, and normalize.
"""
gray_obs = cv2.cvtColor(obs, cv2.COLOR_BGR2GRAY)
resized_obs = cv2.resize(gray_obs, (84, 84))
normalized_obs = resized_obs / 255.0 # Normalize pixel values to [0, 1]
return np.expand_dims(normalized_obs, axis=0) # Add channel dimension
# Initialize environment and agent
env = make('SuperMarioBros-1-1-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)
state_shape = (1, 84, 84) # Shape of preprocessed state
action_space = env.action_space.n # Number of possible actions
agent = DQNAgent(action_space=action_space, learning_rate=0.001)
# Hyperparameters
episodes = 10 # Number of episodes for testing
gamma = 0.99 # Discount factor
exploration_rate = 1.0 # Start with full exploration
exploration_min = 0.1
exploration_decay = 0.995
# Main training loop
for episode in range(episodes):
state = preprocess_observation(env.reset())
total_reward = 0
done = False
while not done:
# Exploration vs. Exploitation
if np.random.rand() < exploration_rate:
action = env.action_space.sample() # Random action
else:
action = agent.act(state) # Action from the agent
# Step the environment
next_state, reward, done, _ = env.step(action)
next_state = preprocess_observation(next_state)
total_reward += reward
# Store the transition in memory
transition = (state, action, reward, next_state, done)
agent.train([transition], gamma) # Train the agent
state = next_state
# Render the environment (optional)
env.render()
# Decay exploration rate
exploration_rate = max(exploration_min, exploration_rate * exploration_decay)
print(f"Episode {episode + 1}/{episodes} - Total Reward: {total_reward}")
env.close()