-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSimpleNN.py
More file actions
108 lines (82 loc) · 3.86 KB
/
SimpleNN.py
File metadata and controls
108 lines (82 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import matplotlib.pyplot as plt
fig2 = plt.figure(figsize=(8,7))
ax2 = fig2.add_subplot(111)
#----SETTINGS----
learningRate = 0.001
epoch = 1000 #how many Cycles
maxVal = 5 #max value in range of inputs, also min value is -maxVal
totalInputs = 10000 #number or entries in dataset
weights0 = np.random.random((3,2))
weights1 = np.random.random((1,3))
bias0 = np.random.random((3,1))
bias1 = np.random.random((1,1))
#----------------
#Since weights and bias start randomly, you can get unlucky and go out of bounds,
#get numbers that are way to large for python to handle
def resetFunc():
Rweights0 = np.random.random((3,2))
Rweights1 = np.random.random((1,3))
Rbias0 = np.random.random((3,1))
Rbias1 = np.random.random((1,1))
return Rweights0, Rweights1, Rbias0, Rbias1
def reluFunc(x):
if (x < 0):
return 0
return x
def DerivReluFunc(x):
if (x < 0):
return 0
return 1
def correctValue(xIN, yIN):
return xIN**2 + yIN**2
vFunc = np.vectorize(reluFunc)
vFunc2 = np.vectorize(DerivReluFunc)
#Partial derivative of each weight, refer to note sheet for drawn out diagram of equations.
#The partial derivative is dependent on next and previous nodes, which each weight is connected to
#Not all weights and connected to all nodes, so must be careful on multiplication on nodes, weights, for partial derivative
def findSlope(weightSlope, AnyWeights, PrevNodes, NextNodes): #zval??
for r in range(np.shape(AnyWeights)[0]):
for c in range(np.shape(AnyWeights)[1]):
slopeVal = DerivReluFunc(NextNodes[r]) * AnyWeights[r][c] * PrevNodes[c]
weightSlope[r][c] = slopeVal
def generateData(totalInputs, maxVal):
dataArray = np.random.random((totalInputs, 2)) * 2 * maxVal - maxVal
return dataArray
def oneTrainingCycle(weights0, weights1, bias0, bias1, inputValArray):
global vFunc, vFunc2
weightSlope0 = np.zeros(np.shape(weights0)) #Initialize where to store partial derivatives "Slopes"
weightSlope1 = np.zeros(np.shape(weights1))
sumTotal = 0
#for x /for y in range(4):
for i in range(10):
inputVal = inputValArray[i]
inputVal = inputVal[:, np.newaxis]
node1 = vFunc(np.matmul(weights0,inputVal) + bias0)
node2 = vFunc(np.matmul(weights1,node1) + bias1)
sumTotal += (node2 - correctValue(inputVal[0], inputVal[1])) ** 2
findSlope(weightSlope0, weights0, inputVal, node1)
weightSlope1 = 2 * (node2 - correctValue(inputVal[0], inputVal[1])) * vFunc2(node2) * np.transpose(node1)
for r in range(np.shape(weightSlope0)[0]):
weightSlope0[r] = 2 * (node2 - correctValue(inputVal[0], inputVal[1])) * vFunc2(node2) * weights1[0][r] * weightSlope0[r][:]
biasSlope1 = 2 * (node2 - correctValue(inputVal[0], inputVal[1])) * vFunc2(node2) * bias1
biasSlope0 = 2 * (node2 - correctValue(inputVal[0], inputVal[1])) * vFunc2(node2) * np.transpose(weights1) * vFunc2(node1) * bias0
weights1 = weights1 - learningRate * weightSlope1
weights0 = weights0 - learningRate * weightSlope0
bias1 = bias1 - learningRate * biasSlope1
bias0 = bias0 - learningRate * biasSlope0
#-------
meanSquared = sumTotal / 16
return weights0, weights1, bias0, bias1, meanSquared
meanSquaredValues = np.array([])
inputValArray = generateData(totalInputs, maxVal)
for t in range(epoch):
epochTime = np.arange(epoch)
weights0, weights1, bias0, bias1, meanSquared = oneTrainingCycle(weights0, weights1, bias0, bias1, inputValArray)
meanSquaredValues = np.append(meanSquaredValues, meanSquared)
ax2.plot(epochTime, meanSquaredValues, '.r')
ax2.set_xlabel("Epoch\n(Number of iterations)")
ax2.set_ylabel("Cost Value\n(How bad the network is)")
dataString = "Final Cost Value: {}".format(meanSquaredValues[epoch - 1])
ax2.text(0.50, 0.95, dataString, transform = ax2.transAxes)
plt.show()