Skip to content

Commit 209d58c

Browse files
committed
Add backprop implementation, still needs work
1 parent f276368 commit 209d58c

File tree

6 files changed

+78
-17
lines changed

6 files changed

+78
-17
lines changed

SkalaNet/build.sbt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
// import scala.scalanative.build.*
2+
13
scalaVersion := "3.1.3"
24

5+
/*
36
enablePlugins(ScalaNativePlugin)
7+
8+
nativeConfig ~= {
9+
_.withLTO(LTO.thin)
10+
.withMode(Mode.releaseFull)
11+
.withGC(GC.commix)
12+
}
13+
*/

SkalaNet/src/main/scala/Image.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ object Image:
2020

2121
private def readLabels(labelFile: String): Seq[Int] = readBytes(labelFile).drop(8).map(_.toInt)
2222

23-
def readImages(imageFile: String, labelFile: String): Seq[Image] =
23+
def readImages(imageFile: String, labelFile: String): IndexedSeq[Image] =
2424
val labels = readLabels(labelFile)
2525
readBytes(imageFile).drop(16)
2626
.map(_.toInt & 255) // convert to unsigned "byte" by masking with 0b11111111
2727
.grouped(28 * 28)
2828
.map(_.grouped(28).toArray)
2929
.zip(labels)
3030
.map((pixels, label) => Image(pixels = pixels, label = label))
31-
.toSeq
31+
.toIndexedSeq

SkalaNet/src/main/scala/Main.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def tryNetwork() =
2929

3030
def trainNetwork() = nn.SGD(
3131
trainingData = trainingImages.map(img => (img.toColumnVector(), img.label)),
32-
epochs = 1,
32+
epochs = 2,
3333
batchSize = 100
3434
)
3535

SkalaNet/src/main/scala/Matrix.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
package SkalaNet
22

33
import SkalaNet.Types.*
4-
import scalanative.unsafe.*
4+
/*import scalanative.unsafe.*
55
66
@extern
77
def mult(
88
n: CInt, m: CInt, p: CInt,
99
A: Ptr[CFloat], B: Ptr[CFloat],
1010
res: Ptr[CFloat]
1111
): Unit = extern
12-
12+
*/
1313
extension (M: Matrix)
1414

1515
def rows: Int = M.size
@@ -29,13 +29,8 @@ extension (M: Matrix)
2929
def -(other: Matrix): Matrix =
3030
assert(rows == other.rows && cols == other.cols, "Matrix dimensions do not match!")
3131

32-
val newM = Array.ofDim[Float](rows, cols)
33-
for i <- 0 until rows do
34-
for j <- 0 until cols do
35-
newM(i)(j) = M(i)(j) - other(i)(j)
36-
37-
newM
38-
32+
M + other * -1
33+
/*
3934
def *(other: Matrix): Matrix =
4035
assert(cols == other.rows, "Dimensions are not valid for multiplication!")
4136
@@ -60,9 +55,28 @@ extension (M: Matrix)
6055
newM(i)(j) = !(res + i * p + j)
6156
6257
newM
63-
58+
*/
59+
60+
def *(other: Matrix): Matrix =
61+
assert(cols == other.rows)
62+
val (n, m, p) = (rows, cols, other.cols)
63+
val res = Array.ofDim[Float](n, p)
64+
for i <- 0 until n do
65+
for j <- 0 until p do
66+
for k <- 0 until m do
67+
res(i)(j) = M(i)(k) * other(k)(j)
68+
res
69+
6470
def *(c: Float): Matrix =
6571
M.map(_.map(z => c * z))
72+
73+
def (other: Matrix): Matrix =
74+
assert(rows == other.rows && cols == other.cols, "Matrix dimensions differ!")
75+
val res = Array.ofDim[Float](rows, cols)
76+
for i <- 0 until rows do
77+
for j <- 0 until cols do
78+
res(i)(j) = M(i)(j) * other(i)(j)
79+
res
6680

6781
object Matrix:
6882

SkalaNet/src/main/scala/NeuralNetwork.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package SkalaNet
22

33
import SkalaNet.Types.*
4+
import collection.mutable.ArrayBuffer
5+
import Utils.zip
46

57
extension (x: Float)
68
def **(y: Int): Float =
@@ -14,15 +16,20 @@ case class NeuralNetwork private (private val layerSizes: Seq[Int]):
1416
// ReLU ;)
1517
private def __/(m: Matrix): Matrix = m.map(_.map(z => math.max(z, 0)))
1618

19+
private def reluPrime(m: Matrix): Matrix = m.map(_.map(z => if z > 0 then 1 else 0))
20+
1721
private def feedforward(inp: Matrix): Matrix =
1822
weights.zip(biases).foldLeft(inp){case (x, (w, b)) => __/(w * x + b)}
1923

24+
private def costPrime(output: Matrix, expectedOutput: Matrix): Matrix =
25+
(output - expectedOutput) * 2
26+
2027
// query the network using a matrix representing the image
2128
def apply(inp: Matrix): Int =
2229
feedforward(inp).flatten.zipWithIndex.max._2
2330

2431
// perform stochastic gradient descent
25-
def SGD(trainingData: Seq[(Matrix, Int)], epochs: Int, batchSize: Int): Unit =
32+
def SGD(trainingData: IndexedSeq[(Matrix, Int)], epochs: Int, batchSize: Int): Unit =
2633
import util.Random.shuffle
2734
val n = trainingData.size
2835
for epoch <- 1 to epochs do
@@ -43,10 +50,32 @@ case class NeuralNetwork private (private val layerSizes: Seq[Int]):
4350
weights = weights.zip(nablaW).map((w, nw) => w - nw * (1 / len))
4451
biases = biases.zip(nablaB).map((b, nb) => b - nb * (1 / len))
4552

46-
private def backprop(inp: Matrix, expectedAns: Int): (Seq[Matrix], Seq[Matrix]) = ???
53+
private def backprop(inp: Matrix, expectedAns: Int): (Seq[Matrix], Seq[Matrix]) =
54+
val deltaW = ArrayBuffer[Matrix]()
55+
val deltaB = ArrayBuffer[Matrix]()
56+
57+
val zs = ArrayBuffer[Matrix]()
58+
val as = ArrayBuffer[Matrix](inp)
59+
weights.zip(biases).foldLeft(inp){
60+
case (x, (w, b)) =>
61+
val z = w * x + b
62+
zs.append(z)
63+
val a = __/(z)
64+
as.append(a)
65+
a
66+
}
67+
68+
val expectedOutput = Array.ofDim[Float](10, 1)
69+
expectedOutput(expectedAns)(0) = 1f
70+
var delta = costPrime(as.last, expectedOutput) reluPrime(zs.last)
71+
deltaW.append(delta * as.init.last.transpose)
72+
deltaB.append(delta)
4773

48-
private def cost(output: Matrix, expectedOutput: Matrix): Float =
49-
output.flatten.zip(expectedOutput.flatten).foldLeft(0f){case (acc, (a, b)) => acc + (a - b) ** 2}
74+
for (w_next, z, a_prev) <- zip(weights.tail, zs.init, as.init.init).reverse do
75+
delta = (w_next.transpose * delta) reluPrime(z)
76+
deltaB.append(delta)
77+
deltaW.append(delta * a_prev.transpose)
78+
(deltaW.reverse.toSeq, deltaB.reverse.toSeq)
5079

5180
object NeuralNetwork:
5281

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package SkalaNet
2+
3+
object Utils:
4+
5+
def zip[A, B, C](l1: IterableOnce[A], l2: IterableOnce[B], l3: IterableOnce[C]): Seq[(A, B, C)] =
6+
val (i1, i2, i3) = (l1.iterator, l2.iterator, l3.iterator)
7+
for i <- 0 until Seq(l1.size, l2.size, l3.size).max yield
8+
(i1.next, i2.next, i3.next)

0 commit comments

Comments
 (0)