-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
57 lines (45 loc) · 1.67 KB
/
main.py
File metadata and controls
57 lines (45 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from environment import create_environment, preprocess_observation
from agent import DQNAgent
from utils import ReplayBuffer, save_model, load_model
import torch
import numpy as np
# Hyperparameters
EPISODES = 1000
BATCH_SIZE = 32
GAMMA = 0.99 # Discount factor
LEARNING_RATE = 0.001
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.1
EXPLORATION_DECAY = 0.995
# Initialize environment and agent
env = create_environment()
agent = DQNAgent(env.action_space.n, LEARNING_RATE)
replay_buffer = ReplayBuffer(10000)
# Training loop
exploration_rate = EXPLORATION_MAX
for episode in range(EPISODES):
state = preprocess_observation(env.reset())
done = False
total_reward = 0
while not done:
# Choose action using ε-greedy strategy
if np.random.rand() < exploration_rate:
action = env.action_space.sample() # Explore
else:
action = agent.act(state) # Exploit
# Take action in the environment
next_state, reward, done, _ = env.step(action)
next_state = preprocess_observation(next_state)
replay_buffer.store(state, action, reward, next_state, done)
state = next_state
total_reward += reward
# Train agent using replay buffer
if len(replay_buffer) > BATCH_SIZE:
agent.train(replay_buffer.sample(BATCH_SIZE), GAMMA)
# Decay exploration rate
exploration_rate = max(EXPLORATION_MIN, exploration_rate * EXPLORATION_DECAY)
print(f"Episode {episode + 1}: Total Reward = {total_reward}")
# Save model periodically
if (episode + 1) % 50 == 0:
save_model(agent, f"models/mario_agent_{episode + 1}.pth")
env.close()