Skip to content

Commit 3e527f9

Browse files
YannByrondanzhewuju
authored andcommitted
[spark] support to push down min/max aggregation (apache#5270)
1 parent dab3d60 commit 3e527f9

File tree

8 files changed

+572
-109
lines changed

8 files changed

+572
-109
lines changed

paimon-core/src/main/java/org/apache/paimon/stats/SimpleStatsEvolution.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,30 @@ public SimpleStatsEvolution(
6464
this.emptyNullCounts = new GenericArray(new Object[fieldNames.size()]);
6565
}
6666

67+
public InternalRow evolution(InternalRow row, @Nullable List<String> denseFields) {
68+
InternalRow result = row;
69+
70+
if (denseFields != null && denseFields.isEmpty()) {
71+
result = emptyValues;
72+
} else if (denseFields != null) {
73+
int[] denseIndexMapping =
74+
indexMappings.computeIfAbsent(
75+
denseFields,
76+
k -> fieldNames.stream().mapToInt(denseFields::indexOf).toArray());
77+
result = ProjectedRow.from(denseIndexMapping).replaceRow(result);
78+
}
79+
80+
if (indexMapping != null) {
81+
result = ProjectedRow.from(indexMapping).replaceRow(result);
82+
}
83+
84+
if (castFieldGetters != null) {
85+
result = CastedRow.from(castFieldGetters).replaceRow(result);
86+
}
87+
88+
return result;
89+
}
90+
6791
public Result evolution(
6892
SimpleStats stats, @Nullable Long rowCount, @Nullable List<String> denseFields) {
6993
InternalRow minValues = stats.minValues();

paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.paimon.table.source;
2020

2121
import org.apache.paimon.data.BinaryRow;
22+
import org.apache.paimon.data.InternalRow;
2223
import org.apache.paimon.io.DataFileMeta;
2324
import org.apache.paimon.io.DataFileMeta08Serializer;
2425
import org.apache.paimon.io.DataFileMeta09Serializer;
@@ -28,7 +29,12 @@
2829
import org.apache.paimon.io.DataInputViewStreamWrapper;
2930
import org.apache.paimon.io.DataOutputView;
3031
import org.apache.paimon.io.DataOutputViewStreamWrapper;
32+
import org.apache.paimon.predicate.CompareUtils;
33+
import org.apache.paimon.stats.SimpleStatsEvolution;
34+
import org.apache.paimon.stats.SimpleStatsEvolutions;
35+
import org.apache.paimon.types.DataField;
3136
import org.apache.paimon.utils.FunctionWithIOException;
37+
import org.apache.paimon.utils.InternalRowUtils;
3238
import org.apache.paimon.utils.SerializationUtils;
3339

3440
import javax.annotation.Nullable;
@@ -141,6 +147,44 @@ public long mergedRowCount() {
141147
return partialMergedRowCount();
142148
}
143149

150+
public Object minValue(int fieldIndex, DataField dataField, SimpleStatsEvolutions evolutions) {
151+
Object minValue = null;
152+
for (DataFileMeta dataFile : dataFiles) {
153+
SimpleStatsEvolution evolution = evolutions.getOrCreate(dataFile.schemaId());
154+
InternalRow minValues =
155+
evolution.evolution(
156+
dataFile.valueStats().minValues(), dataFile.valueStatsCols());
157+
Object other = InternalRowUtils.get(minValues, fieldIndex, dataField.type());
158+
if (minValue == null) {
159+
minValue = other;
160+
} else if (other != null) {
161+
if (CompareUtils.compareLiteral(dataField.type(), minValue, other) > 0) {
162+
minValue = other;
163+
}
164+
}
165+
}
166+
return minValue;
167+
}
168+
169+
public Object maxValue(int fieldIndex, DataField dataField, SimpleStatsEvolutions evolutions) {
170+
Object maxValue = null;
171+
for (DataFileMeta dataFile : dataFiles) {
172+
SimpleStatsEvolution evolution = evolutions.getOrCreate(dataFile.schemaId());
173+
InternalRow maxValues =
174+
evolution.evolution(
175+
dataFile.valueStats().maxValues(), dataFile.valueStatsCols());
176+
Object other = InternalRowUtils.get(maxValues, fieldIndex, dataField.type());
177+
if (maxValue == null) {
178+
maxValue = other;
179+
} else if (other != null) {
180+
if (CompareUtils.compareLiteral(dataField.type(), maxValue, other) < 0) {
181+
maxValue = other;
182+
}
183+
}
184+
}
185+
return maxValue;
186+
}
187+
144188
/**
145189
* Obtain merged row count as much as possible. There are two scenarios where accurate row count
146190
* can be calculated:

paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,30 @@
2828
import org.apache.paimon.io.DataOutputViewStreamWrapper;
2929
import org.apache.paimon.manifest.FileSource;
3030
import org.apache.paimon.stats.SimpleStats;
31+
import org.apache.paimon.stats.SimpleStatsEvolutions;
32+
import org.apache.paimon.types.BigIntType;
33+
import org.apache.paimon.types.DataField;
34+
import org.apache.paimon.types.DoubleType;
35+
import org.apache.paimon.types.FloatType;
36+
import org.apache.paimon.types.IntType;
37+
import org.apache.paimon.types.SmallIntType;
38+
import org.apache.paimon.types.TimestampType;
3139
import org.apache.paimon.utils.IOUtils;
3240
import org.apache.paimon.utils.InstantiationUtil;
3341

3442
import org.junit.jupiter.api.Test;
3543

44+
import javax.annotation.Nullable;
45+
3646
import java.io.ByteArrayOutputStream;
3747
import java.io.IOException;
3848
import java.time.LocalDateTime;
3949
import java.util.ArrayList;
4050
import java.util.Arrays;
4151
import java.util.Collections;
52+
import java.util.HashMap;
4253
import java.util.List;
54+
import java.util.Map;
4355
import java.util.concurrent.ThreadLocalRandom;
4456

4557
import static org.apache.paimon.data.BinaryArray.fromLongArray;
@@ -84,6 +96,70 @@ public void testSplitMergedRowCount() {
8496
assertThat(split.mergedRowCount()).isEqualTo(5700L);
8597
}
8698

99+
@Test
100+
public void testSplitMinMaxValue() {
101+
Map<Long, List<DataField>> schemas = new HashMap<>();
102+
103+
Timestamp minTs = Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-01-01T00:00:00"));
104+
Timestamp maxTs1 = Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-01T00:00:00"));
105+
Timestamp maxTs2 = Timestamp.fromLocalDateTime(LocalDateTime.parse("2025-03-12T00:00:00"));
106+
BinaryRow min1 = newBinaryRow(new Object[] {10, 123L, 888.0D, minTs});
107+
BinaryRow max1 = newBinaryRow(new Object[] {99, 456L, 999.0D, maxTs1});
108+
SimpleStats valueStats1 = new SimpleStats(min1, max1, fromLongArray(new Long[] {0L}));
109+
110+
BinaryRow min2 = newBinaryRow(new Object[] {5, 0L, 777.0D, minTs});
111+
BinaryRow max2 = newBinaryRow(new Object[] {90, 789L, 899.0D, maxTs2});
112+
SimpleStats valueStats2 = new SimpleStats(min2, max2, fromLongArray(new Long[] {0L}));
113+
114+
// test the common case.
115+
DataFileMeta d1 = newDataFile(100, valueStats1, null);
116+
DataFileMeta d2 = newDataFile(100, valueStats2, null);
117+
DataSplit split1 = newDataSplit(true, Arrays.asList(d1, d2), null);
118+
119+
DataField intField = new DataField(0, "c_int", new IntType());
120+
DataField longField = new DataField(1, "c_long", new BigIntType());
121+
DataField doubleField = new DataField(2, "c_double", new DoubleType());
122+
DataField tsField = new DataField(3, "c_ts", new TimestampType());
123+
schemas.put(1L, Arrays.asList(intField, longField, doubleField, tsField));
124+
125+
SimpleStatsEvolutions evolutions = new SimpleStatsEvolutions(schemas::get, 1);
126+
assertThat(split1.minValue(0, intField, evolutions)).isEqualTo(5);
127+
assertThat(split1.maxValue(0, intField, evolutions)).isEqualTo(99);
128+
assertThat(split1.minValue(1, longField, evolutions)).isEqualTo(0L);
129+
assertThat(split1.maxValue(1, longField, evolutions)).isEqualTo(789L);
130+
assertThat(split1.minValue(2, doubleField, evolutions)).isEqualTo(777D);
131+
assertThat(split1.maxValue(2, doubleField, evolutions)).isEqualTo(999D);
132+
assertThat(split1.minValue(3, tsField, evolutions)).isEqualTo(minTs);
133+
assertThat(split1.maxValue(3, tsField, evolutions)).isEqualTo(maxTs2);
134+
135+
// test the case which provide non-null valueStatsCol and there are different between file
136+
// schema and table schema.
137+
BinaryRow min3 = newBinaryRow(new Object[] {10, 123L, minTs});
138+
BinaryRow max3 = newBinaryRow(new Object[] {99, 456L, maxTs1});
139+
SimpleStats valueStats3 = new SimpleStats(min3, max3, fromLongArray(new Long[] {0L}));
140+
BinaryRow min4 = newBinaryRow(new Object[] {5, 0L, minTs});
141+
BinaryRow max4 = newBinaryRow(new Object[] {90, 789L, maxTs2});
142+
SimpleStats valueStats4 = new SimpleStats(min4, max4, fromLongArray(new Long[] {0L}));
143+
List<String> valueStatsCols2 = Arrays.asList("c_int", "c_long", "c_ts");
144+
DataFileMeta d3 = newDataFile(100, valueStats3, valueStatsCols2);
145+
DataFileMeta d4 = newDataFile(100, valueStats4, valueStatsCols2);
146+
DataSplit split2 = newDataSplit(true, Arrays.asList(d3, d4), null);
147+
148+
DataField smallField = new DataField(4, "c_small", new SmallIntType());
149+
DataField floatField = new DataField(5, "c_float", new FloatType());
150+
schemas.put(2L, Arrays.asList(intField, smallField, tsField, floatField));
151+
152+
evolutions = new SimpleStatsEvolutions(schemas::get, 2);
153+
assertThat(split2.minValue(0, intField, evolutions)).isEqualTo(5);
154+
assertThat(split2.maxValue(0, intField, evolutions)).isEqualTo(99);
155+
assertThat(split2.minValue(1, smallField, evolutions)).isEqualTo(null);
156+
assertThat(split2.maxValue(1, smallField, evolutions)).isEqualTo(null);
157+
assertThat(split2.minValue(2, tsField, evolutions)).isEqualTo(minTs);
158+
assertThat(split2.maxValue(2, tsField, evolutions)).isEqualTo(maxTs2);
159+
assertThat(split2.minValue(3, floatField, evolutions)).isEqualTo(null);
160+
assertThat(split2.maxValue(3, floatField, evolutions)).isEqualTo(null);
161+
}
162+
87163
@Test
88164
public void testSerializer() throws IOException {
89165
DataFileTestDataGenerator gen = DataFileTestDataGenerator.builder().build();
@@ -436,18 +512,23 @@ public void testSerializerCompatibleV5() throws Exception {
436512
}
437513

438514
private DataFileMeta newDataFile(long rowCount) {
515+
return newDataFile(rowCount, null, null);
516+
}
517+
518+
private DataFileMeta newDataFile(
519+
long rowCount, SimpleStats rowStats, @Nullable List<String> valueStatsCols) {
439520
return DataFileMeta.forAppend(
440521
"my_data_file.parquet",
441522
1024 * 1024,
442523
rowCount,
443-
null,
524+
rowStats,
444525
0L,
445-
rowCount,
526+
rowCount - 1,
446527
1,
447528
Collections.emptyList(),
448529
null,
449530
null,
450-
null,
531+
valueStatsCols,
451532
null);
452533
}
453534

@@ -467,4 +548,27 @@ private DataSplit newDataSplit(
467548
}
468549
return builder.build();
469550
}
551+
552+
private BinaryRow newBinaryRow(Object[] objs) {
553+
BinaryRow row = new BinaryRow(objs.length);
554+
BinaryRowWriter writer = new BinaryRowWriter(row);
555+
writer.reset();
556+
for (int i = 0; i < objs.length; i++) {
557+
if (objs[i] instanceof Integer) {
558+
writer.writeInt(i, (Integer) objs[i]);
559+
} else if (objs[i] instanceof Long) {
560+
writer.writeLong(i, (Long) objs[i]);
561+
} else if (objs[i] instanceof Float) {
562+
writer.writeFloat(i, (Float) objs[i]);
563+
} else if (objs[i] instanceof Double) {
564+
writer.writeDouble(i, (Double) objs[i]);
565+
} else if (objs[i] instanceof Timestamp) {
566+
writer.writeTimestamp(i, (Timestamp) objs[i], 5);
567+
} else {
568+
throw new UnsupportedOperationException("It's not supported.");
569+
}
570+
}
571+
writer.complete();
572+
return row;
573+
}
470574
}

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
package org.apache.paimon.spark
2020

2121
import org.apache.paimon.predicate.{PartitionPredicateVisitor, Predicate, PredicateBuilder}
22-
import org.apache.paimon.spark.aggregate.LocalAggregator
23-
import org.apache.paimon.table.Table
22+
import org.apache.paimon.spark.aggregate.{AggregatePushDownUtils, LocalAggregator}
23+
import org.apache.paimon.table.{FileStoreTable, Table}
2424
import org.apache.paimon.table.source.DataSplit
2525

2626
import org.apache.spark.sql.PaimonUtils
@@ -101,13 +101,12 @@ class PaimonScanBuilder(table: Table)
101101
return true
102102
}
103103

104-
// Only support when there is no post scan predicates.
105-
if (hasPostScanPredicates) {
104+
if (!table.isInstanceOf[FileStoreTable]) {
106105
return false
107106
}
108107

109-
val aggregator = new LocalAggregator(table)
110-
if (!aggregator.pushAggregation(aggregation)) {
108+
// Only support when there is no post scan predicates.
109+
if (hasPostScanPredicates) {
111110
return false
112111
}
113112

@@ -116,19 +115,26 @@ class PaimonScanBuilder(table: Table)
116115
val pushedPartitionPredicate = PredicateBuilder.and(pushedPaimonPredicates.toList.asJava)
117116
readBuilder.withFilter(pushedPartitionPredicate)
118117
}
119-
val dataSplits =
118+
val dataSplits = if (AggregatePushDownUtils.hasMinMaxAggregation(aggregation)) {
119+
readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
120+
} else {
120121
readBuilder.dropStats().newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
121-
if (!dataSplits.forall(_.mergedRowCountAvailable())) {
122-
return false
123122
}
124-
dataSplits.foreach(aggregator.update)
125-
localScan = Some(
126-
PaimonLocalScan(
127-
aggregator.result(),
128-
aggregator.resultSchema(),
129-
table,
130-
pushedPaimonPredicates))
131-
true
123+
if (AggregatePushDownUtils.canPushdownAggregation(table, aggregation, dataSplits.toSeq)) {
124+
val aggregator = new LocalAggregator(table.asInstanceOf[FileStoreTable])
125+
aggregator.initialize(aggregation)
126+
dataSplits.foreach(aggregator.update)
127+
localScan = Some(
128+
PaimonLocalScan(
129+
aggregator.result(),
130+
aggregator.resultSchema(),
131+
table,
132+
pushedPaimonPredicates)
133+
)
134+
true
135+
} else {
136+
false
137+
}
132138
}
133139

134140
override def build(): Scan = {

0 commit comments

Comments
 (0)