-
Notifications
You must be signed in to change notification settings - Fork 1
Home
Welcome to the ANN4j wiki!
This package provides Object oriented Neural Networks for making Explainable Networks. Object Oriented Network structure is helpful for observing each and every element the model. This package is developed for XAI research and development.
- Observable implementation for Artificial Neural Networks (ANN)
- XAI method for relevance propagation
- Stochastic/batch gradient descent
- No hardcoded implementations lets researchers change the parameters as they want.
- Plug and play mnist type data. Other Data files can be handeled via extension
ANN4j is a java package that provides object oriented functionality to neural networks. It implements multilayer perceptrons in java by using Objects instead of matrix multiplications. Every neuron is treated as a seperate object. While this kind of implementation is highly inefficiant when compared to matrix multiplications, this implementation will help research in the fields of Explainable AI. Explainable AI aims at making the model interpretable. By pausing and observing the neural net at different stages, researchers can study neural networks more efficiantly. Indivisual observable interfaces are more easy to observe then matrices. Operations which are difficult to perform on matrices can be performed more easily using this technique.
import ann4j.*;
-
Setting the output file to be output.txt and enabling command line logging.
parameter.setOutputFile("output.txt", true);
-
Setting the number of neurons in each layer.
parameter.setLayerArray(784, 32, 16, 16, 26);
-
Setting the training file to be emnist-letters-train.csv and the file type
parameter.setTrainingFileReader("emnist-letters-train.csv", "mnist");
-
Setting the testing file
parameter.setTestingFileReader("emnist-letters-test.csv", "mnist");
-
Setting the learning rate for weights
parameter.setLearningRate(1);
-
Setting the learning rate for the bias to 1.
parameter.setBiasLearningRate(1);
-
Setting the epsillion value for the relevance propagation algorithm.
parameter.setEpsillion(0);
-
Setting the batch size
parameter.setBatchsize(10);
-
Setting the rectification function.
parameter.setRectificationFunction("sigmoid");
-
Creating a new instance of the Trainer class.
Trainer myTrainer = new Trainer();
-
Training the network with 88800 samples for n epochs
myTrainer.train(m, n);
-
Creating a new instance of the NeuronObserver class this class will observe the neurons and respond when every parameter is changed.
NeuronObserver myNeuronObserver = new NeuronObserver();
-
Testing the network with 9990 samples.
myTrainer.test(9990);
-
Adding the neuron at layer 1 and index 31 to be observed.
myNeuronObserver.addNeuronToBeObserved(1, 31);
-
Training accuracy
myTrainer.getModelEvaluator().getTrainingAccuracy();
-
Testing accuracy
myTrainer.getModelEvaluator().getTrainingAccuracy();
-
Confusion Matrix
myTrainer.getModelEvaluator().printConfusionMatrix();
-
xai algorithm for relevance propagation.
myTrainer.relevancePropagate(2, 3);
-
xai algotithm for most significant input neurons
myTrainer.forwardPropagatewithExclusionInputLayerOnKSamples(2);
ANN4j provides functionality to extend the InputFileReader to add file handling for various types of datasets apart from mnist type files.
In ANN4j, every neuron is an object of its own. Every Neuron can be observed by the NeuronObserver class when the values are updated. NeuronObserver class can be extended as per the requirement of the parameters to be observed. Neurons objects can also be obtined and observed independantly.
-
Get a neuron object from a layer.
myTrainer.getLayerManager().getLayer(layerNum).getNeuron(neuronNum));
-
Get activation of a neuron
neuron.getActivation();
-
Get bias of the neuron
neuron.getBias();
-
Get arraylist of the left or right connections of the neuron
neuron.leftConnections; neuron.rightConnections;
-
Get weight of a connection
connection.getWeight();
-
Example code https://github.com/Aatmaj-Zephyr/ANN4j/blob/main/Main.java
-
Sample output https://github.com/Aatmaj-Zephyr/ANN4j/blob/3721148ec24371bf095e1394fe39fc471f391466/output.txt
-
Sample output https://github.com/Aatmaj-Zephyr/ANN4j/blob/ef0f34b505e6e6316f94b5a660b9ef651582667d/output.txt
- More about Artificial Neural Networks https://www.3blue1brown.com/topics/neural-networks
- Relevance propagation example https://towardsdatascience.com/indepth-layer-wise-relevance-propagation-340f95deb1ea
- Rectification functions https://www.quora.com/What-is-the-purpose-of-rectifier-functions-in-neural-networks
Thank you For visiting. Please share your love by starring this repository and following me!