-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
59 lines (48 loc) · 1.74 KB
/
agent.py
File metadata and controls
59 lines (48 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from collections import defaultdict
import random
import numpy as np
class Agent:
def __init__(
self,
env,
learning_rate: float,
initial_epsilon: float,
epsilon_decay: float,
final_epsilon: float,
discount_factor: float,
initial_value: float,
):
self.lr = learning_rate
self.discount_factor = discount_factor
self.epsilon = initial_epsilon
self.epsilon_decay = epsilon_decay
self.final_epsilon = final_epsilon
self.q_values = defaultdict(
lambda: np.full(env.action_space.n, self.initial_value)
)
self.initial_value = initial_value
self.training_error = []
def get_action(self, env, obs):
dice = random.random()
if dice < self.epsilon:
return env.action_space.sample(), 0
else:
obs = str(obs)
m = np.max(self.q_values[obs])
action = random.choice(np.where(self.q_values[obs] == m)[0])
return action, 1
def update(self, env, obs, action, reward, terminated, next_obs):
"""Updates the Q-value of an action."""
# Convert obs and next_obs to hashable structs
obs = str(obs)
next_obs = str(next_obs)
future_q_value = (not terminated) * np.max(self.q_values[next_obs])
temporal_difference = (
reward + self.discount_factor * future_q_value - self.q_values[obs][action]
)
self.q_values[obs][action] = (
self.q_values[obs][action] + self.lr * temporal_difference
)
self.training_error.append(temporal_difference)
def decay_epsilon(self):
self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)