-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathqlearn.py
More file actions
85 lines (78 loc) · 2.59 KB
/
qlearn.py
File metadata and controls
85 lines (78 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
## packages
import numpy as np
import os, time
## define function
def SimProc(action_value, reward, trans_mat, steps, gamma, alpha, epsilon):
# initialize setting
record = []
state = np.random.randint(1,15)
for step in range(steps):
# get next infromation
action = GetAction(action_value, epsilon, state)
next_state = np.argmax(trans_mat[:,state,action])
record.append([state, action, reward[next_state], next_state])
# update action value
action_value[state, action] = ValueUpdate(action_value, record[step], alpha, gamma)
# update for next step
state = next_state
if state == 0 or state == 15:
break
return action_value
def GetAction(action_value, epsilon, next_state):
if np.random.rand(1) >= epsilon:
policy = np.argmax(action_value, axis = 1)
action = policy[next_state]
else:
action = np.random.randint(0,4,1)
return action
def ValueUpdate(action_value, record, alpha, gamma):
state = record[0]
action = record[1]
reward = record[2]
next_state = record[3]
now_value = action_value[state, action]
update_value = alpha*(reward + gamma*np.max(action_value[next_state,:]) - now_value)
value = now_value + update_value
return value
def PrintGreedyPolicy(now_episode, action_value):
policy = np.argmax(action_value, axis = 1)
policy_string = []
policy_string.append('*')
for i in range(1,15):
if policy[i] == 0:
policy_string.append('^')
elif policy[i] == 1:
policy_string.append('<')
elif policy[i] == 2:
policy_string.append('v')
elif policy[i] == 3:
policy_string.append('>')
policy_string.append('*')
policy_string = np.array(policy_string)
os.system('cls' if os.name == 'nt' else 'clear')
print('='*60)
print('[Greedy Policy]')
print('Episode: ' + str(now_episode+1))
print(policy_string.reshape(4,4))
print(np.max(action_value, axis = 1).reshape(4,4))
print('='*60)
## main function
def main(Episodes):
# Environment setting
ActionValue = np.zeros([16,4])
Reward = np.full(16, -1)
Reward[0] = 0
Reward[-1] = 0
TransMat = np.load('./gridworld/T.npy')
# parameters setting
Gamma = 0.99
Steps = 50
Alpha = 0.05
# Execute
for episode in range(Episodes):
Epsilon = 1/(episode+1)
ActionValue = SimProc(ActionValue, Reward, TransMat, Steps, Gamma, Alpha, Epsilon)
PrintGreedyPolicy(episode, ActionValue)
#time.sleep(1)
if __name__ == '__main__':
main(1000)