Skip to content

Commit 1df3a35

Browse files
committed
refactor(web): maintain correction-path cost thresholding
We now track maximum individual probability for any one input in the full, source fat-finger distribution. For some tokenizations, the corresponding input will not be considered, but that should not affect thresholding behaviors. Build-bot: skip build:web Test-bot: skip
1 parent 3bfe361 commit 1df3a35

File tree

8 files changed

+126
-75
lines changed

8 files changed

+126
-75
lines changed

web/src/engine/predictive-text/worker-thread/src/main/correction/context-state.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ export class ContextState {
192192
*/
193193
analyzeTransition(
194194
context: Context,
195-
transformDistribution?: Distribution<Transform>,
195+
transformDistribution: Distribution<Transform>,
196196
// overrides checks for token substitution that can fail for large applied suggestions.
197197
isApplyingSuggestion?: boolean
198198
): ContextTransition {
@@ -245,8 +245,11 @@ export class ContextState {
245245
// and then fold all resulting search spaces (on the final token) into one.
246246
const tokenizationAnalysis = trueInputSubset.pendingSet.get(baseTokenization);
247247

248+
// Determine the best probability from among ALL available inputs, before they're split
249+
// into subsets.
250+
const bestProb = transformDistribution.reduce((best, curr) => Math.max(best, curr.p), 0);
248251
// Should gain one per subsetBuilder.subsets entry.
249-
const resultTokenization = baseTokenization.evaluateTransition(tokenizationAnalysis, lexicalModel, trueInput);
252+
const resultTokenization = baseTokenization.evaluateTransition(tokenizationAnalysis, lexicalModel, trueInput, bestProb);
250253

251254
// ------------
252255

web/src/engine/predictive-text/worker-thread/src/main/correction/context-token.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import Transform = LexicalModelTypes.Transform;
2424
export interface TokenInputSource {
2525
trueTransform: Transform;
2626
inputStartIndex: number;
27+
bestProbFromSet: number;
2728
}
2829

2930
/**
@@ -129,9 +130,10 @@ export class ContextToken {
129130
rawTransformDistributions.forEach((entry) => {
130131
this._inputRange.push({
131132
trueTransform: entry[0].sample,
132-
inputStartIndex: 0
133+
inputStartIndex: 0,
134+
bestProbFromSet: 1
133135
});
134-
this.searchSpace.addInput(entry);
136+
this.searchSpace.addInput(entry, 1);
135137
});
136138
}
137139
}
@@ -142,7 +144,7 @@ export class ContextToken {
142144
*/
143145
addInput(inputSource: TokenInputSource, distribution: Distribution<Transform>) {
144146
this._inputRange.push(inputSource);
145-
this.searchSpace.addInput(distribution);
147+
this.searchSpace.addInput(distribution, inputSource.bestProbFromSet);
146148
}
147149

148150
/**
@@ -350,7 +352,8 @@ export class ContextToken {
350352
backupToken = new ContextToken(constructingToken);
351353
constructingToken.addInput({
352354
trueTransform: priorSourceInput.trueTransform,
353-
inputStartIndex: priorSourceInput.inputStartIndex + extraCharsAdded
355+
inputStartIndex: priorSourceInput.inputStartIndex + extraCharsAdded,
356+
bestProbFromSet: priorSourceInput.bestProbFromSet
354357
}, tailDistribution);
355358

356359
const lenToCommit = lenBeforeLastApply + extraCharsAdded;

web/src/engine/predictive-text/worker-thread/src/main/correction/context-tokenization.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,17 @@ export class ContextTokenization {
497497
* @param lexicalModel The active lexical model
498498
* @param sourceInput The Transform associated with the keystroke triggering
499499
* the transition.
500+
* @param bestProbFromSet The probability of the single most likely input
501+
* transform in the overall transformDistribution associated with the
502+
* keystroke triggering theh transition. It need not be represented by the
503+
* pendingTokenization to be built.
500504
* @returns
501505
*/
502506
evaluateTransition(
503507
pendingTokenization: PendingTokenization,
504508
lexicalModel: LexicalModel,
505-
sourceInput: Transform
509+
sourceInput: Transform,
510+
bestProbFromSet: number
506511
): ContextTokenization {
507512
const { alignment: alignment, inputs } = pendingTokenization;
508513
const sliceIndex = alignment.edgeWindow.sliceIndex;
@@ -581,7 +586,7 @@ export class ContextTokenization {
581586
if(affectedToken.inputRange.length == 0 && distribution[0].sample.deleteLeft != 0) {
582587
distribution = distribution.map((mass) => ({sample: { ...mass.sample, deleteLeft: 0 }, p: mass.p }));
583588
}
584-
affectedToken.addInput({trueTransform: sourceInput, inputStartIndex: appliedLength}, distribution);
589+
affectedToken.addInput({trueTransform: sourceInput, inputStartIndex: appliedLength, bestProbFromSet: bestProbFromSet}, distribution);
585590
appliedLength += KMWString.length(distribution[0].sample.insert);
586591

587592
const tokenize = determineModelTokenizer(lexicalModel);

web/src/engine/predictive-text/worker-thread/src/main/correction/distance-modeler.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,12 @@ export class SearchSpace {
652652
*/
653653
private processedEdgeSet: {[pathKey: string]: boolean} = {};
654654

655+
/**
656+
* Provides a heuristic for the base cost at each depth if the best
657+
* individual input were taken at that level.
658+
*/
659+
private lowestCostAtDepth: number[];
660+
655661
/**
656662
* Clone constructor. Deep-copies its internal queues, but not search nodes.
657663
* @param instance
@@ -670,6 +676,7 @@ export class SearchSpace {
670676
this.rootNode = arg1.rootNode;
671677
// Re-use already-checked Nodes.
672678
this.completedPaths = [].concat(arg1.completedPaths);
679+
this.lowestCostAtDepth = arg1.lowestCostAtDepth.slice();
673680
this.returnedValues = {...arg1.returnedValues};
674681
this.processedEdgeSet = {...arg1.processedEdgeSet};
675682

@@ -688,6 +695,7 @@ export class SearchSpace {
688695
this.selectionQueue = new PriorityQueue<SearchNode>(QUEUE_NODE_COMPARATOR);
689696
this.rootNode = new SearchNode(model.traverseFromRoot(), model.toKey ? model.toKey.bind(model) : null);
690697
this.selectionQueue.enqueue(this.rootNode);
698+
this.lowestCostAtDepth = [];
691699

692700
this.completedPaths = [];
693701
}
@@ -722,8 +730,12 @@ export class SearchSpace {
722730
* @param inputDistribution The fat-finger distribution for the incoming keystroke (or
723731
* just the raw keystroke if corrections are disabled)
724732
*/
725-
addInput(inputDistribution: Distribution<Transform>) {
726-
this._inputSequence.push(inputDistribution);
733+
addInput(inputDistribution: Distribution<Transform>, bestProbFromSet: number) {
734+
const input = inputDistribution;
735+
this._inputSequence.push(input);
736+
const lastDepthCost = this.lowestCostAtDepth[this.lowestCostAtDepth.length - 1] ?? 0;
737+
const logTierCost = -Math.log(bestProbFromSet);
738+
this.lowestCostAtDepth.push(lastDepthCost + logTierCost);
727739

728740
// Assumes that `inputDistribution` is already sorted.
729741
this.minInputCost.push(-Math.log(inputDistribution[0].p));
@@ -822,7 +834,8 @@ export class SearchSpace {
822834
// ... or even just not the then-current layer of the keyboard.
823835
//
824836
// TODO: still consider the lowest-cost individual edges for THIS specific criterion.
825-
if(currentNode.currentCost > /* tierMinCost */ + 2.5 * SearchSpace.EDIT_DISTANCE_COST_SCALE) {
837+
const tierMinCost = this.lowestCostAtDepth[currentNode.priorInput.length-1];
838+
if(currentNode.currentCost > tierMinCost + 2.5 * SearchSpace.EDIT_DISTANCE_COST_SCALE) {
826839
return unmatchedResult;
827840
}
828841

web/src/test/auto/headless/engine/predictive-text/worker-thread/context/context-token.tests.ts

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,25 @@ describe('ContextToken', function() {
128128

129129
token1.addInput({
130130
trueTransform: srcTransform,
131-
inputStartIndex: 0
131+
inputStartIndex: 0,
132+
bestProbFromSet: 1
132133
}, [{sample: {insert: 'can', deleteLeft: 0, deleteRight: 0, id: 1}, p: 1}]);
133134

134135
token2.addInput({
135136
trueTransform: srcTransform,
136-
inputStartIndex: 3
137+
inputStartIndex: 3,
138+
bestProbFromSet: 1
137139
}, [{sample: {insert: "'", deleteLeft: 0, deleteRight: 0, id: 1}, p: 1}]);
138140

139141
token3.addInput({
140142
trueTransform: srcTransform,
141-
inputStartIndex: 4
143+
inputStartIndex: 4,
144+
bestProbFromSet: 1
142145
}, [{sample: {insert: 't', deleteLeft: 0, deleteRight: 0, id: 1}, p: 1}]);
143146

144147
const merged = ContextToken.merge([token1, token2, token3], plainModel);
145148
assert.equal(merged.exampleInput, "can't");
146-
assert.deepEqual(merged.inputRange, [ { trueTransform: srcTransform, inputStartIndex: 0 } ]);
149+
assert.deepEqual(merged.inputRange, [ { trueTransform: srcTransform, inputStartIndex: 0, bestProbFromSet: 1 } ]);
147150
assert.deepEqual(merged.searchSpace.inputSequence, [[{sample: srcTransform, p: 1}]]);
148151
});
149152

@@ -168,35 +171,41 @@ describe('ContextToken', function() {
168171

169172
token1.addInput({
170173
trueTransform: srcTransform1,
171-
inputStartIndex: 0
174+
inputStartIndex: 0,
175+
bestProbFromSet: 1
172176
}, [{sample: srcTransform1, p: 1}]);
173177
token1.addInput({
174178
trueTransform: srcTransform2,
175-
inputStartIndex: 0
179+
inputStartIndex: 0,
180+
bestProbFromSet: 1
176181
}, [{sample: {insert: 's', deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
177182

178183
token2.addInput({
179184
trueTransform: srcTransform2,
180-
inputStartIndex: 1
185+
inputStartIndex: 1,
186+
bestProbFromSet: 1
181187
}, [{sample: {insert: "and", deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
182188

183189
token3.addInput({
184190
trueTransform: srcTransform2,
185-
inputStartIndex: 4
191+
inputStartIndex: 4,
192+
bestProbFromSet: 1
186193
}, [{sample: {insert: 's', deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
187194
token3.addInput({
188195
trueTransform: srcTransform3,
189-
inputStartIndex: 0
196+
inputStartIndex: 0,
197+
bestProbFromSet: 1
190198
}, [{sample: srcTransform3, p: 1}]);
191199

192200
token4.addInput({
193201
trueTransform: srcTransform4,
194-
inputStartIndex: 0
202+
inputStartIndex: 0,
203+
bestProbFromSet: 1
195204
}, [{sample: srcTransform4, p: 1}]);
196205

197206
const merged = ContextToken.merge(tokensToMerge, plainModel);
198207
assert.equal(merged.exampleInput, "applesandsourgrapes");
199-
assert.deepEqual(merged.inputRange, srcTransforms.map((t) => ({ trueTransform: t, inputStartIndex: 0 }) ));
208+
assert.deepEqual(merged.inputRange, srcTransforms.map((t) => ({ trueTransform: t, inputStartIndex: 0, bestProbFromSet: 1 }) ));
200209
assert.deepEqual(merged.searchSpace.inputSequence, srcTransforms.map((t) => [{sample: t, p: 1}]));
201210
});
202211

@@ -221,35 +230,41 @@ describe('ContextToken', function() {
221230

222231
token1.addInput({
223232
trueTransform: srcTransform1,
224-
inputStartIndex: 0
233+
inputStartIndex: 0,
234+
bestProbFromSet: 1
225235
}, [{sample: srcTransform1, p: 1}]);
226236
token1.addInput({
227237
trueTransform: srcTransform2,
228-
inputStartIndex: 0
238+
inputStartIndex: 0,
239+
bestProbFromSet: 1
229240
}, [{sample: {insert: toMathematicalSMP('s'), deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
230241

231242
token2.addInput({
232243
trueTransform: srcTransform2,
233-
inputStartIndex: 1
244+
inputStartIndex: 1,
245+
bestProbFromSet: 1
234246
}, [{sample: {insert: toMathematicalSMP("and"), deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
235247

236248
token3.addInput({
237249
trueTransform: srcTransform2,
238-
inputStartIndex: 4
250+
inputStartIndex: 4,
251+
bestProbFromSet: 1
239252
}, [{sample: {insert: toMathematicalSMP('s'), deleteLeft: 0, deleteRight: 0, id: 2}, p: 1}]);
240253
token3.addInput({
241254
trueTransform: srcTransform3,
242-
inputStartIndex: 0
255+
inputStartIndex: 0,
256+
bestProbFromSet: 1
243257
}, [{sample: srcTransform3, p: 1}]);
244258

245259
token4.addInput({
246260
trueTransform: srcTransform4,
247-
inputStartIndex: 0
261+
inputStartIndex: 0,
262+
bestProbFromSet: 1
248263
}, [{sample: srcTransform4, p: 1}]);
249264

250265
const merged = ContextToken.merge(tokensToMerge, plainModel);
251266
assert.equal(merged.exampleInput, toMathematicalSMP("applesandsourgrapes"));
252-
assert.deepEqual(merged.inputRange, srcTransforms.map((t) => ({ trueTransform: t, inputStartIndex: 0 }) ));
267+
assert.deepEqual(merged.inputRange, srcTransforms.map((t) => ({ trueTransform: t, inputStartIndex: 0, bestProbFromSet: 1 }) ));
253268
assert.deepEqual(merged.searchSpace.inputSequence, srcTransforms.map((t) => [{sample: t, p: 1}]));
254269
});
255270
});
@@ -278,7 +293,7 @@ describe('ContextToken', function() {
278293

279294
const tokenToSplit = new ContextToken(plainModel);
280295
for(let i = 0; i < keystrokeDistributions.length; i++) {
281-
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0}, keystrokeDistributions[i]);
296+
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0, bestProbFromSet: .75}, keystrokeDistributions[i]);
282297
};
283298

284299
assert.equal(tokenToSplit.sourceText, 'can\'');
@@ -316,7 +331,7 @@ describe('ContextToken', function() {
316331

317332
const tokenToSplit = new ContextToken(plainModel);
318333
for(let i = 0; i < keystrokeDistributions.length; i++) {
319-
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0}, keystrokeDistributions[i]);
334+
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0, bestProbFromSet: 1}, keystrokeDistributions[i]);
320335
};
321336

322337
assert.equal(tokenToSplit.sourceText, 'biglargetransform');
@@ -343,7 +358,9 @@ describe('ContextToken', function() {
343358
insert: 'biglargetransform',
344359
deleteLeft: 0,
345360
deleteRight: 0
346-
}, inputStartIndex: i
361+
},
362+
inputStartIndex: i,
363+
bestProbFromSet: 1
347364
})));
348365
assert.sameDeepOrderedMembers(resultsOfSplit.map(t => t.searchSpace.inputSequence[0]), splitTextArray.map(t => [{
349366
sample: { insert: t, deleteLeft: 0, deleteRight: 0 }, p: 1
@@ -365,7 +382,7 @@ describe('ContextToken', function() {
365382

366383
const tokenToSplit = new ContextToken(plainModel);
367384
for(let i = 0; i < keystrokeDistributions.length; i++) {
368-
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0}, keystrokeDistributions[i]);
385+
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0, bestProbFromSet: 1}, keystrokeDistributions[i]);
369386
};
370387

371388
assert.equal(tokenToSplit.exampleInput, 'largelongtransforms');
@@ -388,15 +405,15 @@ describe('ContextToken', function() {
388405
assert.equal(resultsOfSplit.length, 3);
389406
assert.sameOrderedMembers(resultsOfSplit.map(t => t.exampleInput), splitTextArray);
390407
assert.deepEqual(resultsOfSplit[0].inputRange, [
391-
{ trueTransform: keystrokeDistributions[0][0].sample, inputStartIndex: 0 },
392-
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 0 },
408+
{ trueTransform: keystrokeDistributions[0][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
409+
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
393410
]);
394411
assert.deepEqual(resultsOfSplit[1].inputRange, [
395-
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 'arge'.length },
396-
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 0 },
412+
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 'arge'.length, bestProbFromSet: 1 },
413+
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
397414
]);
398415
assert.deepEqual(resultsOfSplit[2].inputRange, [
399-
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 'ng'.length }
416+
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 'ng'.length, bestProbFromSet: 1 }
400417
]);
401418

402419
assert.deepEqual(resultsOfSplit[0].searchSpace.inputSequence, [
@@ -459,7 +476,7 @@ describe('ContextToken', function() {
459476

460477
const tokenToSplit = new ContextToken(plainModel);
461478
for(let i = 0; i < keystrokeDistributions.length; i++) {
462-
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0}, keystrokeDistributions[i]);
479+
tokenToSplit.addInput({trueTransform: keystrokeDistributions[i][0].sample, inputStartIndex: 0, bestProbFromSet: 1}, keystrokeDistributions[i]);
463480
};
464481

465482
assert.equal(tokenToSplit.exampleInput, toMathematicalSMP('largelongtransforms'));
@@ -482,15 +499,15 @@ describe('ContextToken', function() {
482499
assert.equal(resultsOfSplit.length, 3);
483500
assert.sameOrderedMembers(resultsOfSplit.map(t => t.exampleInput), splitTextArray);
484501
assert.deepEqual(resultsOfSplit[0].inputRange, [
485-
{ trueTransform: keystrokeDistributions[0][0].sample, inputStartIndex: 0 },
486-
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 0 },
502+
{ trueTransform: keystrokeDistributions[0][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
503+
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
487504
]);
488505
assert.deepEqual(resultsOfSplit[1].inputRange, [
489-
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 'arge'.length },
490-
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 0 },
506+
{ trueTransform: keystrokeDistributions[1][0].sample, inputStartIndex: 'arge'.length, bestProbFromSet: 1 },
507+
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 0, bestProbFromSet: 1 },
491508
]);
492509
assert.deepEqual(resultsOfSplit[2].inputRange, [
493-
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 'ng'.length }
510+
{ trueTransform: keystrokeDistributions[2][0].sample, inputStartIndex: 'ng'.length, bestProbFromSet: 1 }
494511
]);
495512

496513
assert.deepEqual(resultsOfSplit[0].searchSpace.inputSequence, [
@@ -550,7 +567,8 @@ describe('preprocessInputSources', () => {
550567

551568
const results = preprocessInputSources(transforms.map((t) => ({
552569
trueTransform: t,
553-
inputStartIndex: 0
570+
inputStartIndex: 0,
571+
bestProbFromSet: 1
554572
})));
555573

556574
assert.equal(results.length, transforms.length);

0 commit comments

Comments
 (0)