Skip to content

Commit 77b459c

Browse files
authored
Improve accuracy of write load forecast when shard numbers change (elastic#129990)
1 parent 8acf94b commit 77b459c

File tree

4 files changed

+215
-36
lines changed

4 files changed

+215
-36
lines changed

docs/changelog/129990.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129990
2+
summary: Make forecast write load accurate when shard numbers change
3+
area: Allocation
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ public String toString() {
362362
* <p>If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO.
363363
* If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown
364364
* period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or
365-
* {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should
365+
* {@link AutoShardingType#COOLDOWN_PREVENTED_DECREASE} with the remaining cooldown configured and the number of shards that should
366366
* be configured for the data stream once the remaining cooldown lapses as the target number of shards.
367367
*
368368
* <p>The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE.

x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,12 @@ public ProjectMetadata.Builder withWriteLoadForecastForWriteIndex(String dataStr
108108
}
109109

110110
final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex());
111-
metadata.put(IndexMetadata.builder(writeIndex).indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble()).build(), false);
111+
metadata.put(
112+
IndexMetadata.builder(writeIndex)
113+
.indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble() / writeIndex.getNumberOfShards())
114+
.build(),
115+
false
116+
);
112117

113118
return metadata;
114119
}
@@ -129,25 +134,48 @@ private static void clearPreviousForecast(DataStream dataStream, ProjectMetadata
129134
}
130135
}
131136

137+
/**
138+
* This calculates the weighted average total write-load for all recent indices.
139+
*
140+
* @param indicesWriteLoadWithinMaxAgeRange The indices considered "recent"
141+
* @return The weighted average total write-load. To get the per-shard write load, this number must be divided by the number of shards
142+
*/
132143
// Visible for testing
133144
static OptionalDouble forecastIndexWriteLoad(List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange) {
134-
double totalWeightedWriteLoad = 0;
135-
long totalShardUptime = 0;
145+
double allIndicesWriteLoad = 0;
146+
long allIndicesUptime = 0;
136147
for (IndexWriteLoad writeLoad : indicesWriteLoadWithinMaxAgeRange) {
148+
double totalShardWriteLoad = 0;
149+
long totalShardUptimeInMillis = 0;
150+
long maxShardUptimeInMillis = 0;
137151
for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
138152
final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
139153
final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
140154
if (writeLoadForShard.isPresent()) {
141155
assert uptimeInMillisForShard.isPresent();
142156
double shardWriteLoad = writeLoadForShard.getAsDouble();
143157
long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
144-
totalWeightedWriteLoad += shardWriteLoad * shardUptimeInMillis;
145-
totalShardUptime += shardUptimeInMillis;
158+
totalShardWriteLoad += shardWriteLoad * shardUptimeInMillis;
159+
totalShardUptimeInMillis += shardUptimeInMillis;
160+
maxShardUptimeInMillis = Math.max(maxShardUptimeInMillis, shardUptimeInMillis);
146161
}
147162
}
163+
double weightedAverageShardWriteLoad = totalShardWriteLoad / totalShardUptimeInMillis;
164+
double totalIndexWriteLoad = weightedAverageShardWriteLoad * writeLoad.numberOfShards();
165+
// We need to weight the contribution from each index somehow, but we only know
166+
// the write-load from the final allocation of each shard at rollover time. It's
167+
// possible the index is much older than any of those shards, but we don't have
168+
// any write-load data beyond their lifetime.
169+
// To avoid making assumptions about periods for which we have no data, we'll weight
170+
// each index's contribution to the forecast by the maximum shard uptime observed in
171+
// that index. It should be safe to extrapolate our weighted average out to the
172+
// maximum uptime observed, based on the assumption that write-load is roughly
173+
// evenly distributed across shards of a datastream index.
174+
allIndicesWriteLoad += totalIndexWriteLoad * maxShardUptimeInMillis;
175+
allIndicesUptime += maxShardUptimeInMillis;
148176
}
149177

150-
return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime);
178+
return allIndicesUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(allIndicesWriteLoad / allIndicesUptime);
151179
}
152180

153181
@Override

x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java

Lines changed: 175 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.logging.log4j.Level;
1111
import org.apache.logging.log4j.core.LogEvent;
12+
import org.apache.lucene.util.hnsw.IntToIntFunction;
1213
import org.elasticsearch.cluster.metadata.DataStream;
1314
import org.elasticsearch.cluster.metadata.IndexMetadata;
1415
import org.elasticsearch.cluster.metadata.IndexMetadataStats;
@@ -24,16 +25,19 @@
2425
import org.elasticsearch.test.MockLog;
2526
import org.elasticsearch.threadpool.TestThreadPool;
2627
import org.elasticsearch.threadpool.ThreadPool;
28+
import org.hamcrest.Matcher;
2729
import org.junit.After;
2830
import org.junit.Before;
2931

3032
import java.util.ArrayList;
3133
import java.util.List;
3234
import java.util.Map;
35+
import java.util.Objects;
3336
import java.util.OptionalDouble;
3437
import java.util.concurrent.TimeUnit;
3538
import java.util.concurrent.atomic.AtomicBoolean;
3639
import java.util.concurrent.atomic.AtomicInteger;
40+
import java.util.stream.Collectors;
3741

3842
import static org.elasticsearch.xpack.writeloadforecaster.LicensedWriteLoadForecaster.forecastIndexWriteLoad;
3943
import static org.hamcrest.Matchers.closeTo;
@@ -42,6 +46,7 @@
4246
import static org.hamcrest.Matchers.equalTo;
4347
import static org.hamcrest.Matchers.greaterThan;
4448
import static org.hamcrest.Matchers.is;
49+
import static org.hamcrest.Matchers.lessThan;
4550

4651
public class LicensedWriteLoadForecasterTests extends ESTestCase {
4752
ThreadPool threadPool;
@@ -67,33 +72,15 @@ public void testWriteLoadForecastIsAddedToWriteIndex() {
6772

6873
writeLoadForecaster.refreshLicense();
6974

70-
final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
7175
final String dataStreamName = "logs-es";
7276
final int numberOfBackingIndices = 10;
73-
final int numberOfShards = randomIntBetween(1, 5);
74-
final List<Index> backingIndices = new ArrayList<>();
75-
for (int i = 0; i < numberOfBackingIndices; i++) {
76-
final IndexMetadata indexMetadata = createIndexMetadata(
77-
DataStream.getDefaultBackingIndexName(dataStreamName, i),
78-
numberOfShards,
79-
randomIndexWriteLoad(numberOfShards),
80-
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
81-
);
82-
backingIndices.add(indexMetadata.getIndex());
83-
metadataBuilder.put(indexMetadata, false);
84-
}
85-
86-
final IndexMetadata writeIndexMetadata = createIndexMetadata(
87-
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
88-
numberOfShards,
89-
null,
90-
System.currentTimeMillis()
77+
final ProjectMetadata.Builder metadataBuilder = createMetadataBuilderWithDataStream(
78+
dataStreamName,
79+
numberOfBackingIndices,
80+
randomIntBetween(1, 5),
81+
maxIndexAge
9182
);
92-
backingIndices.add(writeIndexMetadata.getIndex());
93-
metadataBuilder.put(writeIndexMetadata, false);
94-
95-
final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
96-
metadataBuilder.put(dataStream);
83+
final DataStream dataStream = metadataBuilder.dataStream(dataStreamName);
9784

9885
final ProjectMetadata.Builder updatedMetadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
9986
dataStream.getName(),
@@ -253,7 +240,7 @@ public void testWriteLoadForecast() {
253240
)
254241
);
255242
assertThat(writeLoadForecast.isPresent(), is(true));
256-
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(14.4)));
243+
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(72.0)));
257244
}
258245

259246
{
@@ -264,14 +251,14 @@ public void testWriteLoadForecast() {
264251
.withShardWriteLoad(1, 24, 999, 999, 5)
265252
.withShardWriteLoad(2, 24, 999, 999, 5)
266253
.withShardWriteLoad(3, 24, 999, 999, 5)
267-
.withShardWriteLoad(4, 24, 999, 999, 4)
254+
.withShardWriteLoad(4, 24, 999, 999, 5)
268255
.build(),
269256
// Since this shard uptime is really low, it doesn't add much to the avg
270257
IndexWriteLoad.builder(1).withShardWriteLoad(0, 120, 999, 999, 1).build()
271258
)
272259
);
273260
assertThat(writeLoadForecast.isPresent(), is(true));
274-
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(15.36)));
261+
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(72.59, 0.01)));
275262
}
276263

277264
{
@@ -283,7 +270,7 @@ public void testWriteLoadForecast() {
283270
)
284271
);
285272
assertThat(writeLoadForecast.isPresent(), is(true));
286-
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(12.0)));
273+
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(16.0)));
287274
}
288275

289276
{
@@ -302,7 +289,7 @@ public void testWriteLoadForecast() {
302289
)
303290
);
304291
assertThat(writeLoadForecast.isPresent(), is(true));
305-
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(15.83, 0.01)));
292+
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(31.66, 0.01)));
306293
}
307294
}
308295

@@ -404,4 +391,163 @@ public boolean innerMatch(LogEvent event) {
404391
);
405392
}, LicensedWriteLoadForecaster.class, collectingLoggingAssertion);
406393
}
394+
395+
public void testShardIncreaseDoesNotIncreaseTotalLoad() {
396+
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.INCREASE);
397+
}
398+
399+
public void testShardDecreaseDoesNotDecreaseTotalLoad() {
400+
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.DECREASE);
401+
}
402+
403+
private void testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange shardCountChange) {
404+
final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
405+
final AtomicBoolean hasValidLicense = new AtomicBoolean(true);
406+
final AtomicInteger licenseCheckCount = new AtomicInteger();
407+
final WriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> {
408+
licenseCheckCount.incrementAndGet();
409+
return hasValidLicense.get();
410+
}, threadPool, maxIndexAge);
411+
writeLoadForecaster.refreshLicense();
412+
413+
final String dataStreamName = randomIdentifier();
414+
final ProjectMetadata.Builder originalMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
415+
dataStreamName,
416+
createMetadataBuilderWithDataStream(dataStreamName, randomIntBetween(5, 15), shardCountChange.originalShardCount(), maxIndexAge)
417+
);
418+
419+
// Generate the same data stream, but with a different number of shards in the write index
420+
final ProjectMetadata.Builder changedShardCountMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
421+
dataStreamName,
422+
updateWriteIndexShardCount(dataStreamName, originalMetadata, shardCountChange)
423+
);
424+
425+
IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(originalMetadata.dataStream(dataStreamName).getWriteIndex());
426+
IndexMetadata changedShardCountWriteIndexMetadata = changedShardCountMetadata.getSafe(
427+
changedShardCountMetadata.dataStream(dataStreamName).getWriteIndex()
428+
);
429+
430+
// The shard count changed
431+
assertThat(
432+
changedShardCountWriteIndexMetadata.getNumberOfShards(),
433+
shardCountChange.expectedChangeFromOriginal(originalWriteIndexMetadata.getNumberOfShards())
434+
);
435+
// But the total write-load did not
436+
assertThat(
437+
changedShardCountWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(
438+
changedShardCountWriteIndexMetadata
439+
).getAsDouble(),
440+
closeTo(
441+
originalWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(originalWriteIndexMetadata)
442+
.getAsDouble(),
443+
0.01
444+
)
445+
);
446+
}
447+
448+
public enum ShardCountChange implements IntToIntFunction {
449+
INCREASE(1, 15) {
450+
@Override
451+
public int apply(int originalShardCount) {
452+
return randomIntBetween(originalShardCount + 1, originalShardCount * 3);
453+
}
454+
455+
public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
456+
return greaterThan(originalShardCount);
457+
}
458+
},
459+
DECREASE(10, 30) {
460+
@Override
461+
public int apply(int originalShardCount) {
462+
return randomIntBetween(1, originalShardCount - 1);
463+
}
464+
465+
public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
466+
return lessThan(originalShardCount);
467+
}
468+
};
469+
470+
private final int originalMinimumShardCount;
471+
private final int originalMaximumShardCount;
472+
473+
ShardCountChange(int originalMinimumShardCount, int originalMaximumShardCount) {
474+
this.originalMinimumShardCount = originalMinimumShardCount;
475+
this.originalMaximumShardCount = originalMaximumShardCount;
476+
}
477+
478+
public int originalShardCount() {
479+
return randomIntBetween(originalMinimumShardCount, originalMaximumShardCount);
480+
}
481+
482+
abstract Matcher<Integer> expectedChangeFromOriginal(int originalShardCount);
483+
}
484+
485+
private ProjectMetadata.Builder updateWriteIndexShardCount(
486+
String dataStreamName,
487+
ProjectMetadata.Builder originalMetadata,
488+
ShardCountChange shardCountChange
489+
) {
490+
final ProjectMetadata.Builder updatedShardCountMetadata = ProjectMetadata.builder(originalMetadata.getId());
491+
492+
final DataStream originalDataStream = originalMetadata.dataStream(dataStreamName);
493+
final Index existingWriteIndex = Objects.requireNonNull(originalDataStream.getWriteIndex());
494+
final IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(existingWriteIndex);
495+
496+
// Copy all non-write indices over unchanged
497+
final List<IndexMetadata> backingIndexMetadatas = originalDataStream.getIndices()
498+
.stream()
499+
.filter(index -> index != existingWriteIndex)
500+
.map(originalMetadata::getSafe)
501+
.collect(Collectors.toList());
502+
503+
// Create a new write index with an updated shard count
504+
final IndexMetadata writeIndexMetadata = createIndexMetadata(
505+
DataStream.getDefaultBackingIndexName(dataStreamName, backingIndexMetadatas.size()),
506+
shardCountChange.apply(originalWriteIndexMetadata.getNumberOfShards()),
507+
null,
508+
System.currentTimeMillis()
509+
);
510+
backingIndexMetadatas.add(writeIndexMetadata);
511+
backingIndexMetadatas.forEach(indexMetadata -> updatedShardCountMetadata.put(indexMetadata, false));
512+
513+
final DataStream dataStream = createDataStream(
514+
dataStreamName,
515+
backingIndexMetadatas.stream().map(IndexMetadata::getIndex).toList()
516+
);
517+
updatedShardCountMetadata.put(dataStream);
518+
return updatedShardCountMetadata;
519+
}
520+
521+
private ProjectMetadata.Builder createMetadataBuilderWithDataStream(
522+
String dataStreamName,
523+
int numberOfBackingIndices,
524+
int numberOfShards,
525+
TimeValue maxIndexAge
526+
) {
527+
final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
528+
final List<Index> backingIndices = new ArrayList<>();
529+
for (int i = 0; i < numberOfBackingIndices; i++) {
530+
final IndexMetadata indexMetadata = createIndexMetadata(
531+
DataStream.getDefaultBackingIndexName(dataStreamName, i),
532+
numberOfShards,
533+
randomIndexWriteLoad(numberOfShards),
534+
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
535+
);
536+
backingIndices.add(indexMetadata.getIndex());
537+
metadataBuilder.put(indexMetadata, false);
538+
}
539+
540+
final IndexMetadata writeIndexMetadata = createIndexMetadata(
541+
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
542+
numberOfShards,
543+
null,
544+
System.currentTimeMillis()
545+
);
546+
backingIndices.add(writeIndexMetadata.getIndex());
547+
metadataBuilder.put(writeIndexMetadata, false);
548+
549+
final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
550+
metadataBuilder.put(dataStream);
551+
return metadataBuilder;
552+
}
407553
}

0 commit comments

Comments
 (0)