Skip to content

Commit 88b3cf8

Browse files
committed
Q learning and Epsilon greedy
1 parent 2870716 commit 88b3cf8

File tree

1 file changed

+53
-7
lines changed

1 file changed

+53
-7
lines changed

2022/FA22/intro-ai-series/workshop-3-reinforcement-learning/src/qlearningAgents.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, **args):
4343
ReinforcementAgent.__init__(self, **args)
4444

4545
"*** YOUR CODE HERE ***"
46+
self.qvalues = util.Counter()
4647

4748
def getQValue(self, state, action):
4849
"""
@@ -51,7 +52,11 @@ def getQValue(self, state, action):
5152
or the Q node value otherwise
5253
"""
5354
"*** YOUR CODE HERE ***"
54-
util.raiseNotDefined()
55+
#if self.qvalues.has_key((state,action))==True:
56+
if (state,action) in self.qvalues:
57+
return self.qvalues[(state,action)]
58+
else:
59+
return 0
5560

5661
def computeValueFromQValues(self, state):
5762
"""
@@ -61,7 +66,10 @@ def computeValueFromQValues(self, state):
6166
terminal state, you should return a value of 0.0.
6267
"""
6368
"*** YOUR CODE HERE ***"
64-
util.raiseNotDefined()
69+
actions = self.getLegalActions(state)
70+
if len(actions) ==0:
71+
return 0.0
72+
return max([self.getQValue(state,action) for action in actions])
6573

6674
def computeActionFromQValues(self, state):
6775
"""
@@ -70,7 +78,23 @@ def computeActionFromQValues(self, state):
7078
you should return None.
7179
"""
7280
"*** YOUR CODE HERE ***"
73-
util.raiseNotDefined()
81+
possibleActions = self.getLegalActions(state)
82+
83+
84+
85+
if possibleActions==():
86+
return None
87+
else:
88+
maxVal=-9999999
89+
maxAction=0
90+
91+
for action in possibleActions:
92+
93+
if self.getQValue(state,action)>maxVal:
94+
maxVal=self.getQValue(state,action)
95+
maxAction = action
96+
97+
return maxAction
7498

7599
def getAction(self, state):
76100
"""
@@ -86,10 +110,19 @@ def getAction(self, state):
86110
legalActions = self.getLegalActions(state)
87111
action = None
88112
"*** YOUR CODE HERE ***"
89-
util.raiseNotDefined()
113+
if legalActions==():
114+
return None
115+
116+
chooseRandomAction = util.flipCoin(self.epsilon)
117+
118+
if chooseRandomAction == True:
119+
action = random.choice(legalActions)
120+
else:
121+
action = self.computeActionFromQValues(state)
90122

91123
return action
92124

125+
93126
def update(self, state, action, nextState, reward: float):
94127
"""
95128
The parent class calls this to observe a
@@ -99,7 +132,11 @@ def update(self, state, action, nextState, reward: float):
99132
it will be called on your behalf
100133
"""
101134
"*** YOUR CODE HERE ***"
102-
util.raiseNotDefined()
135+
new_value =(1-self.alpha)*self.getQValue(state,action) + self.alpha*( reward + self.discount*self.computeValueFromQValues(nextState) )
136+
137+
self.qvalues[(state,action)] = new_value
138+
139+
return None
103140

104141
def getPolicy(self, state):
105142
return self.computeActionFromQValues(state)
@@ -159,14 +196,23 @@ def getQValue(self, state, action):
159196
where * is the dotProduct operator
160197
"""
161198
"*** YOUR CODE HERE ***"
162-
util.raiseNotDefined()
199+
features = self.featExtractor.getFeatures(state,action)
200+
201+
return self.weights * features
163202

164203
def update(self, state, action, nextState, reward: float):
165204
"""
166205
Should update your weights based on transition
167206
"""
168207
"*** YOUR CODE HERE ***"
169-
util.raiseNotDefined()
208+
features = self.featExtractor.getFeatures(state,action)
209+
210+
diff = (reward + self.discount*self.computeValueFromQValues(nextState)) - self.getQValue(state,action)
211+
212+
for key in features.keys():
213+
self.weights[key] = self.weights[key] + features[key]*self.alpha*diff
214+
215+
return None
170216

171217
def final(self, state):
172218
"""Called at the end of each game."""

0 commit comments

Comments
 (0)