Skip to content

Commit 6edd62e

Browse files
fixed training
1 parent 92a3a2d commit 6edd62e

File tree

5 files changed

+130
-15
lines changed

5 files changed

+130
-15
lines changed

src/main/java/org/aika/Neuron.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,31 @@ public Activation addInput(Document doc, int begin, int end) {
5858
return addInput(doc, begin, end, null, doc.bottom);
5959
}
6060

61+
/**
62+
* Propagate an input activation into the network.
63+
*
64+
* @param doc The current document
65+
* @param begin The range begin
66+
* @param end The range end
67+
* @param value The activation value of this input activation
68+
*/
69+
public Activation addInput(Document doc, int begin, int end, double value) {
70+
return addInput(doc, begin, end, null, doc.bottom, value);
71+
}
72+
73+
/**
74+
* Propagate an input activation into the network.
75+
*
76+
* @param doc The current document
77+
* @param begin The range begin
78+
* @param end The range end
79+
* @param value The activation value of this input activation
80+
* @param targetValue The target activation value for supervised learning
81+
*/
82+
public Activation addInput(Document doc, int begin, int end, double value, double targetValue) {
83+
return addInput(doc, begin, end, null, doc.bottom, value, targetValue);
84+
}
85+
6186

6287
/**
6388
* Propagate an input activation into the network.
@@ -110,7 +135,22 @@ public Activation addInput(Document doc, int begin, int end, Integer rid, Interp
110135
* @param value The activation value of this input activation
111136
*/
112137
public Activation addInput(Document doc, int begin, int end, Integer rid, InterprNode o, double value) {
113-
return get().addInput(doc, begin, end, rid, o, value);
138+
return addInput(doc, begin, end, rid, o, value, 0.0);
139+
}
140+
141+
/**
142+
* Propagate an input activation into the network.
143+
*
144+
* @param doc The current document
145+
* @param begin The range begin
146+
* @param end The range end
147+
* @param rid The relational id (e.g. the word position)
148+
* @param o The interpretation node
149+
* @param value The activation value of this input activation
150+
* @param targetValue The target activation value for supervised learning
151+
*/
152+
public Activation addInput(Document doc, int begin, int end, Integer rid, InterprNode o, double value, double targetValue) {
153+
return get().addInput(doc, begin, end, rid, o, value, targetValue);
114154
}
115155

116156

src/main/java/org/aika/corpus/Document.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ public class Document implements Comparable<Document> {
7575
public TreeSet<INeuron> activatedNeurons = new TreeSet<>();
7676
public TreeSet<INeuron> finallyActivatedNeurons = new TreeSet<>();
7777
public TreeSet<Activation> inputNeuronActivations = new TreeSet<>();
78+
public TreeSet<Activation> targetActivations = new TreeSet<>();
79+
public TreeSet<Activation> errorSignalActivations = new TreeSet<>();
80+
7881
public TreeMap<NodeActivation.Key, NodeActivation> activationsByRid = new TreeMap<>(new Comparator<NodeActivation.Key>() {
7982
@Override
8083
public int compare(NodeActivation.Key act1, NodeActivation.Key act2) {
@@ -259,18 +262,18 @@ public TrainConfig setPerformBackpropagation(boolean performBackpropagation) {
259262

260263

261264
public void train(TrainConfig trainConfig) {
265+
for(Activation tAct: targetActivations) {
266+
tAct.key.n.neuron.get().computeOutputErrorSignal(this, tAct);
267+
}
268+
262269
if(trainConfig.performBackpropagation) {
263270
bQueue.backpropagtion();
264271
}
265272

266-
for (INeuron n : finallyActivatedNeurons) {
267-
ThreadState<OrNode, Activation> th = n.node.get().getThreadState(threadId, false);
268-
if (th != null) {
269-
for (Activation act : th.activations.values()) {
270-
n.train(this, act, trainConfig.learnRate, trainConfig.synapseEvaluation);
271-
}
272-
}
273+
for (Activation act : errorSignalActivations) {
274+
act.key.n.neuron.get().train(this, act, trainConfig.learnRate, trainConfig.synapseEvaluation);
273275
}
276+
errorSignalActivations.clear();
274277
}
275278

276279
/**
@@ -368,7 +371,6 @@ public String neuronActivationsToString(boolean withWeights, boolean withTextSni
368371
Activation.State s = me.getValue();
369372
sb.append("[R:" + me.getKey());
370373
sb.append(" VALUE:" + Utils.round(s.value));
371-
sb.append(" F:" + s.fired);
372374
sb.append(" W:" + Utils.round(s.weight.w));
373375
sb.append(" N:" + Utils.round(s.weight.n));
374376
sb.append("]");
@@ -655,7 +657,7 @@ public void backpropagtion() {
655657
Activation act = queue.pollFirst();
656658

657659
act.isQueued = false;
658-
act.key.n.neuron.get().computeErrorSignal(Document.this, act);
660+
act.key.n.neuron.get().computeBackpropagationErrorSignal(Document.this, act);
659661
}
660662
}
661663
}

src/main/java/org/aika/neuron/Activation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ public final class Activation extends NodeActivation<OrNode> {
4646
public long currentStateV;
4747
public StateChange currentStateChange;
4848

49-
public double initialErrorSignal;
5049
public double errorSignal;
50+
public double targetValue;
5151

5252

5353
public Activation(int id, Document doc, Key key) {

src/main/java/org/aika/neuron/INeuron.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public INeuron(Model m, String label, String outputText) {
121121
* @param o The interpretation node
122122
* @param value The activation value of this input activation
123123
*/
124-
public Activation addInput(Document doc, int begin, int end, Integer rid, InterprNode o, double value) {
124+
public Activation addInput(Document doc, int begin, int end, Integer rid, InterprNode o, double value, double targetValue) {
125125
Node.addActivationAndPropagate(doc, new NodeActivation.Key(node.get(), new Range(begin, end), rid, o), Collections.emptySet());
126126

127127
doc.propagate();
@@ -130,6 +130,7 @@ public Activation addInput(Document doc, int begin, int end, Integer rid, Interp
130130
State s = new State(value, 0, NormWeight.ZERO_WEIGHT);
131131
act.rounds.set(0, s);
132132
act.finalState = s;
133+
act.targetValue = targetValue;
133134
act.upperBound = value;
134135
act.isInput = true;
135136

@@ -138,6 +139,10 @@ public Activation addInput(Document doc, int begin, int end, Integer rid, Interp
138139

139140
doc.ubQueue.add(act);
140141

142+
if(targetValue != 0.0) {
143+
doc.targetActivations.add(act);
144+
}
145+
141146
doc.propagate();
142147

143148
return act;
@@ -294,15 +299,29 @@ public InputState(SynapseActivation sa, State s) {
294299
}
295300

296301

297-
public void computeErrorSignal(Document doc, Activation act) {
298-
act.errorSignal = act.initialErrorSignal;
302+
public void computeOutputErrorSignal(Document doc, Activation act) {
303+
act.errorSignal += act.targetValue - act.finalState.value;
304+
305+
if(act.errorSignal != 0.0) {
306+
doc.errorSignalActivations.add(act);
307+
}
308+
for (SynapseActivation sa : act.neuronInputs) {
309+
doc.bQueue.add(sa.input);
310+
}
311+
}
312+
313+
314+
public void computeBackpropagationErrorSignal(Document doc, Activation act) {
299315
for (SynapseActivation sa : act.neuronOutputs) {
300316
Synapse s = sa.s;
301317
Activation oAct = sa.output;
302318

303319
act.errorSignal += s.w * oAct.errorSignal * (1.0 - act.finalState.value);
304320
}
305321

322+
if(act.errorSignal != 0.0) {
323+
doc.errorSignalActivations.add(act);
324+
}
306325
for (SynapseActivation sa : act.neuronInputs) {
307326
doc.bQueue.add(sa.input);
308327
}
@@ -315,7 +334,7 @@ public void train(Document doc, Activation targetAct, double learnRate, Document
315334
long v = doc.visitedCounter++;
316335

317336
double x = learnRate * targetAct.errorSignal;
318-
bias += x;
337+
bias = Math.min(0.0, bias + x);
319338
for (INeuron n : doc.finallyActivatedNeurons) {
320339
for(Activation iAct: n.getFinalActivations(doc)) {
321340
Synapse.Key sk = se.evaluate(iAct, targetAct);

src/test/java/org/aika/network/TrainingTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.aika.neuron.Activation;
2424
import org.aika.neuron.INeuron;
2525
import org.aika.neuron.Synapse;
26+
import org.junit.Assert;
2627
import org.junit.Test;
2728

2829
/**
@@ -64,6 +65,59 @@ public void testTraining() {
6465

6566
doc = m.createDocument("Bla");
6667
in.addInput(doc, 0, 3, 0, doc.bottom, 1.0);
68+
}
69+
70+
71+
@Test
72+
public void testTraining1() {
73+
Model m = new Model();
74+
75+
Neuron inA = m.createNeuron("A");
76+
Neuron inB = m.createNeuron("B");
77+
78+
Neuron outC = m.createNeuron("C");
79+
80+
{
81+
Document doc = m.createDocument("Bla");
82+
inA.addInput(doc, 0, 3, 1.0);
83+
inB.addInput(doc, 0, 3, 1.0);
84+
85+
doc.process();
86+
87+
88+
outC.addInput(doc, 0, 3, 0.0, 1.0);
89+
90+
doc.train(
91+
new Document.TrainConfig()
92+
.setLearnRate(2.0)
93+
.setPerformBackpropagation(false)
94+
.setSynapseEvaluation((iAct, oAct) -> new Synapse.Key(
95+
false,
96+
0,
97+
null,
98+
Range.Operator.EQUALS,
99+
Range.Mapping.START,
100+
true,
101+
Range.Operator.EQUALS,
102+
Range.Mapping.END,
103+
true
104+
))
105+
);
106+
107+
doc.clearActivations();
108+
}
109+
110+
{
111+
Document doc = m.createDocument("Bla");
112+
inA.addInput(doc, 0, 3, 1.0);
113+
inB.addInput(doc, 0, 3, 1.0);
114+
115+
doc.process();
116+
117+
System.out.println(doc.neuronActivationsToString(true, false, true));
118+
Assert.assertFalse(outC.getFinalActivations(doc).isEmpty());
67119

120+
doc.clearActivations();
121+
}
68122
}
69123
}

0 commit comments

Comments
 (0)