-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtrain.py
More file actions
74 lines (52 loc) · 2.08 KB
/
train.py
File metadata and controls
74 lines (52 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 11 00:31:50 2018
@author: aidanrocke
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 11 10:32:53 2018
@author: aidanrocke
"""
import random
import tensorflow as tf
import numpy as np
from free_energy_agent import free_agent
## set random seed:
random.seed(42)
tf.set_random_seed(42)
def train(epochs,batch_size,basic_needs,success_probability):
# initialize count:
count = 0
## define number of evaluations:
N = int(epochs/100)
# initialize food vectors and total consumption:
food_policies = np.zeros((N,24))
total_consumption = np.zeros(N)
with tf.Session() as sess:
F = free_agent(basic_needs,sess,success_probability)
log_loss = F.surprise()
### it might be a good idea to regularise the squared loss:
surprisal = -1.0*tf.reduce_mean(log_loss)
### define the optimiser:
optimizer = tf.train.AdagradOptimizer(0.01)
train_agent = optimizer.minimize(surprisal)
### initialise the variables:
sess.run(tf.global_variables_initializer())
log_loss = np.zeros(epochs)
for i in range(epochs):
mini_batch = basic_needs*np.ones((batch_size,1),dtype=np.float32)
train_feed = {F.survival : mini_batch}
sess.run(train_agent,feed_dict = train_feed)
log_loss[i] = sess.run(surprisal,feed_dict = train_feed)
## check variances:
if i % 100 == 0:
evaluation_feed = {F.survival : mini_batch[0].reshape(1,1)}
#print(np.shape(F.sess.run([F.strategy],feed_dict=evaluation_feed)[0]))
#break
food_policies[count] = F.sess.run([F.strategy],feed_dict=evaluation_feed)[0]
total_consumption[count] = np.sum(food_policies[count])
count += 1
return log_loss, food_policies, total_consumption