Skip to content

Commit 248d2c4

Browse files
author
Lukas Molzberger
committed
reimplementation of computeSoftMax
1 parent 7aa82c2 commit 248d2c4

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

src/main/java/network/aika/Document.java

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import network.aika.neuron.activation.Activation.OscillatingActivationsException;
2828
import network.aika.neuron.activation.Position;
2929
import network.aika.neuron.activation.link.Linker;
30+
import network.aika.neuron.activation.search.Option;
3031
import network.aika.neuron.activation.search.SearchNode;
3132
import network.aika.neuron.activation.search.SearchNode.TimeoutException;
3233
import network.aika.neuron.activation.*;
@@ -38,6 +39,7 @@
3839

3940
import static network.aika.neuron.INeuron.Type.*;
4041
import static network.aika.neuron.activation.Activation.CANDIDATE_COMP;
42+
import static network.aika.neuron.activation.search.Decision.SELECTED;
4143
import static network.aika.neuron.activation.search.Decision.UNKNOWN;
4244

4345

@@ -445,24 +447,9 @@ public void storeFinalState() {
445447

446448

447449
private void computeSoftMax() {
448-
/* for (Activation act : activationsById.values()) {
449-
double offset = Double.MAX_VALUE;
450-
for (Option option : act.getOptions()) {
451-
offset = Math.min(offset, Math.log(option.cacheFactor) + option.weight);
452-
}
453-
454-
double norm = 0.0;
455-
for (Option option : act.getOptions()) {
456-
norm += Math.exp(Math.log(option.cacheFactor) + option.weight - offset);
457-
}
458-
459-
for (Option option : act.getOptions()) {
460-
if (option.decision == SELECTED) {
461-
option.p = Math.exp(Math.log(option.cacheFactor) + option.weight - offset) / norm;
462-
}
463-
}
450+
for (Activation act : activationsById.values()) {
451+
act.computeSoftMax();
464452
}
465-
*/
466453
}
467454

468455

src/main/java/network/aika/neuron/activation/Activation.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,23 @@ public boolean match(Predicate<Link> filter) {
777777
}
778778

779779

780+
public void computeSoftMax() {
781+
rootOption.traverse((o) -> o.computeRemainingWeight());
782+
783+
final double[] offset = new double[] {Double.MAX_VALUE};
784+
rootOption.traverse(o -> offset[0] = Math.min(offset[0], Math.log(o.cacheFactor) + o.remainingWeight));
785+
786+
final double[] norm = new double[] {0.0};
787+
rootOption.traverse(o -> norm[0] += Math.log(o.cacheFactor) + o.remainingWeight - offset[0]);
788+
789+
rootOption.traverse(o -> {
790+
if (o.decision == SELECTED) {
791+
o.p = Math.exp(Math.log(o.cacheFactor) + o.remainingWeight - offset[0]) / norm[0];
792+
}
793+
});
794+
}
795+
796+
780797
public String toString() {
781798
return id + " " + getNeuron().getId() + ":" + getINeuron().typeToString() + " " + getLabel() + " " + slotsToString() + " " + identityToString() + " - " +
782799
" UB:" + Utils.round(upperBound) +

src/main/java/network/aika/neuron/activation/search/Option.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
import network.aika.neuron.activation.State;
2323

2424
import java.util.*;
25+
import java.util.concurrent.Callable;
26+
import java.util.function.BiFunction;
27+
import java.util.function.Consumer;
28+
import java.util.function.Predicate;
2529
import java.util.stream.Collectors;
2630

2731
import static network.aika.Document.MAX_ROUND;
@@ -50,6 +54,7 @@ public class Option implements Comparable<Option> {
5054
public Decision decision;
5155

5256
public double weight;
57+
public double remainingWeight;
5358
public int cacheFactor = 1;
5459
public double p;
5560

@@ -154,6 +159,24 @@ public Activation getAct() {
154159
}
155160

156161

162+
public void computeRemainingWeight() {
163+
double sum = 0.0;
164+
for(Option c: children) {
165+
sum += c.weight;
166+
}
167+
168+
remainingWeight = weight - sum;
169+
}
170+
171+
172+
public void traverse(Consumer<Option> f) {
173+
for(Option c: children) {
174+
c.traverse(f);
175+
f.accept(c);
176+
}
177+
}
178+
179+
157180
public String toString() {
158181
StringBuilder sb = new StringBuilder();
159182
sb.append(" snId:" + (searchNode != null ? searchNode.getId() : "-") + " d:" + decision + " cacheFactor:" + cacheFactor + " w:" + Utils.round(weight) + " p:" + p + " value:" + Utils.round(state.value));

src/main/java/network/aika/neuron/activation/search/SearchNode.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ public static void search(Document doc, SearchNode root, long v, Long timeoutInM
233233
returnWeight = sn.finalStep();
234234
returnWeightSum = sn.getWeightSum();
235235

236-
sn.currentChildDecision = UNKNOWN;
237236
sn = sn.parent;
238237
break;
239238
default:
@@ -319,8 +318,6 @@ private boolean prepareStep(Document doc, Decision d) throws OscillatingActivati
319318

320319
act.debugDecisionCounts[d.ordinal()]++;
321320

322-
currentChildDecision = d;
323-
324321
return true;
325322
}
326323

0 commit comments

Comments
 (0)