Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/144637.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
area: Geo
issues:
- 144504
pr: 144637
summary: Fix `geo_centroid` over `geo_shape` merging multiple shards
type: bug
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.geo.SpatialPoint;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.lucene.spatial.DimensionalShapeType;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.AggregatorReducer;
import org.elasticsearch.search.aggregations.InternalAggregation;
Expand All @@ -28,15 +30,31 @@
* Serialization and merge logic for {@link GeoCentroidAggregator}.
*/
public abstract class InternalCentroid extends InternalAggregation implements CentroidAggregation {

private static final TransportVersion SHAPE_CENTROID_SUPPORT = TransportVersion.fromName("geo_centroid_shape_weighted_sums");

/**
* Holds the raw weighted sums and dimensional shape type needed for correct cross-shard reduction
* of shape centroids. This is {@code null} for geo_point centroids and for results from old nodes,
* avoiding any memory overhead in the common geo_point case.
*/
public record ShapeData(double firstWeightedSum, double secondWeightedSum, double totalWeight, DimensionalShapeType shapeType) {}

protected final SpatialPoint centroid;
protected final long count;
protected final ShapeData shapeData;

public InternalCentroid(String name, SpatialPoint centroid, long count, Map<String, Object> metadata) {
this(name, centroid, count, null, metadata);
}

public InternalCentroid(String name, SpatialPoint centroid, long count, ShapeData shapeData, Map<String, Object> metadata) {
super(name, metadata);
assert (centroid == null) == (count == 0);
this.centroid = centroid;
assert count >= 0;
this.count = count;
this.shapeData = shapeData;
}

protected abstract SpatialPoint centroidFromStream(StreamInput in) throws IOException;
Expand All @@ -55,6 +73,20 @@ protected InternalCentroid(StreamInput in) throws IOException {
} else {
centroid = null;
}
if (in.getTransportVersion().supports(SHAPE_CENTROID_SUPPORT)) {
if (in.readBoolean()) {
shapeData = new ShapeData(
in.readDouble(),
in.readDouble(),
in.readDouble(),
DimensionalShapeType.fromOrdinalByte(in.readByte())
);
} else {
shapeData = null;
}
} else {
shapeData = null;
}
}

@Override
Expand All @@ -66,6 +98,17 @@ protected void doWriteTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (out.getTransportVersion().supports(SHAPE_CENTROID_SUPPORT)) {
if (shapeData != null) {
out.writeBoolean(true);
out.writeDouble(shapeData.firstWeightedSum);
out.writeDouble(shapeData.secondWeightedSum);
out.writeDouble(shapeData.totalWeight);
out.writeByte((byte) shapeData.shapeType.ordinal());
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand All @@ -80,33 +123,81 @@ public long count() {

protected abstract InternalCentroid copyWith(SpatialPoint result, long count);

/** Create a new centroid with by reducing from the sums and total count */
/** Create a new centroid by reducing from the sums and total count (count-weighted path for geo_point). */
protected abstract InternalCentroid copyWith(double firstSum, double secondSum, long totalCount);

/** Create a new centroid from shape-aware weighted sums (area-weighted path for geo_shape). */
protected abstract InternalCentroid copyWithShapeFields(ShapeData shapeData, long count);

protected AggregatorReducer getLeaderReducer(AggregationReduceContext reduceContext, int size) {
return new AggregatorReducer() {

// Count-weighted accumulator (geo_point or old nodes)
double firstSum = Double.NaN;
double secondSum = Double.NaN;
long totalCount = 0;

// Shape-aware accumulator (geo_shape)
double combinedFirstWeighted = 0;
double combinedSecondWeighted = 0;
double combinedWeight = 0;
long shapeCount = 0;
DimensionalShapeType combinedShapeType = DimensionalShapeType.POINT;
boolean hasShapeValues = false;

@Override
public void accept(InternalAggregation aggregation) {
InternalCentroid centroidAgg = (InternalCentroid) aggregation;
if (centroidAgg.count > 0) {
totalCount += centroidAgg.count;
if (Double.isNaN(firstSum)) {
firstSum = centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum = centroidAgg.count * extractSecond(centroidAgg.centroid);
} else {
firstSum += centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum += centroidAgg.count * extractSecond(centroidAgg.centroid);
if (centroidAgg.shapeData != null && centroidAgg.shapeData.totalWeight > 0) {
// Shape-aware path: respect dimensional type priority
int cmp = centroidAgg.shapeData.shapeType.compareTo(combinedShapeType);
if (hasShapeValues == false || cmp > 0) {
// First shape value or higher dimension — reset
combinedFirstWeighted = centroidAgg.shapeData.firstWeightedSum;
combinedSecondWeighted = centroidAgg.shapeData.secondWeightedSum;
combinedWeight = centroidAgg.shapeData.totalWeight;
shapeCount = centroidAgg.count;
combinedShapeType = centroidAgg.shapeData.shapeType;
hasShapeValues = true;
} else if (cmp == 0) {
// Same dimension — accumulate
combinedFirstWeighted += centroidAgg.shapeData.firstWeightedSum;
combinedSecondWeighted += centroidAgg.shapeData.secondWeightedSum;
combinedWeight += centroidAgg.shapeData.totalWeight;
shapeCount += centroidAgg.count;
}
// cmp < 0: lower dimension — ignore
} else if (centroidAgg.centroid != null) {
// Count-weighted path (geo_point or BWC from old node)
if (hasShapeValues) {
// BWC: approximate old-node shape result as same dimension, count as weight
combinedFirstWeighted += centroidAgg.count * extractFirst(centroidAgg.centroid);
combinedSecondWeighted += centroidAgg.count * extractSecond(centroidAgg.centroid);
combinedWeight += centroidAgg.count;
shapeCount += centroidAgg.count;
} else {
totalCount += centroidAgg.count;
if (Double.isNaN(firstSum)) {
firstSum = centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum = centroidAgg.count * extractSecond(centroidAgg.centroid);
} else {
firstSum += centroidAgg.count * extractFirst(centroidAgg.centroid);
secondSum += centroidAgg.count * extractSecond(centroidAgg.centroid);
}
}
}
}
}

@Override
public InternalAggregation get() {
if (hasShapeValues) {
return copyWithShapeFields(
new ShapeData(combinedFirstWeighted, combinedSecondWeighted, combinedWeight, combinedShapeType),
shapeCount
);
}
return copyWith(firstSum, secondSum, totalCount);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ public InternalGeoCentroid(String name, SpatialPoint centroid, long count, Map<S
super(name, centroid, count, metadata);
}

/**
* Constructor for shape centroid results that carry raw weighted sums for correct cross-shard reduction.
*/
public InternalGeoCentroid(String name, SpatialPoint centroid, long count, ShapeData shapeData, Map<String, Object> metadata) {
super(name, centroid, count, shapeData, metadata);
}

/**
* Read from a stream.
*/
Expand Down Expand Up @@ -73,6 +80,14 @@ protected InternalGeoCentroid copyWith(double firstSum, double secondSum, long t
return copyWith(result, totalCount);
}

@Override
protected InternalGeoCentroid copyWithShapeFields(ShapeData shapeData, long count) {
final GeoPoint result = shapeData.totalWeight() > 0
? new GeoPoint(shapeData.firstWeightedSum() / shapeData.totalWeight(), shapeData.secondWeightedSum() / shapeData.totalWeight())
: null;
return new InternalGeoCentroid(name, result, count, shapeData, getMetadata());
}

@Override
protected String nameFirst() {
return "lat";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9325000,9250010,9185023,8841087
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/8.19.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_8.19.13,8841086
geo_centroid_shape_weighted_sums,8841087
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_9.2.7,9185022
geo_centroid_shape_weighted_sums,9185023
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
initial_9.3.2,9250009
geo_centroid_shape_weighted_sums,9250010
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.4.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
inference_api_chat_completion_reasoning_max_tokens_removed,9324000
geo_centroid_shape_weighted_sums,9325000
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.metrics.CompensatedSum;
import org.elasticsearch.search.aggregations.metrics.InternalGeoCentroid;
import org.elasticsearch.search.aggregations.metrics.MetricsAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
Expand Down Expand Up @@ -126,17 +125,21 @@ public InternalAggregation buildAggregation(long bucket) {
if (bucket >= counts.size()) {
return buildEmptyAggregation();
}
final long bucketCount = counts.get(bucket);
final double bucketXSum = lonSum.get(bucket); // x-coordinate sum (named "lon" for historical reasons)
final double bucketYSum = latSum.get(bucket); // y-coordinate sum (named "lat" for historical reasons)
final double bucketWeight = weightSum.get(bucket);
final CartesianPoint bucketCentroid = (bucketWeight > 0)
? new CartesianPoint(lonSum.get(bucket) / bucketWeight, latSum.get(bucket) / bucketWeight)
final long bucketCount = counts.get(bucket);
final DimensionalShapeType bucketShapeType = DimensionalShapeType.fromOrdinalByte(dimensionalShapeTypes.get(bucket));
final CartesianPoint bucketCentroid = bucketWeight > 0
? new CartesianPoint(bucketXSum / bucketWeight, bucketYSum / bucketWeight)
: null;
return new InternalCartesianCentroid(name, bucketCentroid, bucketCount, metadata());
var shapeData = new InternalCartesianCentroid.ShapeData(bucketXSum, bucketYSum, bucketWeight, bucketShapeType);
return new InternalCartesianCentroid(name, bucketCentroid, bucketCount, shapeData, metadata());
}

@Override
public InternalAggregation buildEmptyAggregation() {
return InternalGeoCentroid.empty(name, metadata());
return InternalCartesianCentroid.empty(name, metadata());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ public InternalAggregation buildAggregation(long bucket) {
if (bucket >= counts.size()) {
return buildEmptyAggregation();
}
final long bucketCount = counts.get(bucket);
final double bucketLatSum = latSum.get(bucket);
final double bucketLonSum = lonSum.get(bucket);
final double bucketWeight = weightSum.get(bucket);
final GeoPoint bucketCentroid = (bucketWeight > 0)
? new GeoPoint(latSum.get(bucket) / bucketWeight, lonSum.get(bucket) / bucketWeight)
: null;
return new InternalGeoCentroid(name, bucketCentroid, bucketCount, metadata());
final long bucketCount = counts.get(bucket);
final DimensionalShapeType bucketShapeType = DimensionalShapeType.fromOrdinalByte(dimensionalShapeTypes.get(bucket));
final GeoPoint bucketCentroid = bucketWeight > 0 ? new GeoPoint(bucketLatSum / bucketWeight, bucketLonSum / bucketWeight) : null;
var shapeData = new InternalGeoCentroid.ShapeData(bucketLatSum, bucketLonSum, bucketWeight, bucketShapeType);
return new InternalGeoCentroid(name, bucketCentroid, bucketCount, shapeData, metadata());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ public InternalCartesianCentroid(String name, SpatialPoint centroid, long count,
super(name, centroid, count, metadata);
}

/**
* Constructor for shape centroid results that carry raw weighted sums for correct cross-shard reduction.
*/
public InternalCartesianCentroid(String name, SpatialPoint centroid, long count, ShapeData shapeData, Map<String, Object> metadata) {
super(name, centroid, count, shapeData, metadata);
}

/**
* Read from a stream.
*/
Expand Down Expand Up @@ -72,6 +79,17 @@ protected InternalCartesianCentroid copyWith(double firstSum, double secondSum,
return copyWith(result, totalCount);
}

@Override
protected InternalCartesianCentroid copyWithShapeFields(ShapeData shapeData, long count) {
final CartesianPoint result = shapeData.totalWeight() > 0
? new CartesianPoint(
shapeData.firstWeightedSum() / shapeData.totalWeight(),
shapeData.secondWeightedSum() / shapeData.totalWeight()
)
: null;
return new InternalCartesianCentroid(name, result, count, shapeData, getMetadata());
}

@Override
protected String nameFirst() {
return "x";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.common.geo.SpatialPoint;
import org.elasticsearch.geo.ShapeTestUtils;
import org.elasticsearch.geometry.Geometry;
import org.elasticsearch.geometry.LinearRing;
import org.elasticsearch.geometry.Polygon;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.lucene.spatial.CentroidCalculator;
import org.elasticsearch.lucene.spatial.DimensionalShapeType;
Expand Down Expand Up @@ -189,6 +191,65 @@ public void testSingleValuedField() throws Exception {
}
}

/**
* Tests that when shapes with very different areas are in different segments (simulating different shards),
* the aggregation produces an area-weighted centroid rather than a count-weighted one.
*/
public void testMultiSegmentAreaWeightedReduction() throws Exception {
// Large polygon: 1000×1000 square, centroid at (500, 500), area = 1_000_000
Polygon largePolygon = new Polygon(new LinearRing(new double[] { 0, 1000, 1000, 0, 0 }, new double[] { 0, 0, 1000, 1000, 0 }));

// Small polygon: 1×1 square, centroid at (0.5, 0.5), area = 1
Polygon smallPolygon = new Polygon(new LinearRing(new double[] { 0, 1, 1, 0, 0 }, new double[] { 0, 0, 1, 1, 0 }));
int numSmallPolygons = 100;

CentroidCalculator largeCalc = new CentroidCalculator();
largeCalc.add(largePolygon);
CentroidCalculator smallCalc = new CentroidCalculator();
smallCalc.add(smallPolygon);

double largeWeight = largeCalc.sumWeight();
double smallWeight = smallCalc.sumWeight();
double totalWeight = largeWeight + numSmallPolygons * smallWeight;
double expectedX = (largeWeight * largeCalc.getX() + numSmallPolygons * smallWeight * smallCalc.getX()) / totalWeight;
double expectedY = (largeWeight * largeCalc.getY() + numSmallPolygons * smallWeight * smallCalc.getY()) / totalWeight;

// The count-weighted (buggy) result would be much closer to (0.5, 0.5); the area-weighted result near (500, 500).
assertTrue("Area-weighted centroid x " + expectedX + " should dominate", expectedX > 499);

try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) {
Document doc = new Document();
doc.add(GeoTestUtils.binaryCartesianShapeDocValuesField("field", largePolygon));
w.addDocument(doc);
w.flush();

for (int i = 0; i < numSmallPolygons; i++) {
Document smallDoc = new Document();
smallDoc.add(GeoTestUtils.binaryCartesianShapeDocValuesField("field", smallPolygon));
w.addDocument(smallDoc);
}
w.flush();

MappedFieldType fieldType = new ShapeFieldMapper.ShapeFieldType(
"field",
true,
true,
Orientation.RIGHT,
null,
false,
Collections.emptyMap()
);
CartesianCentroidAggregationBuilder aggBuilder = new CartesianCentroidAggregationBuilder("my_agg").field("field");
try (IndexReader reader = w.getReader()) {
InternalCartesianCentroid result = searchAndReduce(reader, new AggTestConfig(aggBuilder, fieldType));
assertNotNull(result.centroid());
assertEquals(numSmallPolygons + 1, result.count());
assertCentroid("x (area-weighted)", result.count(), result.centroid().getX(), expectedX);
assertCentroid("y (area-weighted)", result.count(), result.centroid().getY(), expectedY);
}
}
}

private void assertCentroid(RandomIndexWriter w, CartesianPoint expectedCentroid) throws IOException {
MappedFieldType fieldType = new ShapeFieldMapper.ShapeFieldType(
"field",
Expand Down
Loading
Loading