Skip to content

Commit 872c356

Browse files
authored
Merge pull request #6 from gabrielmfern/preprocessing
add a way to preprocess inputs for the training options and the fit m…
2 parents 2bc11e4 + 528af7b commit 872c356

File tree

7 files changed

+1129
-1085
lines changed

7 files changed

+1129
-1085
lines changed

Cargo.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "intricate"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
edition = "2021"
55
license = "MIT"
66
authors = ["Gabriel Miranda"]

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,16 @@ xor_model
185185
&mut TrainingOptions {
186186
loss_fn: &mut loss, // the type of loss function that should be used for Intricate
187187
// to determine how bad the Model is
188+
// these two functions are quite useful for a Model that needs to work with very large
189+
// data that will cost a lot of RAM and computing
190+
from_inputs_to_vectors: &(|inputs| Ok(inputs.to_vec())), // a function to
191+
// preprocess the inputs
192+
from_expected_outputs_to_vectors: &(|outputs| Ok(outputs.to_vec())), // a function
193+
// to
194+
// preprocess
195+
// the
196+
// expected
197+
// outputs
188198
verbosity: TrainingVerbosity {
189199
show_current_epoch: true, // show a message for each epoch like `epoch #5`
190200
show_epoch_progress: false, // show a progress bar of the training steps in a
@@ -208,7 +218,6 @@ xor_model
208218
)
209219
.unwrap();
210220
```
211-
212221
As you can see it is extremely easy creating these models, and blazingly fast as well.
213222

214223
---
@@ -250,3 +259,4 @@ to use the Model after loading it, you **must** call the `init` method in the `l
250259
- add a way to send into the training process a callback closure that would be called everytime a epoch finished or even a step too with some cool info
251260
- make an example after doing the thing above ^, that uses that same function to plot the loss realtime using a crate like `textplots`
252261
- add embedding layers for text such as bag of words with an expected vocabulary size
262+
- add a better way to define training options as to not need to write such large code when there is no need for it

examples/xor/main.rs

Lines changed: 100 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,101 @@
1-
use intricate::layers::activations::TanH;
2-
use intricate::layers::Dense;
3-
4-
use intricate::loss_functions::MeanSquared;
5-
use intricate::optimizers::BasicOptimizer;
6-
use intricate::types::{ModelLayer, TrainingOptions, TrainingVerbosity, HaltingCondition};
7-
use intricate::utils::opencl::DeviceType;
8-
use intricate::utils::setup_opencl;
9-
use intricate::Model;
10-
11-
use savefile::{load_file, save_file};
12-
13-
fn main() -> () {
14-
// Defining the training data
15-
let training_inputs: Vec<Vec<f32>> = vec![
16-
vec![0.0, 0.0],
17-
vec![0.0, 1.0],
18-
vec![1.0, 0.0],
19-
vec![1.0, 1.0],
20-
];
21-
22-
let expected_outputs: Vec<Vec<f32>> = vec![
23-
vec![0.0],
24-
vec![1.0],
25-
vec![1.0],
26-
vec![0.0],
27-
];
28-
29-
// Defining the layers for our XoR Model
30-
let layers: Vec<ModelLayer> = vec![
31-
Dense::new(2, 3),
32-
TanH::new(3),
33-
Dense::new(3, 1),
34-
TanH::new(1),
35-
];
36-
37-
// Actually instantiate the Model with the layers
38-
let mut xor_model = Model::new(layers);
39-
// you can change this to DeviceType::GPU if you want
40-
let opencl_state = setup_opencl(DeviceType::CPU).unwrap();
41-
xor_model.init(&opencl_state).unwrap();
42-
43-
let mut loss = MeanSquared::new();
44-
let mut optimizer = BasicOptimizer::new(0.1);
45-
46-
// Fit the model however many times we want
47-
xor_model
48-
.fit(
49-
&training_inputs,
50-
&expected_outputs,
51-
&mut TrainingOptions {
52-
loss_fn: &mut loss, // the type of loss function that should be used for Intricate
53-
// to determine how bad the Model is
54-
verbosity: TrainingVerbosity {
55-
show_current_epoch: true, // show a message for each epoch like `epoch #5`
56-
show_epoch_progress: false, // show a progress bar of the training steps in a
57-
// epoch
58-
show_epoch_elapsed: true, // show elapsed time in calculations for one epoch
59-
print_accuracy: true, // should print the accuracy after each epoch
60-
print_loss: true, // should print the loss after each epoch
61-
halting_condition_warning: true,
62-
},
63-
// a condition for stopping the training if a min accuracy is reached
64-
halting_condition: Some(HaltingCondition::MinAccuracyReached(0.95)),
65-
compute_accuracy: false, // if Intricate should compute the accuracy after each
66-
// training step
67-
compute_loss: true, // if Intricate should compute the loss after each training
68-
// step
69-
optimizer: &mut optimizer,
70-
batch_size: 4, // the size of the mini-batch being used in Intricate's Mini-batch
71-
// Gradient Descent
72-
epochs: 10000,
73-
},
74-
)
75-
.unwrap();
76-
77-
// for saving Intricate uses the 'savefile' crate
78-
// that simply needs to call the 'save_file' function to the path you want
79-
// for the Model as follows
80-
xor_model.sync_data_from_buffers_to_host().unwrap();
81-
save_file("xor-model.bin", 0, &xor_model).unwrap();
82-
83-
// as for loading we can just call the 'load_file' function
84-
// on the path we saved to before
85-
let mut loaded_xor_model: Model = load_file("xor-model.bin", 0).unwrap();
86-
loaded_xor_model.init(&opencl_state).unwrap();
87-
88-
loaded_xor_model.predict(&training_inputs).unwrap();
89-
xor_model.predict(&training_inputs).unwrap();
90-
91-
let model_prediction = xor_model.get_last_prediction().unwrap();
92-
let loaded_model_prediction = loaded_xor_model.get_last_prediction().unwrap();
93-
94-
assert_eq!(loaded_model_prediction, model_prediction);
1+
use intricate::layers::activations::TanH;
2+
use intricate::layers::Dense;
3+
4+
use intricate::loss_functions::MeanSquared;
5+
use intricate::optimizers::BasicOptimizer;
6+
use intricate::types::{HaltingCondition, ModelLayer, TrainingOptions, TrainingVerbosity};
7+
use intricate::utils::opencl::DeviceType;
8+
use intricate::utils::setup_opencl;
9+
use intricate::Model;
10+
11+
use savefile::{load_file, save_file};
12+
13+
fn main() -> () {
14+
// Defining the training data
15+
let training_inputs: Vec<Vec<f32>> = vec![
16+
vec![0.0, 0.0],
17+
vec![0.0, 1.0],
18+
vec![1.0, 0.0],
19+
vec![1.0, 1.0],
20+
];
21+
22+
let expected_outputs: Vec<Vec<f32>> = vec![vec![0.0], vec![1.0], vec![1.0], vec![0.0]];
23+
24+
// Defining the layers for our XoR Model
25+
let layers: Vec<ModelLayer> = vec![
26+
Dense::new(2, 3),
27+
TanH::new(3),
28+
Dense::new(3, 1),
29+
TanH::new(1),
30+
];
31+
32+
// Actually instantiate the Model with the layers
33+
let mut xor_model = Model::new(layers);
34+
// you can change this to DeviceType::GPU if you want
35+
let opencl_state = setup_opencl(DeviceType::CPU).unwrap();
36+
xor_model.init(&opencl_state).unwrap();
37+
38+
let mut loss = MeanSquared::new();
39+
let mut optimizer = BasicOptimizer::new(0.1);
40+
41+
// Fit the model however many times we want
42+
xor_model
43+
.fit(
44+
&training_inputs,
45+
&expected_outputs,
46+
&mut TrainingOptions {
47+
loss_fn: &mut loss, // the type of loss function that should be used for Intricate
48+
// to determine how bad the Model is
49+
// these two functions are quite useful for a Model that needs to work with very large
50+
// data that will cost a lot of RAM and computing
51+
from_inputs_to_vectors: &(|inputs| Ok(inputs.to_vec())), // a function to
52+
// preprocess the inputs
53+
from_expected_outputs_to_vectors: &(|outputs| Ok(outputs.to_vec())), // a function
54+
// to
55+
// preprocess
56+
// the
57+
// expected
58+
// outputs
59+
verbosity: TrainingVerbosity {
60+
show_current_epoch: true, // show a message for each epoch like `epoch #5`
61+
show_epoch_progress: false, // show a progress bar of the training steps in a
62+
// epoch
63+
show_epoch_elapsed: true, // show elapsed time in calculations for one epoch
64+
print_accuracy: true, // should print the accuracy after each epoch
65+
print_loss: true, // should print the loss after each epoch
66+
halting_condition_warning: true,
67+
},
68+
// a condition for stopping the training if the Model gets to 95%
69+
// accuracy
70+
halting_condition: Some(HaltingCondition::MinAccuracyReached(0.95)),
71+
compute_accuracy: false, // if Intricate should compute the accuracy after each
72+
// training step
73+
compute_loss: true, // if Intricate should compute the loss after each training
74+
// step
75+
optimizer: &mut optimizer,
76+
batch_size: 4, // the size of the mini-batch being used in Intricate's Mini-batch
77+
// Gradient Descent
78+
epochs: 10000,
79+
},
80+
)
81+
.unwrap();
82+
83+
// for saving Intricate uses the 'savefile' crate
84+
// that simply needs to call the 'save_file' function to the path you want
85+
// for the Model as follows
86+
xor_model.sync_data_from_buffers_to_host().unwrap();
87+
save_file("xor-model.bin", 0, &xor_model).unwrap();
88+
89+
// as for loading we can just call the 'load_file' function
90+
// on the path we saved to before
91+
let mut loaded_xor_model: Model = load_file("xor-model.bin", 0).unwrap();
92+
loaded_xor_model.init(&opencl_state).unwrap();
93+
94+
loaded_xor_model.predict(&training_inputs).unwrap();
95+
xor_model.predict(&training_inputs).unwrap();
96+
97+
let model_prediction = xor_model.get_last_prediction().unwrap();
98+
let loaded_model_prediction = loaded_xor_model.get_last_prediction().unwrap();
99+
100+
assert_eq!(loaded_model_prediction, model_prediction);
95101
}

0 commit comments

Comments
 (0)