Skip to content

Commit 700fc7a

Browse files
author
Maarten
committed
Make parallellization prettier
1 parent 9901844 commit 700fc7a

File tree

1 file changed

+25
-17
lines changed
  • src/main/java/org/leibnizcenter/cfg/earleyparser/chart

1 file changed

+25
-17
lines changed

src/main/java/org/leibnizcenter/cfg/earleyparser/chart/Chart.java

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public class Chart<T> {
4747
public final Grammar<T> grammar;
4848
private final ParseOptions<T> callbacks;
4949

50+
private boolean parallelizePredict = false;
51+
private boolean parallelizeScan = false;
5052
private boolean parallelizeComplete = true;
5153

5254
/**
@@ -221,19 +223,20 @@ public void predict(int i, TokenWithCategories<T> token) {
221223
void predict(int index) {
222224
// O(|stateset(i)|) = O(|grammar|): For all states <code>i: X<sub>k</sub> → λ·Zμ</code>...
223225
final Set<State> activeOnNonTerminals = stateSets.activeStates.getActiveOnNonTerminals(index);
224-
if (activeOnNonTerminals != null)
226+
if (activeOnNonTerminals != null) {
225227
// Copy set to avoid concurrent modification
226-
new HashSet<>(activeOnNonTerminals).stream().parallel()
227-
228+
HashSet<State> activeOnNonTerminalsCp = new HashSet<>(activeOnNonTerminals);
229+
(parallelizePredict ? activeOnNonTerminalsCp.parallelStream() : activeOnNonTerminalsCp.stream())
228230
// For all productions Y → value such that R(Z =*L> Y) is nonzero
229-
.flatMap(grammar::streamNonZeroLeftStarRulesWithPrecedingState).parallel()
230-
231+
.flatMap(grammar::streamNonZeroLeftStarRulesWithPrecedingState)
231232
// we predict state <code>i: Y<sub>i</sub> → ·value</code>
232233
.map(statePredecessor_Y_to_v -> predictNextStateAndScores(index, statePredecessor_Y_to_v))
233234

234235
// Now that we've calculated the scores, add to chart...
235236
.sequential()
237+
236238
.forEach(stateSets::setScores);
239+
}
237240
}
238241

239242
public void scan(int i, TokenWithCategories<T> token) {
@@ -257,7 +260,7 @@ public void scan(int i, TokenWithCategories<T> token) {
257260
* @param scanProbability Function that provides the probability of scanning the given token at this position. Might be null for a probability of 1.0.
258261
*/
259262
@SuppressWarnings("WeakerAccess")
260-
public void scan(
263+
void scan(
261264
final int tokenPosition,
262265
final TokenWithCategories<T> tokenWithCategories,
263266
final ScanProbability<T> scanProbability
@@ -277,15 +280,14 @@ public void scan(
277280
* Get all states that are active on a terminal
278281
* O(|stateset(i)|) = O(|grammar|): For all states <code>i: X<sub>k</sub> → λ·tμ</code>, where t is a terminal that matches the given token...
279282
*/
280-
tokenWithCategories.categories.stream()
281-
.parallel()
283+
final Set<Terminal<T>> categories = tokenWithCategories.categories;
284+
(parallelizeScan ? categories.parallelStream() : categories.stream())
282285
.flatMap((final Terminal<T> terminalType) -> {
283286
final Set<State> statesActiveOnTerminals = stateSets.activeStates.getActiveOn(tokenPosition, terminalType);
284287
return statesActiveOnTerminals == null
285288
? Stream.empty()
286289
: statesActiveOnTerminals.stream();
287290
})
288-
.parallel() // Parallellize for performance: everything we do in map does not mutate state
289291
.map(preScanState -> new Scan.Delta<>(
290292
token,
291293
preScanState,
@@ -335,10 +337,9 @@ private void completeNoViterbi(final int position,
335337

336338
/* Safe to parallelize here */
337339
if (parallelizeComplete) stream = stream.parallel();
338-
stream = stream.flatMap(stateSets.activeStates::streamAllStatesToAdvance);
339-
if (parallelizeComplete) stream = stream.parallel();
340340

341341
List<Complete.Delta> deltas = stream
342+
.flatMap(stateSets.activeStates::streamAllStatesToAdvance)
342343
.map(stateInformation -> completeNoViterbiForTriple(
343344
position,
344345
addInnerScores.getOrCreate(stateInformation.stateToAdvance, stateSets.innerScores.getAtom(stateInformation.stateToAdvance)),
@@ -383,7 +384,7 @@ private void completeNoViterbi(final int position,
383384
* @param completedState Completed state to calculate Viterbi score for
384385
*/
385386
@SuppressWarnings("WeakerAccess")
386-
public void computeViterbiScoresForCompletedState(
387+
private void computeViterbiScoresForCompletedState(
387388
final State completedState
388389
) {
389390
if (stateSets.viterbiScores.get(completedState) == null)
@@ -435,15 +436,22 @@ private Complete.ViterbiDelta computeViterbiForState(State completedState, doubl
435436
: null;
436437
}
437438

438-
private State.ViterbiScore getNewViterbiScore(State completedState, double completedViterbi, State stateToAdvance, State resultingState) {
439+
private State.ViterbiScore getNewViterbiScore(
440+
State completedState,
441+
double completedViterbi,
442+
State stateToAdvance,
443+
State resultingState
444+
) {
445+
final double oldViterbiScore = getViterbiScore(stateToAdvance).innerScore;
446+
439447
return new State.ViterbiScore(
440-
stateSets.grammar.semiring.times(
448+
grammar.semiring.times(
441449
completedViterbi,
442-
stateSets.viterbiScores.get(stateToAdvance).innerScore // must be set
450+
oldViterbiScore // must be set
443451
),
444452
completedState,
445453
resultingState,
446-
stateSets.grammar.semiring
454+
grammar.semiring
447455
);
448456
}
449457

@@ -452,7 +460,7 @@ private State.ViterbiScore getNewViterbiScore(State completedState, double compl
452460
*
453461
* @param i The index to make completions at.
454462
*/
455-
void completeNoViterbi(
463+
private void completeNoViterbi(
456464
final int i
457465
) {
458466
ExpressionSemiring semiring = grammar.semiring;

0 commit comments

Comments
 (0)