|
21 | 21 | import java.util.ArrayList;
|
22 | 22 | import java.util.Collection;
|
23 | 23 | import java.util.HashSet;
|
| 24 | +import java.util.Iterator; |
24 | 25 | import java.util.List;
|
25 | 26 | import java.util.PriorityQueue;
|
26 | 27 | import java.util.Set;
|
@@ -267,68 +268,88 @@ default <T, R extends CompactionSSTable> List<T> splitSSTablesInShardsLimited(Co
|
267 | 268 | {
|
268 | 269 | if (coveredShards <= maxParallelism)
|
269 | 270 | return splitSSTablesInShards(sstables, operationRange, numShardsForDensity, maker);
|
270 |
| - // We may be in a simple case where we can reduce the number of shards by some power of 2. |
271 |
| - int multiple = Integer.highestOneBit(coveredShards / maxParallelism); |
272 |
| - if (maxParallelism * multiple == coveredShards) |
273 |
| - return splitSSTablesInShards(sstables, operationRange, numShardsForDensity / multiple, maker); |
274 | 271 |
|
275 | 272 | var shards = splitSSTablesInShards(sstables,
|
276 | 273 | operationRange,
|
277 | 274 | numShardsForDensity,
|
278 | 275 | (rangeSSTables, range) -> Pair.create(Set.copyOf(rangeSSTables), range));
|
| 276 | + |
279 | 277 | return applyMaxParallelism(maxParallelism, maker, shards);
|
280 | 278 | }
|
281 | 279 |
|
282 |
| - private static <T, R extends CompactionSSTable> List<T> applyMaxParallelism(int maxParallelism, BiFunction<Collection<R>, Range<Token>, T> maker, List<Pair<Set<R>, Range<Token>>> shards) |
| 280 | + private static <T, R extends CompactionSSTable> List<T> applyMaxParallelism(int maxParallelism, |
| 281 | + BiFunction<Collection<R>, Range<Token>, T> maker, |
| 282 | + List<Pair<Set<R>, Range<Token>>> shards) |
283 | 283 | {
|
284 |
| - int actualParallelism = shards.size(); |
285 |
| - if (maxParallelism >= actualParallelism) |
286 |
| - { |
287 |
| - // We can fit within the parallelism limit without grouping, because some ranges are empty. |
288 |
| - // This is not expected to happen often, but if it does, take advantage. |
289 |
| - List<T> tasks = new ArrayList<>(); |
290 |
| - for (Pair<Set<R>, Range<Token>> pair : shards) |
291 |
| - tasks.add(maker.apply(pair.left, pair.right)); |
292 |
| - return tasks; |
293 |
| - } |
294 |
| - |
295 |
| - // Otherwise we have to group shards together. Define a target token span per task and greedily group |
296 |
| - // to be as close to it as possible. |
297 |
| - double spanPerTask = shards.stream().map(Pair::right).mapToDouble(t -> t.left.size(t.right)).sum() / maxParallelism; |
298 |
| - double currentSpan = 0; |
299 |
| - Set<R> currentSSTables = new HashSet<>(); |
300 |
| - Token rangeStart = null; |
301 |
| - Token prevEnd = null; |
| 284 | + Iterator<Pair<Set<R>, Range<Token>>> iter = shards.iterator(); |
302 | 285 | List<T> tasks = new ArrayList<>(maxParallelism);
|
303 |
| - for (var pair : shards) |
| 286 | + int shardsRemaining = shards.size(); |
| 287 | + int tasksRemaining = maxParallelism; |
| 288 | + |
| 289 | + if (shardsRemaining > tasksRemaining) |
304 | 290 | {
|
305 |
| - final Token currentEnd = pair.right.right; |
306 |
| - final Token currentStart = pair.right.left; |
307 |
| - double span = currentStart.size(currentEnd); |
308 |
| - if (rangeStart == null) |
309 |
| - rangeStart = currentStart; |
310 |
| - if (currentSpan + span >= spanPerTask - 0.001) // rounding error safety |
| 291 | + double totalSpan = shards.stream().map(Pair::right).mapToDouble(r -> r.left.size(r.right)).sum(); |
| 292 | + double spanPerTask = totalSpan / maxParallelism; |
| 293 | + |
| 294 | + Set<R> currentSSTables = new HashSet<>(); |
| 295 | + Token rangeStart = null; |
| 296 | + double currentSpan = 0; |
| 297 | + |
| 298 | + // While we have more shards to process than there are tasks, we need to bunch shards up into tasks. |
| 299 | + while (shardsRemaining > tasksRemaining) |
311 | 300 | {
|
312 |
| - boolean includeCurrent = currentSpan + span - spanPerTask <= spanPerTask - currentSpan; |
313 |
| - if (includeCurrent) |
314 |
| - currentSSTables.addAll(pair.left); |
315 |
| - tasks.add(maker.apply(currentSSTables, new Range<>(rangeStart, includeCurrent ? currentEnd : prevEnd))); |
316 |
| - currentSpan -= spanPerTask; |
317 |
| - rangeStart = null; |
318 |
| - currentSSTables.clear(); |
319 |
| - if (!includeCurrent) |
320 |
| - { |
321 |
| - currentSSTables.addAll(pair.left); |
| 301 | + Pair<Set<R>, Range<Token>> pair = iter.next(); // shardsRemaining counts the shards so iter can't be exhausted at this point |
| 302 | + Token currentStart = pair.right.left; |
| 303 | + Token currentEnd = pair.right.right; |
| 304 | + double span = currentStart.size(currentEnd); |
| 305 | + |
| 306 | + if (rangeStart == null) |
322 | 307 | rangeStart = currentStart;
|
| 308 | + |
| 309 | + currentSSTables.addAll(pair.left); |
| 310 | + currentSpan += span; |
| 311 | + |
| 312 | + // If there is only one task remaining, we should not issue it until we are processing the last shard. |
| 313 | + // The latter condition is normally guaranteed, but floating point rounding has a very small chance of making the calculations wrong |
| 314 | + if (currentSpan >= spanPerTask && tasksRemaining > 1) |
| 315 | + { |
| 316 | + tasks.add(maker.apply(currentSSTables, new Range<>(rangeStart, currentEnd))); |
| 317 | + --tasksRemaining; |
| 318 | + currentSSTables = new HashSet<>(); |
| 319 | + rangeStart = null; |
| 320 | + currentSpan = 0; |
323 | 321 | }
|
| 322 | + --shardsRemaining; |
324 | 323 | }
|
325 |
| - else |
| 324 | + |
| 325 | + // At this point there are as many tasks remaining as there are shards |
| 326 | + // (this includes the case of issuing a task for the last shard when only one task remains). |
| 327 | + |
| 328 | + // Add any already collected sstables to the next task. |
| 329 | + if (!currentSSTables.isEmpty()) |
| 330 | + { |
| 331 | + assert shardsRemaining > 0; |
| 332 | + Pair<Set<R>, Range<Token>> pair = iter.next(); // shardsRemaining counts the shards so iter can't be exhausted at this point |
326 | 333 | currentSSTables.addAll(pair.left);
|
| 334 | + Token currentEnd = pair.right.right; |
| 335 | + tasks.add(maker.apply(currentSSTables, new Range<>(rangeStart, currentEnd))); |
| 336 | + --tasksRemaining; |
| 337 | + --shardsRemaining; |
| 338 | + } |
| 339 | + assert shardsRemaining == tasksRemaining : shardsRemaining + " != " + tasksRemaining; |
| 340 | + } |
327 | 341 |
|
328 |
| - currentSpan += span; |
329 |
| - prevEnd = currentEnd; |
| 342 | + // If we still have tasks and shards to process, produce one task for each shard. |
| 343 | + while (iter.hasNext()) |
| 344 | + { |
| 345 | + Pair<Set<R>, Range<Token>> pair = iter.next(); // shardsRemaining counts the shards so iter can't be exhausted at this point |
| 346 | + tasks.add(maker.apply(pair.left, pair.right)); |
| 347 | + --tasksRemaining; |
| 348 | + --shardsRemaining; |
330 | 349 | }
|
331 |
| - assert currentSSTables.isEmpty(); |
| 350 | + |
| 351 | + assert tasks.size() == Math.min(maxParallelism, shards.size()) : tasks.size() + " != " + maxParallelism; |
| 352 | + assert shardsRemaining == 0 : shardsRemaining + " != 0"; |
332 | 353 | return tasks;
|
333 | 354 | }
|
334 | 355 |
|
|
0 commit comments