Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ public void allocate(RoutingAllocation allocation) {
allocation,
balancerSettings.getThreshold(),
balancingWeights,
balancerSettings.completeEarlyOnShardAssignmentChange()
balancerSettings.completeEarlyOnShardAssignmentChange(),
balancerSettings.getDiskUsageBalanceFactor() == 0
);

boolean shardAssigned = false, shardMoved = false, shardBalanced = false;
Expand Down Expand Up @@ -248,7 +249,8 @@ public ShardAllocationDecision decideShardAllocation(final ShardRouting shard, f
allocation,
balancerSettings.getThreshold(),
balancingWeightsFactory.create(),
balancerSettings.completeEarlyOnShardAssignmentChange()
balancerSettings.completeEarlyOnShardAssignmentChange(),
balancerSettings.getDiskUsageBalanceFactor() == 0
);
AllocateUnassignedDecision allocateUnassignedDecision = AllocateUnassignedDecision.NOT_TAKEN;
MoveDecision moveDecision = MoveDecision.NOT_TAKEN;
Expand Down Expand Up @@ -309,13 +311,15 @@ public static class Balancer {
private final BalancingWeights balancingWeights;
private final NodeSorters nodeSorters;
private final boolean completeEarlyOnShardAssignmentChange;
private final boolean skipDiskUsageCalculation;

private Balancer(
WriteLoadForecaster writeLoadForecaster,
RoutingAllocation allocation,
float threshold,
BalancingWeights balancingWeights,
boolean completeEarlyOnShardAssignmentChange
boolean completeEarlyOnShardAssignmentChange,
boolean skipDiskUsageCalculation
) {
this.writeLoadForecaster = writeLoadForecaster;
this.allocation = allocation;
Expand All @@ -324,11 +328,14 @@ private Balancer(
this.threshold = threshold;
avgShardsPerNode = WeightFunction.avgShardPerNode(metadata, routingNodes);
avgWriteLoadPerNode = WeightFunction.avgWriteLoadPerNode(writeLoadForecaster, metadata, routingNodes);
avgDiskUsageInBytesPerNode = WeightFunction.avgDiskUsageInBytesPerNode(allocation.clusterInfo(), metadata, routingNodes);
nodes = Collections.unmodifiableMap(buildModelFromAssigned());
avgDiskUsageInBytesPerNode = skipDiskUsageCalculation
? 0
: WeightFunction.avgDiskUsageInBytesPerNode(allocation.clusterInfo(), metadata, routingNodes);
nodes = Collections.unmodifiableMap(buildModelFromAssigned(skipDiskUsageCalculation));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cost saving is realistic. My main question is whether the approach is considered hacky.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about if rather than having the additional flag, we passed the weighting around, and the expensive parts could only perform the calculation if the weighting was non-zero?

I'm not sure if that's better, but it's a thought.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could refer to this.balancingWeights perhaps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below flamegraph (from many-shards benchmark) that shows the time spent on disk related computation (purple color) inside allocate calls.
Screenshot 2025-10-10 at 4 40 09 pm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We synced offline and agreed to change the boolean flag to be a method on BalancingWeights.

this.nodeSorters = balancingWeights.createNodeSorters(nodesArray(), this);
this.balancingWeights = balancingWeights;
this.completeEarlyOnShardAssignmentChange = completeEarlyOnShardAssignmentChange;
this.skipDiskUsageCalculation = skipDiskUsageCalculation;
}

private static long getShardDiskUsageInBytes(ShardRouting shardRouting, IndexMetadata indexMetadata, ClusterInfo clusterInfo) {
Expand All @@ -345,7 +352,12 @@ private float getShardWriteLoad(ProjectIndex index) {
return (float) writeLoadForecaster.getForecastedWriteLoad(projectMetadata.index(index.indexName)).orElse(0.0);
}

// This method is used only by NodeSorter#minWeightDelta to compute the node weight delta.
// Hence, we can return 0 when disk usage is ignored. Any future usage of this method should review whether this still holds.
private float maxShardSizeBytes(ProjectIndex index) {
if (skipDiskUsageCalculation) {
return 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit I have minor apprehensions about. But it seems like it's not an easy one to skip on the caller side. And we do have that information available to us here via the balancingWeights.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove this change if you prefer. It does not really show up in the flamegraph. So I could be over-zealous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe that would be nicer. The behaviour seems a little surprising potentially. It would seem safer if the caller was skipping the call rather than the callee just returning zero.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that we can check it at the call site. Not sure why I initially thought it was not feasible ... Pushed 0affb99
Let me know if this works for you. Thanks!

final var indexMetadata = indexMetadata(index);
if (indexMetadata.ignoreDiskWatermarks()) {
// disk watermarks are ignored for partial searchable snapshots
Expand Down Expand Up @@ -1156,10 +1168,10 @@ private Decision decideCanForceAllocateForVacate(ShardRouting shardRouting, Rout
* on the target node which we respect during the allocation / balancing
* process. In short, this method recreates the status-quo in the cluster.
*/
private Map<String, ModelNode> buildModelFromAssigned() {
private Map<String, ModelNode> buildModelFromAssigned(boolean skipDiskUsageCalculation) {
Map<String, ModelNode> nodes = Maps.newMapWithExpectedSize(routingNodes.size());
for (RoutingNode rn : routingNodes) {
ModelNode node = new ModelNode(writeLoadForecaster, metadata, allocation.clusterInfo(), rn);
ModelNode node = new ModelNode(writeLoadForecaster, metadata, allocation.clusterInfo(), rn, skipDiskUsageCalculation);
nodes.put(rn.nodeId(), node);
for (ShardRouting shard : rn) {
assert rn.nodeId().equals(shard.currentNodeId());
Expand Down Expand Up @@ -1476,13 +1488,21 @@ public static class ModelNode implements Iterable<ModelIndex> {
private final ClusterInfo clusterInfo;
private final RoutingNode routingNode;
private final Map<ProjectIndex, ModelIndex> indices;
private final boolean skipDiskUsageCalculation;

public ModelNode(WriteLoadForecaster writeLoadForecaster, Metadata metadata, ClusterInfo clusterInfo, RoutingNode routingNode) {
public ModelNode(
WriteLoadForecaster writeLoadForecaster,
Metadata metadata,
ClusterInfo clusterInfo,
RoutingNode routingNode,
boolean skipDiskUsageCalculation
) {
this.writeLoadForecaster = writeLoadForecaster;
this.metadata = metadata;
this.clusterInfo = clusterInfo;
this.routingNode = routingNode;
this.indices = Maps.newMapWithExpectedSize(routingNode.size() + 10);// some extra to account for shard movements
this.skipDiskUsageCalculation = skipDiskUsageCalculation;
}

public ModelIndex getIndex(ProjectIndex index) {
Expand Down Expand Up @@ -1527,7 +1547,9 @@ public void addShard(ProjectIndex index, ShardRouting shard) {
indices.computeIfAbsent(index, t -> new ModelIndex()).addShard(shard);
IndexMetadata indexMetadata = metadata.getProject(index.project).index(shard.index());
writeLoad += writeLoadForecaster.getForecastedWriteLoad(indexMetadata).orElse(0.0);
diskUsageInBytes += Balancer.getShardDiskUsageInBytes(shard, indexMetadata, clusterInfo);
if (skipDiskUsageCalculation == false) {
diskUsageInBytes += Balancer.getShardDiskUsageInBytes(shard, indexMetadata, clusterInfo);
}
numShards++;
}

Expand All @@ -1541,7 +1563,9 @@ public void removeShard(ProjectIndex projectIndex, ShardRouting shard) {
}
IndexMetadata indexMetadata = metadata.getProject(projectIndex.project).index(shard.index());
writeLoad -= writeLoadForecaster.getForecastedWriteLoad(indexMetadata).orElse(0.0);
diskUsageInBytes -= Balancer.getShardDiskUsageInBytes(shard, indexMetadata, clusterInfo);
if (skipDiskUsageCalculation == false) {
diskUsageInBytes -= Balancer.getShardDiskUsageInBytes(shard, indexMetadata, clusterInfo);
}
numShards--;
}

Expand Down