Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129990.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129990
summary: Make forecast write load accurate when shard numbers change
area: Allocation
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ public String toString() {
* <p>If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO.
* If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown
* period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or
* {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should
* {@link AutoShardingType#COOLDOWN_PREVENTED_DECREASE} with the remaining cooldown configured and the number of shards that should
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed a typo too

* be configured for the data stream once the remaining cooldown lapses as the target number of shards.
*
* <p>The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ public ProjectMetadata.Builder withWriteLoadForecastForWriteIndex(String dataStr
}

final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex());
metadata.put(IndexMetadata.builder(writeIndex).indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble()).build(), false);
metadata.put(
IndexMetadata.builder(writeIndex)
.indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble() / writeIndex.getNumberOfShards())
.build(),
false
);

return metadata;
}
Expand All @@ -129,25 +134,48 @@ private static void clearPreviousForecast(DataStream dataStream, ProjectMetadata
}
}

/**
* This calculates the weighted average total write-load for all recent indices.
*
* @param indicesWriteLoadWithinMaxAgeRange The indices considered "recent"
* @return The weighted average total write-load. To get the per-shard write load, this number must be divided by the number of shards
*/
// Visible for testing
static OptionalDouble forecastIndexWriteLoad(List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange) {
double totalWeightedWriteLoad = 0;
long totalShardUptime = 0;
double allIndicesWriteLoad = 0;
long allIndicesUptime = 0;
for (IndexWriteLoad writeLoad : indicesWriteLoadWithinMaxAgeRange) {
double totalShardWriteLoad = 0;
long totalShardUptimeInMillis = 0;
long maxShardUptimeInMillis = 0;
for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
if (writeLoadForShard.isPresent()) {
assert uptimeInMillisForShard.isPresent();
double shardWriteLoad = writeLoadForShard.getAsDouble();
long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
totalWeightedWriteLoad += shardWriteLoad * shardUptimeInMillis;
totalShardUptime += shardUptimeInMillis;
totalShardWriteLoad += shardWriteLoad * shardUptimeInMillis;
totalShardUptimeInMillis += shardUptimeInMillis;
maxShardUptimeInMillis = Math.max(maxShardUptimeInMillis, shardUptimeInMillis);
}
}
double weightedAverageShardWriteLoad = totalShardWriteLoad / totalShardUptimeInMillis;
double totalIndexWriteLoad = weightedAverageShardWriteLoad * writeLoad.numberOfShards();
// We need to weight the contribution from each index somehow, but we only know
// the write-load from the final allocation of each shard at rollover time. It's
// possible the index is much older than any of those shards, but we don't have
// any write-load data beyond their lifetime.
// To avoid making assumptions about periods for which we have no data, we'll weight
// each index's contribution to the forecast by the maximum shard uptime observed in
// that index. It should be safe to extrapolate our weighted average out to the
// maximum uptime observed, based on the assumption that write-load is roughly
// evenly distributed across shards of a datastream index.
allIndicesWriteLoad += totalIndexWriteLoad * maxShardUptimeInMillis;
allIndicesUptime += maxShardUptimeInMillis;
}

return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime);
return allIndicesUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(allIndicesWriteLoad / allIndicesUptime);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexMetadataStats;
Expand All @@ -24,16 +25,19 @@
import org.elasticsearch.test.MockLog;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matcher;
import org.junit.After;
import org.junit.Before;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.writeloadforecaster.LicensedWriteLoadForecaster.forecastIndexWriteLoad;
import static org.hamcrest.Matchers.closeTo;
Expand All @@ -42,6 +46,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;

public class LicensedWriteLoadForecasterTests extends ESTestCase {
ThreadPool threadPool;
Expand All @@ -67,33 +72,15 @@ public void testWriteLoadForecastIsAddedToWriteIndex() {

writeLoadForecaster.refreshLicense();

final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
final String dataStreamName = "logs-es";
final int numberOfBackingIndices = 10;
final int numberOfShards = randomIntBetween(1, 5);
final List<Index> backingIndices = new ArrayList<>();
for (int i = 0; i < numberOfBackingIndices; i++) {
final IndexMetadata indexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, i),
numberOfShards,
randomIndexWriteLoad(numberOfShards),
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
);
backingIndices.add(indexMetadata.getIndex());
metadataBuilder.put(indexMetadata, false);
}

final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
numberOfShards,
null,
System.currentTimeMillis()
final ProjectMetadata.Builder metadataBuilder = createMetadataBuilderWithDataStream(
dataStreamName,
numberOfBackingIndices,
randomIntBetween(1, 5),
maxIndexAge
);
backingIndices.add(writeIndexMetadata.getIndex());
metadataBuilder.put(writeIndexMetadata, false);

final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
metadataBuilder.put(dataStream);
final DataStream dataStream = metadataBuilder.dataStream(dataStreamName);

final ProjectMetadata.Builder updatedMetadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStream.getName(),
Expand Down Expand Up @@ -253,7 +240,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(14.4)));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(72.0)));
}

{
Expand All @@ -264,14 +251,14 @@ public void testWriteLoadForecast() {
.withShardWriteLoad(1, 24, 999, 999, 5)
.withShardWriteLoad(2, 24, 999, 999, 5)
.withShardWriteLoad(3, 24, 999, 999, 5)
.withShardWriteLoad(4, 24, 999, 999, 4)
.withShardWriteLoad(4, 24, 999, 999, 5)
.build(),
// Since this shard uptime is really low, it doesn't add much to the avg
IndexWriteLoad.builder(1).withShardWriteLoad(0, 120, 999, 999, 1).build()
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(15.36)));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(72.59, 0.01)));
}

{
Expand All @@ -283,7 +270,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(12.0)));
assertThat(writeLoadForecast.getAsDouble(), is(equalTo(16.0)));
}

{
Expand All @@ -302,7 +289,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(15.83, 0.01)));
assertThat(writeLoadForecast.getAsDouble(), is(closeTo(31.66, 0.01)));
}
}

Expand Down Expand Up @@ -404,4 +391,163 @@ public boolean innerMatch(LogEvent event) {
);
}, LicensedWriteLoadForecaster.class, collectingLoggingAssertion);
}

public void testShardIncreaseDoesNotIncreaseTotalLoad() {
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.INCREASE);
}

public void testShardDecreaseDoesNotDecreaseTotalLoad() {
testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.DECREASE);
}

private void testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange shardCountChange) {
final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
final AtomicBoolean hasValidLicense = new AtomicBoolean(true);
final AtomicInteger licenseCheckCount = new AtomicInteger();
final WriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> {
licenseCheckCount.incrementAndGet();
return hasValidLicense.get();
}, threadPool, maxIndexAge);
writeLoadForecaster.refreshLicense();

final String dataStreamName = randomIdentifier();
final ProjectMetadata.Builder originalMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStreamName,
createMetadataBuilderWithDataStream(dataStreamName, randomIntBetween(5, 15), shardCountChange.originalShardCount(), maxIndexAge)
);

// Generate the same data stream, but with a different number of shards in the write index
final ProjectMetadata.Builder changedShardCountMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStreamName,
updateWriteIndexShardCount(dataStreamName, originalMetadata, shardCountChange)
);

IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(originalMetadata.dataStream(dataStreamName).getWriteIndex());
IndexMetadata changedShardCountWriteIndexMetadata = changedShardCountMetadata.getSafe(
changedShardCountMetadata.dataStream(dataStreamName).getWriteIndex()
);

// The shard count changed
assertThat(
changedShardCountWriteIndexMetadata.getNumberOfShards(),
shardCountChange.expectedChangeFromOriginal(originalWriteIndexMetadata.getNumberOfShards())
);
// But the total write-load did not
assertThat(
changedShardCountWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(
changedShardCountWriteIndexMetadata
).getAsDouble(),
closeTo(
originalWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(originalWriteIndexMetadata)
.getAsDouble(),
0.01
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are sometimes small rounding errors

)
);
}

public enum ShardCountChange implements IntToIntFunction {
INCREASE(1, 15) {
@Override
public int apply(int originalShardCount) {
return randomIntBetween(originalShardCount + 1, originalShardCount * 3);
}

public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
return greaterThan(originalShardCount);
}
},
DECREASE(10, 30) {
@Override
public int apply(int originalShardCount) {
return randomIntBetween(1, originalShardCount - 1);
}

public Matcher<Integer> expectedChangeFromOriginal(int originalShardCount) {
return lessThan(originalShardCount);
}
};

private final int originalMinimumShardCount;
private final int originalMaximumShardCount;

ShardCountChange(int originalMinimumShardCount, int originalMaximumShardCount) {
this.originalMinimumShardCount = originalMinimumShardCount;
this.originalMaximumShardCount = originalMaximumShardCount;
}

public int originalShardCount() {
return randomIntBetween(originalMinimumShardCount, originalMaximumShardCount);
}

abstract Matcher<Integer> expectedChangeFromOriginal(int originalShardCount);
}

private ProjectMetadata.Builder updateWriteIndexShardCount(
String dataStreamName,
ProjectMetadata.Builder originalMetadata,
ShardCountChange shardCountChange
) {
final ProjectMetadata.Builder updatedShardCountMetadata = ProjectMetadata.builder(originalMetadata.getId());

final DataStream originalDataStream = originalMetadata.dataStream(dataStreamName);
final Index existingWriteIndex = Objects.requireNonNull(originalDataStream.getWriteIndex());
final IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(existingWriteIndex);

// Copy all non-write indices over unchanged
final List<IndexMetadata> backingIndexMetadatas = originalDataStream.getIndices()
.stream()
.filter(index -> index != existingWriteIndex)
.map(originalMetadata::getSafe)
.collect(Collectors.toList());

// Create a new write index with an updated shard count
final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, backingIndexMetadatas.size()),
shardCountChange.apply(originalWriteIndexMetadata.getNumberOfShards()),
null,
System.currentTimeMillis()
);
backingIndexMetadatas.add(writeIndexMetadata);
backingIndexMetadatas.forEach(indexMetadata -> updatedShardCountMetadata.put(indexMetadata, false));

final DataStream dataStream = createDataStream(
dataStreamName,
backingIndexMetadatas.stream().map(IndexMetadata::getIndex).toList()
);
updatedShardCountMetadata.put(dataStream);
return updatedShardCountMetadata;
}

private ProjectMetadata.Builder createMetadataBuilderWithDataStream(
String dataStreamName,
int numberOfBackingIndices,
int numberOfShards,
TimeValue maxIndexAge
) {
final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
final List<Index> backingIndices = new ArrayList<>();
for (int i = 0; i < numberOfBackingIndices; i++) {
final IndexMetadata indexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, i),
numberOfShards,
randomIndexWriteLoad(numberOfShards),
System.currentTimeMillis() - (maxIndexAge.millis() / 2)
);
backingIndices.add(indexMetadata.getIndex());
metadataBuilder.put(indexMetadata, false);
}

final IndexMetadata writeIndexMetadata = createIndexMetadata(
DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
numberOfShards,
null,
System.currentTimeMillis()
);
backingIndices.add(writeIndexMetadata.getIndex());
metadataBuilder.put(writeIndexMetadata, false);

final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
metadataBuilder.put(dataStream);
return metadataBuilder;
}
}
Loading