|
12 | 12 |
|
13 | 13 | ## 🤔 What is this project? |
14 | 14 |
|
15 | | -This project is a lightweight neural network implementation designed to run on microcontrollers like the **ESP32** and **Arduino**. It demonstrates how even resource-constrained devices can train and perform simple tasks like **XOR** prediction. Maybe you’ll find a use case for simple robot projects. |
| 15 | +This project is a lightweight neural network implementation designed to run on microcontrollers like the **ESP32** and **Arduino**. It demonstrates how even resource-constrained devices can train and perform simple tasks like **XOR** prediction. Maybe you’ll find a use case for simple robot or sensor projects. |
16 | 16 |
|
17 | | -While it takes just some **seconds** to train on the ESP32, the Arduino requires significantly more time due to limited processing power. |
| 17 | +The project has two supported modes, inference and training mode. Inference mode uses an existing torch model and converts it to a header file, which can be loaded to your esp or arduino. |
| 18 | +For fun or testing purposes, you can also run your training directly on the microchip, but for larger models, the performance gets weak pretty fast and you run into memory constraints. |
18 | 19 |
|
19 | 20 | ## 📎 [Blog to this project](https://medium.com/@FrozenAssassine/neural-network-from-scratch-on-esp32-2a53a7b65f9f) |
20 | 21 |
|
21 | 22 | ## 🛠️ Features |
22 | | -- **On-device training**: Train your neural network directly on ESP32 or Arduino. |
23 | | -- **XOR**: Predict simple numbers like in xor. |
24 | | -- **Activation Functions**: Use activation functions like Sigmoid, Relu, Softmax, TanH and LeakyRelu |
25 | | -- **Fast Training**: The ESP32 can train in just a few seconds, while the Arduino requires longer due to its slow processor. |
| 23 | + |
| 24 | +- **Inference only**: Use a python script to convert your pytorch models to include file for esp32 and Arduino. |
| 25 | +- **On-device training**: Train your neural network directly on ESP32 or Arduino (no weight saving atm). |
| 26 | +- **Activation Functions**: Use activation functions like Softmax, Sigmoid, Relu, TanH and LeakyRelu |
26 | 27 | - **Xavier Initialization**: Optimizes weight distribution for faster training. |
| 28 | +- **Simple building structure**: The oop approach makes building the initial model really simple. |
| 29 | + |
27 | 30 | ## 🔮 Future features |
28 | | -- Train on PC and load weights to chip |
29 | | -- Save and load weights |
30 | | -- More layer types |
31 | 31 |
|
32 | | -## 🚀 Performance |
33 | | -- ESP32: Fast training (~seconds). |
34 | | -- Arduino: Slower training (~minutes or more). |
| 32 | +- Save and load weights from on device training |
| 33 | +- More layer types |
35 | 34 |
|
36 | 35 | ## 🫶 Code considerations |
| 36 | + |
37 | 37 | I tried to keep the code as simple and easy to understand as possible. The neural network is completely built using OOP principles, which means that everything is its own class. This is useful for structuring the model later. |
38 | | -For the individual layers, I used the basic principle of inheritance, where I have a BaseLayer class and each layer inherits from it. The BaseLayer also implements some functions, like Train and FeedForward, as well as pointers to the weights, values, biases, and errors. In my inherited classes, I only have to override these functions with the training logic and variable implementations. This is very useful when adding new layers. |
| 38 | +For the individual layers, I used the basic principle of inheritance, where there is a BaseLayer class and each layer inherits from it. The BaseLayer also implements some functions, for Training and FeedForward, as well as pointers to the weights, values, biases, and errors. In the inherited classes, those functions can be overriden with the actual training logic and variable implementations. This is very useful for adding new layers. |
39 | 39 |
|
40 | | -## 🏗️ How to Use |
| 40 | +## 🏗️ Run the code |
41 | 41 |
|
42 | | -1. Clone this repository and open the project in Arduino IDE. |
43 | | -2. Upload the code to your ESP32 or Arduino using Arduino IDE |
| 42 | +1. Clone this repository and open the project with PlatformIO. |
| 43 | +2. Upload the code to your ESP32 or Arduino |
44 | 44 | 3. Monitor the predictions via Serial Monitor at 115200 baud rate. |
45 | 45 |
|
46 | | -Here is an example code: |
| 46 | +## 1. Training mode |
47 | 47 |
|
48 | 48 | ```cpp |
49 | | -#include "Layers.h" |
50 | | -#include "NeuralNetwork.h" |
51 | | - |
52 | | -void setup() { |
53 | | - Serial.begin(115200); |
| 49 | +#include "nn/layers.h" |
| 50 | +#include "nn/neuralNetwork.h" |
| 51 | +#include <nn/predictionHelper.h> |
| 52 | +#include <Arduino.h> |
54 | 53 |
|
| 54 | +void TrainAndTest() |
| 55 | +{ |
55 | 56 | NeuralNetwork *nn = new NeuralNetwork(3); |
56 | 57 | nn->StackLayer(new InputLayer(2)); |
57 | 58 | nn->StackLayer(new DenseLayer(4, ActivationKind::TanH)); |
58 | | - nn->StackLayer(new OutputLayer(1, ActivationKind::Sigmoid)); |
59 | | - nn->Build(); |
| 59 | + nn->StackLayer(new OutputLayer(2, ActivationKind::Softmax)); |
| 60 | + nn->Build(false); // training and prediction |
| 61 | + |
| 62 | + float inputs[4][2] = { |
| 63 | + {0, 0}, |
| 64 | + {0, 1}, |
| 65 | + {1, 0}, |
| 66 | + {1, 1}}; |
| 67 | + |
| 68 | + float desired[4][2] = { |
| 69 | + {1, 0}, |
| 70 | + {0, 1}, |
| 71 | + {0, 1}, |
| 72 | + {1, 0}}; |
| 73 | + |
| 74 | + nn->Train((float *)inputs, (float *)desired, 4, 2, 220, 0.1); |
| 75 | + |
| 76 | + Serial.println("Predictions:"); |
| 77 | + for (uint8_t i = 0; i < 4; i++) |
| 78 | + { |
| 79 | + float *pred = nn->Predict(inputs[i], 2); |
| 80 | + Serial.printf( |
| 81 | + "Input: [%.0f, %.0f] -> Softmax: [%.4f, %.4f] -> Class: %d\n", |
| 82 | + inputs[i][0], inputs[i][1], pred[0], pred[1], ArgMax(pred, 2)); |
| 83 | + } |
| 84 | +} |
60 | 85 |
|
61 | | - float inputs[4][2] = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } }; |
62 | | - float desired[4][1] = { { 0 }, { 1 }, { 1 }, { 0 } }; |
| 86 | +void setup() |
| 87 | +{ |
| 88 | + Serial.begin(115200); |
| 89 | + delay(1000); |
| 90 | + |
| 91 | + TrainAndTest(); |
| 92 | +} |
| 93 | +void loop() { } |
| 94 | +``` |
63 | 95 |
|
64 | | - nn->Train((float*)inputs, (float*)desired, 4, 2, 600, 0.1f); |
| 96 | +**Output:** |
65 | 97 |
|
66 | | - // Predict XOR results: |
67 | | - for (int i = 0; i < 4; i++) { |
| 98 | +``` |
| 99 | +Training Done! |
| 100 | +Predictions: |
| 101 | +Input: [0, 0] -> Softmax: [0.9665, 0.0335] -> Class: 0 |
| 102 | +Input: [0, 1] -> Softmax: [0.0324, 0.9676] -> Class: 1 |
| 103 | +Input: [1, 0] -> Softmax: [0.0783, 0.9217] -> Class: 1 |
| 104 | +Input: [1, 1] -> Softmax: [0.9355, 0.0645] -> Class: 0 |
| 105 | +``` |
| 106 | +
|
| 107 | +## 2. Inference only |
| 108 | +
|
| 109 | +```cpp |
| 110 | +#include "nn/layers.h" |
| 111 | +#include "nn/neuralNetwork.h" |
| 112 | +#include <nn/predictionHelper.h> |
| 113 | +#include <Arduino.h> |
| 114 | +
|
| 115 | +void InferenceOnly() |
| 116 | +{ |
| 117 | + Serial.println("Testing model inference only (XOR Classification)"); |
| 118 | +
|
| 119 | + NeuralNetwork *nn = new NeuralNetwork(3); |
| 120 | + nn->StackLayer(new InputLayer(2)); |
| 121 | + nn->StackLayer(new DenseLayer(4, ActivationKind::TanH)); |
| 122 | + nn->StackLayer(new OutputLayer(2, ActivationKind::Softmax)); |
| 123 | + nn->Build(true); // inference only |
| 124 | +
|
| 125 | + float inputs[4][2] = { |
| 126 | + {0, 0}, |
| 127 | + {0, 1}, |
| 128 | + {1, 0}, |
| 129 | + {1, 1}}; |
| 130 | +
|
| 131 | + Serial.println("Predictions:"); |
| 132 | + for (uint8_t i = 0; i < 4; i++) |
| 133 | + { |
68 | 134 | float *pred = nn->Predict(inputs[i], 2); |
69 | | - Serial.print("PREDICTION "); |
70 | | - Serial.print(inputs[i][0]); |
71 | | - Serial.print(" "); |
72 | | - Serial.print(inputs[i][1]); |
73 | | - Serial.print(" = "); |
74 | | - Serial.println(pred[0]); |
| 135 | + Serial.printf( |
| 136 | + "Input: [%.0f, %.0f] -> Softmax: [%.4f, %.4f] -> Class: %d\n", |
| 137 | + inputs[i][0], inputs[i][1], pred[0], pred[1], ArgMax(pred, 2)); |
75 | 138 | } |
76 | 139 | } |
77 | 140 |
|
78 | | -void loop() { |
| 141 | +void setup() |
| 142 | +{ |
| 143 | + Serial.begin(115200); |
79 | 144 | delay(1000); |
| 145 | +
|
| 146 | + InferenceOnly(); |
80 | 147 | } |
| 148 | +void loop() { } |
81 | 149 | ``` |
82 | 150 |
|
83 | | -# 📷 Images: |
84 | | - |
| 151 | +**Output:** |
85 | 152 |
|
| 153 | +``` |
| 154 | +Testing model inference only (XOR Classification) |
| 155 | +Predictions: |
| 156 | +Input: [0, 0] -> Softmax: [0.9523, 0.0477] -> Class: 0 |
| 157 | +Input: [0, 1] -> Softmax: [0.0702, 0.9298] -> Class: 1 |
| 158 | +Input: [1, 0] -> Softmax: [0.0817, 0.9183] -> Class: 1 |
| 159 | +Input: [1, 1] -> Softmax: [0.9112, 0.0888] -> Class: 0 |
| 160 | +``` |
0 commit comments