Skip to content

Commit 14d9d5e

Browse files
author
Lukas Molzberger
committed
performance optimization: cache accumulated weights during interpretation search
1 parent 6f9a7b6 commit 14d9d5e

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

src/main/java/org/aika/corpus/SearchNode.java

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,13 @@ public void dumpDebugState() {
185185
}
186186

187187

188-
private double search(Document doc, int[] searchSteps, Candidate[] candidates) {
188+
private NormWeight search(Document doc, int[] searchSteps, Candidate[] candidates) {
189189
if(candidate == null) {
190190
return processResult(doc);
191191
}
192192

193-
double selectedWeight = 0.0;
194-
double excludedWeight = 0.0;
193+
NormWeight selectedWeight = NormWeight.ZERO_WEIGHT;
194+
NormWeight excludedWeight = NormWeight.ZERO_WEIGHT;
195195

196196
boolean alreadySelected = checkSelected(refinement);
197197
boolean alreadyExcluded = checkExcluded(refinement, doc.visitedCounter++);
@@ -218,7 +218,7 @@ private double search(Document doc, int[] searchSteps, Candidate[] candidates) {
218218
debugState = DebugState.EXPLORE;
219219
}
220220

221-
Boolean cd = !alreadyExcluded && !alreadySelected ? getCachedDecision() : null;
221+
CachedEntry cd = !alreadyExcluded && !alreadySelected ? getCachedDecision() : null;
222222

223223
candidate.debugCounts[debugState.ordinal()]++;
224224

@@ -229,22 +229,22 @@ private double search(Document doc, int[] searchSteps, Candidate[] candidates) {
229229
markSelected(changed, refinement);
230230
markExcluded(changed, refinement);
231231

232-
if (cd == null || cd) {
232+
if (cd == null || (cd.dir && accumulatedWeight.add(cd.weight).getNormWeight() >= getSelectedAccumulatedWeight(doc))) {
233233
Candidate c = candidates.length > level + 1 ? candidates[level + 1] : null;
234234
SearchNode child = new SearchNode(doc, this, excludedParent, c, level + 1, changed);
235235
selectedWeight = child.search(doc, searchSteps, candidates);
236236
child.changeState(StateChange.Mode.OLD);
237237
}
238238
}
239239
if(doc.interrupted) {
240-
return 0.0;
240+
return NormWeight.ZERO_WEIGHT;
241241
}
242242

243243
if(!alreadySelected) {
244244
candidate.refinement.markedExcludedRefinement = true;
245245
List<InterprNode> changed = Collections.singletonList(candidate.refinement);
246246

247-
if (cd == null || !cd) {
247+
if (cd == null || (!cd.dir && accumulatedWeight.add(cd.weight).getNormWeight() >= getSelectedAccumulatedWeight(doc))) {
248248
Candidate c = candidates.length > level + 1 ? candidates[level + 1] : null;
249249
SearchNode child = new SearchNode(doc, selectedParent, this, c, level + 1, changed);
250250
excludedWeight = child.search(doc, searchSteps, candidates);
@@ -254,23 +254,29 @@ private double search(Document doc, int[] searchSteps, Candidate[] candidates) {
254254
candidate.refinement.markedExcludedRefinement = false;
255255
}
256256

257+
boolean dir = selectedWeight.getNormWeight() >= excludedWeight.getNormWeight();
257258
if(cd == null && !alreadyExcluded && !alreadySelected) {
258-
candidate.cache.put(this, selectedWeight >= excludedWeight);
259+
candidate.cache.put(this, new CachedEntry(dir, dir ? selectedWeight.sub(accumulatedWeight) : excludedWeight.sub(accumulatedWeight)));
259260
}
260261

261-
return Math.max(selectedWeight, excludedWeight);
262+
return dir ? selectedWeight : excludedWeight;
262263
}
263264

264-
private double processResult(Document doc) {
265+
266+
private NormWeight processResult(Document doc) {
265267
double accNW = accumulatedWeight.getNormWeight();
266-
double selectedAccNW = doc.selectedSearchNode != null ? doc.selectedSearchNode.accumulatedWeight.getNormWeight() : -1.0;
267268

268-
if (accNW > selectedAccNW) {
269+
if (accNW > getSelectedAccumulatedWeight(doc)) {
269270
doc.selectedSearchNode = this;
270271
doc.bottom.storeFinalWeight(doc.visitedCounter++);
271272
}
272273

273-
return accNW;
274+
return accumulatedWeight;
275+
}
276+
277+
278+
private double getSelectedAccumulatedWeight(Document doc) {
279+
return doc.selectedSearchNode != null ? doc.selectedSearchNode.accumulatedWeight.getNormWeight() : -1.0;
274280
}
275281

276282

@@ -604,8 +610,8 @@ private boolean getDecision() {
604610
}
605611

606612

607-
public Boolean getCachedDecision() {
608-
x: for(Map.Entry<SearchNode, Boolean> me: candidate.cache.entrySet()) {
613+
public CachedEntry getCachedDecision() {
614+
x: for(Map.Entry<SearchNode, CachedEntry> me: candidate.cache.entrySet()) {
609615
SearchNode n = this;
610616
SearchNode cn = me.getKey();
611617
do {
@@ -640,8 +646,19 @@ public boolean affectsUnknown(SearchNode p) {
640646
}
641647

642648

649+
private static class CachedEntry {
650+
boolean dir;
651+
NormWeight weight;
652+
653+
private CachedEntry(boolean dir, NormWeight weight) {
654+
this.dir = dir;
655+
this.weight = weight;
656+
}
657+
}
658+
659+
643660
private static class Candidate implements Comparable<Candidate> {
644-
public TreeMap<SearchNode, Boolean> cache = new TreeMap<>();
661+
public TreeMap<SearchNode, CachedEntry> cache = new TreeMap<>();
645662
public InterprNode refinement;
646663

647664
int[] debugCounts = new int[3];

0 commit comments

Comments
 (0)