Skip to content

Commit 355c974

Browse files
Merge pull request #2 from FrozenAssassine/pio-loadpythonmodel-fixes
feat: use pio, load torch models from python, inference mode, code fixes
2 parents b98207a + 6652aab commit 355c974

32 files changed

+1154
-82519
lines changed

.gitignore

Lines changed: 6 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,7 @@
1-
app/bin/
2-
app/pde.jar
3-
build/macosx/work/
4-
arduino-core/bin/
5-
arduino-core/arduino-core.jar
6-
hardware/arduino/bootloaders/caterina_LUFA/Descriptors.o
7-
hardware/arduino/bootloaders/caterina_LUFA/Descriptors.lst
8-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.sym
9-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.o
10-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.map
11-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.lst
12-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.lss
13-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.elf
14-
hardware/arduino/bootloaders/caterina_LUFA/Caterina.eep
15-
hardware/arduino/bootloaders/caterina_LUFA/.dep/
16-
build/*.zip
17-
build/*.tar.bz2
18-
build/windows/work/
19-
build/windows/*.zip
20-
build/windows/*.tgz
21-
build/windows/*.tar.bz2
22-
build/windows/libastylej*
23-
build/windows/liblistSerials*
24-
build/windows/arduino-*.zip
25-
build/windows/dist/*.tar.gz
26-
build/windows/dist/*.tar.bz2
27-
build/windows/launch4j-*.tgz
28-
build/windows/launch4j-*.zip
29-
build/windows/launcher/launch4j
30-
build/windows/WinAVR-*.zip
31-
build/macosx/arduino-*.zip
32-
build/macosx/dist/*.tar.gz
33-
build/macosx/dist/*.tar.bz2
34-
build/macosx/*.tar.bz2
35-
build/macosx/libastylej*
36-
build/macosx/appbundler*.jar
37-
build/macosx/appbundler*.zip
38-
build/macosx/appbundler
39-
build/macosx/appbundler-1.0ea-arduino?
40-
build/macosx/appbundler-1.0ea-arduino*.zip
41-
build/macosx/appbundler-1.0ea-upstream*.zip
42-
build/linux/work/
43-
build/linux/dist/*.tar.gz
44-
build/linux/dist/*.tar.bz2
45-
build/linux/*.tgz
46-
build/linux/*.tar.xz
47-
build/linux/*.tar.bz2
48-
build/linux/*.zip
49-
build/linux/libastylej*
50-
build/linux/liblistSerials*
51-
build/shared/arduino-examples*
52-
build/shared/reference*.zip
53-
build/shared/Edison*.zip
54-
build/shared/Galileo*.zip
55-
build/shared/WiFi101-Updater-ArduinoIDE-Plugin*.zip
56-
test-bin
57-
*.iml
58-
.idea
59-
.DS_Store
60-
.directory
61-
hardware/arduino/avr/libraries/Bridge/examples/XivelyClient/passwords.h
62-
avr-toolchain-*.zip
63-
/app/nbproject/private/
64-
/arduino-core/nbproject/private/
65-
/app/build/
66-
/arduino-core/build/
1+
.pio
2+
.vscode/.browse.c_cpp.db*
3+
.vscode/c_cpp_properties.json
4+
.vscode/launch.json
5+
.vscode/ipch
676

68-
manifest.mf
69-
nbbuild.xml
70-
nbproject
71-
main/esp32.svd
72-
main/debug.cfg
73-
main/debug_custom.json
7+
python/out

README.md

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,74 +12,149 @@
1212

1313
## 🤔 What is this project?
1414

15-
This project is a lightweight neural network implementation designed to run on microcontrollers like the **ESP32** and **Arduino**. It demonstrates how even resource-constrained devices can train and perform simple tasks like **XOR** prediction. Maybe you’ll find a use case for simple robot projects.
15+
This project is a lightweight neural network implementation designed to run on microcontrollers like the **ESP32** and **Arduino**. It demonstrates how even resource-constrained devices can train and perform simple tasks like **XOR** prediction. Maybe you’ll find a use case for simple robot or sensor projects.
1616

17-
While it takes just some **seconds** to train on the ESP32, the Arduino requires significantly more time due to limited processing power.
17+
The project has two supported modes, inference and training mode. Inference mode uses an existing torch model and converts it to a header file, which can be loaded to your esp or arduino.
18+
For fun or testing purposes, you can also run your training directly on the microchip, but for larger models, the performance gets weak pretty fast and you run into memory constraints.
1819

1920
## 📎 [Blog to this project](https://medium.com/@FrozenAssassine/neural-network-from-scratch-on-esp32-2a53a7b65f9f)
2021

2122
## 🛠️ Features
22-
- **On-device training**: Train your neural network directly on ESP32 or Arduino.
23-
- **XOR**: Predict simple numbers like in xor.
24-
- **Activation Functions**: Use activation functions like Sigmoid, Relu, Softmax, TanH and LeakyRelu
25-
- **Fast Training**: The ESP32 can train in just a few seconds, while the Arduino requires longer due to its slow processor.
23+
24+
- **Inference only**: Use a python script to convert your pytorch models to include file for esp32 and Arduino.
25+
- **On-device training**: Train your neural network directly on ESP32 or Arduino (no weight saving atm).
26+
- **Activation Functions**: Use activation functions like Softmax, Sigmoid, Relu, TanH and LeakyRelu
2627
- **Xavier Initialization**: Optimizes weight distribution for faster training.
28+
- **Simple building structure**: The oop approach makes building the initial model really simple.
29+
2730
## 🔮 Future features
28-
- Train on PC and load weights to chip
29-
- Save and load weights
30-
- More layer types
3131

32-
## 🚀 Performance
33-
- ESP32: Fast training (~seconds).
34-
- Arduino: Slower training (~minutes or more).
32+
- Save and load weights from on device training
33+
- More layer types
3534

3635
## 🫶 Code considerations
36+
3737
I tried to keep the code as simple and easy to understand as possible. The neural network is completely built using OOP principles, which means that everything is its own class. This is useful for structuring the model later.
38-
For the individual layers, I used the basic principle of inheritance, where I have a BaseLayer class and each layer inherits from it. The BaseLayer also implements some functions, like Train and FeedForward, as well as pointers to the weights, values, biases, and errors. In my inherited classes, I only have to override these functions with the training logic and variable implementations. This is very useful when adding new layers.
38+
For the individual layers, I used the basic principle of inheritance, where there is a BaseLayer class and each layer inherits from it. The BaseLayer also implements some functions, for Training and FeedForward, as well as pointers to the weights, values, biases, and errors. In the inherited classes, those functions can be overriden with the actual training logic and variable implementations. This is very useful for adding new layers.
3939

40-
## 🏗️ How to Use
40+
## 🏗️ Run the code
4141

42-
1. Clone this repository and open the project in Arduino IDE.
43-
2. Upload the code to your ESP32 or Arduino using Arduino IDE
42+
1. Clone this repository and open the project with PlatformIO.
43+
2. Upload the code to your ESP32 or Arduino
4444
3. Monitor the predictions via Serial Monitor at 115200 baud rate.
4545

46-
Here is an example code:
46+
## 1. Training mode
4747

4848
```cpp
49-
#include "Layers.h"
50-
#include "NeuralNetwork.h"
51-
52-
void setup() {
53-
Serial.begin(115200);
49+
#include "nn/layers.h"
50+
#include "nn/neuralNetwork.h"
51+
#include <nn/predictionHelper.h>
52+
#include <Arduino.h>
5453

54+
void TrainAndTest()
55+
{
5556
NeuralNetwork *nn = new NeuralNetwork(3);
5657
nn->StackLayer(new InputLayer(2));
5758
nn->StackLayer(new DenseLayer(4, ActivationKind::TanH));
58-
nn->StackLayer(new OutputLayer(1, ActivationKind::Sigmoid));
59-
nn->Build();
59+
nn->StackLayer(new OutputLayer(2, ActivationKind::Softmax));
60+
nn->Build(false); // training and prediction
61+
62+
float inputs[4][2] = {
63+
{0, 0},
64+
{0, 1},
65+
{1, 0},
66+
{1, 1}};
67+
68+
float desired[4][2] = {
69+
{1, 0},
70+
{0, 1},
71+
{0, 1},
72+
{1, 0}};
73+
74+
nn->Train((float *)inputs, (float *)desired, 4, 2, 220, 0.1);
75+
76+
Serial.println("Predictions:");
77+
for (uint8_t i = 0; i < 4; i++)
78+
{
79+
float *pred = nn->Predict(inputs[i], 2);
80+
Serial.printf(
81+
"Input: [%.0f, %.0f] -> Softmax: [%.4f, %.4f] -> Class: %d\n",
82+
inputs[i][0], inputs[i][1], pred[0], pred[1], ArgMax(pred, 2));
83+
}
84+
}
6085

61-
float inputs[4][2] = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } };
62-
float desired[4][1] = { { 0 }, { 1 }, { 1 }, { 0 } };
86+
void setup()
87+
{
88+
Serial.begin(115200);
89+
delay(1000);
90+
91+
TrainAndTest();
92+
}
93+
void loop() { }
94+
```
6395
64-
nn->Train((float*)inputs, (float*)desired, 4, 2, 600, 0.1f);
96+
**Output:**
6597
66-
// Predict XOR results:
67-
for (int i = 0; i < 4; i++) {
98+
```
99+
Training Done!
100+
Predictions:
101+
Input: [0, 0] -> Softmax: [0.9665, 0.0335] -> Class: 0
102+
Input: [0, 1] -> Softmax: [0.0324, 0.9676] -> Class: 1
103+
Input: [1, 0] -> Softmax: [0.0783, 0.9217] -> Class: 1
104+
Input: [1, 1] -> Softmax: [0.9355, 0.0645] -> Class: 0
105+
```
106+
107+
## 2. Inference only
108+
109+
```cpp
110+
#include "nn/layers.h"
111+
#include "nn/neuralNetwork.h"
112+
#include <nn/predictionHelper.h>
113+
#include <Arduino.h>
114+
115+
void InferenceOnly()
116+
{
117+
Serial.println("Testing model inference only (XOR Classification)");
118+
119+
NeuralNetwork *nn = new NeuralNetwork(3);
120+
nn->StackLayer(new InputLayer(2));
121+
nn->StackLayer(new DenseLayer(4, ActivationKind::TanH));
122+
nn->StackLayer(new OutputLayer(2, ActivationKind::Softmax));
123+
nn->Build(true); // inference only
124+
125+
float inputs[4][2] = {
126+
{0, 0},
127+
{0, 1},
128+
{1, 0},
129+
{1, 1}};
130+
131+
Serial.println("Predictions:");
132+
for (uint8_t i = 0; i < 4; i++)
133+
{
68134
float *pred = nn->Predict(inputs[i], 2);
69-
Serial.print("PREDICTION ");
70-
Serial.print(inputs[i][0]);
71-
Serial.print(" ");
72-
Serial.print(inputs[i][1]);
73-
Serial.print(" = ");
74-
Serial.println(pred[0]);
135+
Serial.printf(
136+
"Input: [%.0f, %.0f] -> Softmax: [%.4f, %.4f] -> Class: %d\n",
137+
inputs[i][0], inputs[i][1], pred[0], pred[1], ArgMax(pred, 2));
75138
}
76139
}
77140
78-
void loop() {
141+
void setup()
142+
{
143+
Serial.begin(115200);
79144
delay(1000);
145+
146+
InferenceOnly();
80147
}
148+
void loop() { }
81149
```
82150

83-
# 📷 Images:
84-
![image](https://github.com/user-attachments/assets/4b32f9ee-a1e9-4b4f-b626-1c4d5d9a3861)
151+
**Output:**
85152

153+
```
154+
Testing model inference only (XOR Classification)
155+
Predictions:
156+
Input: [0, 0] -> Softmax: [0.9523, 0.0477] -> Class: 0
157+
Input: [0, 1] -> Softmax: [0.0702, 0.9298] -> Class: 1
158+
Input: [1, 0] -> Softmax: [0.0817, 0.9183] -> Class: 1
159+
Input: [1, 1] -> Softmax: [0.9112, 0.0888] -> Class: 0
160+
```

firmware/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.pio
2+
.vscode/.browse.c_cpp.db*
3+
.vscode/c_cpp_properties.json
4+
.vscode/launch.json
5+
.vscode/ipch

firmware/.vscode/extensions.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
// See http://go.microsoft.com/fwlink/?LinkId=827846
3+
// for the documentation about the extensions.json format
4+
"recommendations": [
5+
"platformio.platformio-ide"
6+
],
7+
"unwantedRecommendations": [
8+
"ms-vscode.cpptools-extension-pack"
9+
]
10+
}

firmware/include/nn_trained.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "nn/layerData.h"
4+
5+
// Example arrays (small XOR-like model)
6+
static float layer0_weights[8] = {1.694597f, 1.308419f, -1.314386f, -0.903650f, -1.036660f, 2.091955f, -2.517021f, 2.006923f};
7+
static float layer0_bias[4] = {-0.033063f, -0.264372f, 0.208891f, -1.040136f};
8+
9+
static float layer1_weights[8] = {-1.449604f, 0.621033f, 1.519490f, -1.582430f, 0.827401f, -0.810482f, -1.936075f, 1.971920f};
10+
static float layer1_bias[2] = {-0.410137f, -0.222768f};
11+
12+
static const LayerData nn_layers[] = {
13+
{nullptr, nullptr, 0, 2},
14+
{layer0_weights, layer0_bias, 2, 4},
15+
{layer1_weights, layer1_bias, 4, 2}};
16+
17+
static const uint8_t nn_total_layers = 3;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# NeuralNetwork (PlatformIO/Arduino library)
2+
3+
Small, header/source based neural network library for Arduino/PlatformIO.
4+
5+
Usage
6+
7+
- Copy the `NeuralNetwork` folder into your project's `lib/` directory, or add it via `lib_extra_dirs`/`lib_deps`.
8+
- Include public headers like:
9+
10+
#include <nn/neuralNetwork.h>
11+
#include <nn/layers.h>
12+
13+
Example
14+
15+
- See `examples/BasicExample` for a minimal inference sketch using the built-in trained model.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
struct LayerData
6+
{
7+
float *weights;
8+
float *bias;
9+
uint16_t inputSize;
10+
uint16_t outputSize;
11+
};

0 commit comments

Comments
 (0)