Skip to content

Commit d7b2c35

Browse files
[SQL/PPL] Fix the count(*) and dc(field) to be capped at MAX_INTEGER #4416 (#4418)
Co-authored-by: Aaron Alvarez <[email protected]>
1 parent 3e95147 commit d7b2c35

File tree

14 files changed

+129
-116
lines changed

14 files changed

+129
-116
lines changed

core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunctions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ private static DefaultFunctionResolver count() {
9292
new FunctionSignature(functionName, Collections.singletonList(type)),
9393
type ->
9494
(functionProperties, arguments) ->
95-
new CountAggregator(arguments, INTEGER))));
95+
new CountAggregator(arguments, LONG))));
9696
return functionResolver;
9797
}
9898

core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public String toString() {
4444

4545
/** Count State. */
4646
protected static class CountState implements AggregationState {
47-
protected int count;
47+
protected long count;
4848

4949
CountState() {
5050
this.count = 0;
@@ -56,7 +56,7 @@ public void count(ExprValue value) {
5656

5757
@Override
5858
public ExprValue result() {
59-
return ExprValueUtils.integerValue(count);
59+
return ExprValueUtils.longValue(count);
6060
}
6161
}
6262

core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ public void named_aggregator_with_condition() {
12831283
emptyList()),
12841284
DSL.named(
12851285
"count(string_value) filter(where integer_value > 1)",
1286-
DSL.ref("count(string_value) filter(where integer_value > 1)", INTEGER))),
1286+
DSL.ref("count(string_value) filter(where integer_value > 1)", LONG))),
12871287
AstDSL.project(
12881288
AstDSL.agg(
12891289
AstDSL.relation("schema"),

core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,37 @@ class CountAggregatorTest extends AggregationTest {
2929
@Test
3030
public void count_integer_field_expression() {
3131
ExprValue result = aggregation(DSL.count(DSL.ref("integer_value", INTEGER)), tuples);
32-
assertEquals(4, result.value());
32+
assertEquals(4L, result.value());
3333
}
3434

3535
@Test
3636
public void count_long_field_expression() {
3737
ExprValue result = aggregation(DSL.count(DSL.ref("long_value", LONG)), tuples);
38-
assertEquals(4, result.value());
38+
assertEquals(4L, result.value());
3939
}
4040

4141
@Test
4242
public void count_float_field_expression() {
4343
ExprValue result = aggregation(DSL.count(DSL.ref("float_value", FLOAT)), tuples);
44-
assertEquals(4, result.value());
44+
assertEquals(4L, result.value());
4545
}
4646

4747
@Test
4848
public void count_double_field_expression() {
4949
ExprValue result = aggregation(DSL.count(DSL.ref("double_value", DOUBLE)), tuples);
50-
assertEquals(4, result.value());
50+
assertEquals(4L, result.value());
5151
}
5252

5353
@Test
5454
public void count_date_field_expression() {
5555
ExprValue result = aggregation(DSL.count(DSL.ref("date_value", DATE)), tuples);
56-
assertEquals(4, result.value());
56+
assertEquals(4L, result.value());
5757
}
5858

5959
@Test
6060
public void count_timestamp_field_expression() {
6161
ExprValue result = aggregation(DSL.count(DSL.ref("timestamp_value", TIMESTAMP)), tuples);
62-
assertEquals(4, result.value());
62+
assertEquals(4L, result.value());
6363
}
6464

6565
@Test
@@ -68,34 +68,33 @@ public void count_arithmetic_expression() {
6868
aggregation(
6969
DSL.count(
7070
DSL.multiply(
71-
DSL.ref("integer_value", INTEGER),
72-
DSL.literal(ExprValueUtils.integerValue(10)))),
71+
DSL.ref("long_value", LONG), DSL.literal(ExprValueUtils.longValue(10L)))),
7372
tuples);
74-
assertEquals(4, result.value());
73+
assertEquals(4L, result.value());
7574
}
7675

7776
@Test
7877
public void count_string_field_expression() {
7978
ExprValue result = aggregation(DSL.count(DSL.ref("string_value", STRING)), tuples);
80-
assertEquals(4, result.value());
79+
assertEquals(4L, result.value());
8180
}
8281

8382
@Test
8483
public void count_boolean_field_expression() {
8584
ExprValue result = aggregation(DSL.count(DSL.ref("boolean_value", BOOLEAN)), tuples);
86-
assertEquals(1, result.value());
85+
assertEquals(1L, result.value());
8786
}
8887

8988
@Test
9089
public void count_struct_field_expression() {
9190
ExprValue result = aggregation(DSL.count(DSL.ref("struct_value", STRUCT)), tuples);
92-
assertEquals(1, result.value());
91+
assertEquals(1L, result.value());
9392
}
9493

9594
@Test
9695
public void count_array_field_expression() {
9796
ExprValue result = aggregation(DSL.count(DSL.ref("array_value", ARRAY)), tuples);
98-
assertEquals(1, result.value());
97+
assertEquals(1L, result.value());
9998
}
10099

101100
@Test
@@ -105,14 +104,14 @@ public void filtered_count() {
105104
DSL.count(DSL.ref("integer_value", INTEGER))
106105
.condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
107106
tuples);
108-
assertEquals(3, result.value());
107+
assertEquals(3L, result.value());
109108
}
110109

111110
@Test
112111
public void distinct_count() {
113112
ExprValue result =
114113
aggregation(DSL.distinctCount(DSL.ref("integer_value", INTEGER)), tuples_with_duplicates);
115-
assertEquals(3, result.value());
114+
assertEquals(3L, result.value());
116115
}
117116

118117
@Test
@@ -122,47 +121,47 @@ public void filtered_distinct_count() {
122121
DSL.distinctCount(DSL.ref("integer_value", INTEGER))
123122
.condition(DSL.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
124123
tuples_with_duplicates);
125-
assertEquals(2, result.value());
124+
assertEquals(2L, result.value());
126125
}
127126

128127
@Test
129128
public void distinct_count_map() {
130129
ExprValue result =
131130
aggregation(DSL.distinctCount(DSL.ref("struct_value", STRUCT)), tuples_with_duplicates);
132-
assertEquals(3, result.value());
131+
assertEquals(3L, result.value());
133132
}
134133

135134
@Test
136135
public void distinct_count_array() {
137136
ExprValue result =
138137
aggregation(DSL.distinctCount(DSL.ref("array_value", ARRAY)), tuples_with_duplicates);
139-
assertEquals(3, result.value());
138+
assertEquals(3L, result.value());
140139
}
141140

142141
@Test
143142
public void count_with_missing() {
144143
ExprValue result =
145144
aggregation(DSL.count(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing);
146-
assertEquals(2, result.value());
145+
assertEquals(2L, result.value());
147146
}
148147

149148
@Test
150149
public void count_with_null() {
151150
ExprValue result =
152151
aggregation(DSL.count(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing);
153-
assertEquals(2, result.value());
152+
assertEquals(2L, result.value());
154153
}
155154

156155
@Test
157156
public void count_star_with_null_and_missing() {
158157
ExprValue result = aggregation(DSL.count(DSL.literal("*")), tuples_with_null_and_missing);
159-
assertEquals(3, result.value());
158+
assertEquals(3L, result.value());
160159
}
161160

162161
@Test
163162
public void count_literal_with_null_and_missing() {
164163
ExprValue result = aggregation(DSL.count(DSL.literal(1)), tuples_with_null_and_missing);
165-
assertEquals(3, result.value());
164+
assertEquals(3L, result.value());
166165
}
167166

168167
@Test

integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,21 @@ private void verifySQLQueries(String endpoint) throws IOException {
187187
executeSQLQuery(
188188
endpoint,
189189
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TestsConstants.TEST_INDEX_ACCOUNT);
190-
verifySchema(filterResponse, schema("COUNT(*) FILTER(WHERE age > 35)", null, "integer"));
190+
// Accept both integer and long types for backwards compatibility
191+
String actualType =
192+
(String) filterResponse.getJSONArray("schema").getJSONObject(0).query("/type");
193+
String expectedType = actualType.equals("integer") ? "integer" : "long";
194+
verifySchema(filterResponse, schema("COUNT(*) FILTER(WHERE age > 35)", null, expectedType));
191195
verifyDataRows(filterResponse, rows(238));
192196

193197
JSONObject aggResponse =
194198
executeSQLQuery(
195199
endpoint, "SELECT COUNT(DISTINCT age) FROM " + TestsConstants.TEST_INDEX_ACCOUNT);
196-
verifySchema(aggResponse, schema("COUNT(DISTINCT age)", null, "integer"));
200+
// Accept both integer and long types for backwards compatibility
201+
String actualType2 =
202+
(String) aggResponse.getJSONArray("schema").getJSONObject(0).query("/type");
203+
String expectedType2 = actualType2.equals("integer") ? "integer" : "long";
204+
verifySchema(aggResponse, schema("COUNT(DISTINCT age)", null, expectedType2));
197205
verifyDataRows(aggResponse, rows(21));
198206

199207
JSONObject groupByResponse =

integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ public class DBResult {
4040
*/
4141
private static final Set<String> VARCHAR = ImmutableSet.of("CHARACTER VARYING", "VARCHAR");
4242

43+
/**
44+
* Possible types for integer numbers.<br>
45+
* Different databases may return INTEGER or BIGINT for count operations.
46+
*/
47+
private static final Set<String> INTEGER_TYPES = ImmutableSet.of("INTEGER", "BIGINT");
48+
4349
/** Database name for display */
4450
private final String databaseName;
4551

@@ -80,6 +86,8 @@ public void addColumn(String name, String type) {
8086
type = FLOAT_TYPES.toString();
8187
} else if (VARCHAR.contains(type)) {
8288
type = "VARCHAR";
89+
} else if (INTEGER_TYPES.contains(type)) {
90+
type = INTEGER_TYPES.toString();
8391
}
8492
schema.add(new Type(StringUtils.toUpper(name), type));
8593
}

integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/Row.java

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@
1010
import java.util.ArrayList;
1111
import java.util.Collection;
1212
import java.util.List;
13-
import lombok.EqualsAndHashCode;
13+
import java.util.Objects;
1414
import lombok.Getter;
1515
import lombok.ToString;
1616

1717
/** Row in result set. */
18-
@EqualsAndHashCode
1918
@ToString
2019
@Getter
2120
public class Row implements Comparable<Row> {
2221

2322
private final Collection<Object> values;
2423

2524
public Row() {
26-
this(new ArrayList<>()); // values in order by default
25+
this(new ArrayList<>());
2726
}
2827

2928
public Row(Collection<Object> values) {
@@ -37,7 +36,7 @@ public void add(Object value) {
3736
private Object roundFloatNum(Object value) {
3837
if (value instanceof Float) {
3938
BigDecimal decimal = BigDecimal.valueOf((Float) value).setScale(2, RoundingMode.CEILING);
40-
value = decimal.doubleValue(); // Convert to double too
39+
value = decimal.doubleValue();
4140
} else if (value instanceof Double) {
4241
BigDecimal decimal = BigDecimal.valueOf((Double) value).setScale(2, RoundingMode.CEILING);
4342
value = decimal.doubleValue();
@@ -70,8 +69,54 @@ public int compareTo(Row other) {
7069
if (result != 0) {
7170
return result;
7271
}
73-
} // Ignore incomparable field silently?
72+
}
7473
}
7574
return 0;
7675
}
76+
77+
@Override
78+
public boolean equals(Object o) {
79+
if (this == o) return true;
80+
if (!(o instanceof Row)) return false;
81+
Row other = (Row) o;
82+
return valuesEqual(this.values, other.values);
83+
}
84+
85+
private boolean valuesEqual(Collection<Object> values1, Collection<Object> values2) {
86+
if (values1.size() != values2.size()) return false;
87+
88+
List<Object> list1 = new ArrayList<>(values1);
89+
List<Object> list2 = new ArrayList<>(values2);
90+
91+
for (int i = 0; i < list1.size(); i++) {
92+
if (!isValueEqual(list1.get(i), list2.get(i))) {
93+
return false;
94+
}
95+
}
96+
return true;
97+
}
98+
99+
private boolean isValueEqual(Object val1, Object val2) {
100+
if (Objects.equals(val1, val2)) return true;
101+
102+
if (isIntegerOrLong(val1) && isIntegerOrLong(val2)) {
103+
return ((Number) val1).longValue() == ((Number) val2).longValue();
104+
}
105+
106+
return false;
107+
}
108+
109+
private boolean isIntegerOrLong(Object value) {
110+
return value instanceof Integer || value instanceof Long;
111+
}
112+
113+
@Override
114+
public int hashCode() {
115+
116+
List<Object> normalizedValues = new ArrayList<>();
117+
for (Object value : values) {
118+
normalizedValues.add(value instanceof Integer ? ((Integer) value).longValue() : value);
119+
}
120+
return normalizedValues.hashCode();
121+
}
77122
}

integ-test/src/test/java/org/opensearch/sql/legacy/AggregationExpressionIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ public void groupByDateShouldPass() {
204204
Index.BANK.getName()));
205205

206206
verifySchema(
207-
response, schema("birthdate", null, "timestamp"), schema("count(*)", "count", "integer"));
207+
response, schema("birthdate", null, "timestamp"), schema("count(*)", "count", "long"));
208208
verifyDataRows(response, rows("2018-06-23 00:00:00", 1));
209209
}
210210

@@ -220,9 +220,7 @@ public void groupByDateWithAliasShouldPass() {
220220
Index.BANK.getName()));
221221

222222
verifySchema(
223-
response,
224-
schema("birthdate", "birth", "timestamp"),
225-
schema("count(*)", "count", "integer"));
223+
response, schema("birthdate", "birth", "timestamp"), schema("count(*)", "count", "long"));
226224
verifyDataRows(response, rows("2018-06-23 00:00:00", 1));
227225
}
228226

integ-test/src/test/java/org/opensearch/sql/ppl/DateTimeImplementationIT.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,7 @@ public void testSpanDatetimeWithCustomFormat() throws IOException {
188188
String.format(
189189
"source=%s | eval a = 1 | stats count() as cnt by span(yyyy-MM-dd, 1d) as span",
190190
TEST_INDEX_DATE_FORMATS));
191-
verifySchema(
192-
result,
193-
isCalciteEnabled() ? schema("cnt", null, "bigint") : schema("cnt", null, "int"),
194-
schema("span", null, "date"));
191+
verifySchema(result, schema("cnt", null, "bigint"), schema("span", null, "date"));
195192
verifyDataRows(result, rows(2, "1984-04-12"));
196193
}
197194

@@ -202,10 +199,7 @@ public void testSpanDatetimeWithEpochMillisFormat() throws IOException {
202199
String.format(
203200
"source=%s | eval a = 1 | stats count() as cnt by span(epoch_millis, 1d) as span",
204201
TEST_INDEX_DATE_FORMATS));
205-
verifySchema(
206-
result,
207-
isCalciteEnabled() ? schema("cnt", null, "bigint") : schema("cnt", null, "int"),
208-
schema("span", null, "timestamp"));
202+
verifySchema(result, schema("cnt", null, "bigint"), schema("span", null, "timestamp"));
209203
verifyDataRows(result, rows(2, "1984-04-12 00:00:00"));
210204
}
211205

@@ -217,10 +211,7 @@ public void testSpanDatetimeWithDisjunctiveDifferentFormats() throws IOException
217211
"source=%s | eval a = 1 | stats count() as cnt by span(yyyy-MM-dd_OR_epoch_millis,"
218212
+ " 1d) as span",
219213
TEST_INDEX_DATE_FORMATS));
220-
verifySchema(
221-
result,
222-
isCalciteEnabled() ? schema("cnt", null, "bigint") : schema("cnt", null, "int"),
223-
schema("span", null, "timestamp"));
214+
verifySchema(result, schema("cnt", null, "bigint"), schema("span", null, "timestamp"));
224215
verifyDataRows(result, rows(2, "1984-04-12 00:00:00"));
225216
}
226217
}

integ-test/src/test/java/org/opensearch/sql/ppl/ObjectFieldOperateIT.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@ public void group_object_field_in_stats() throws IOException {
5252
JSONObject result =
5353
executeQuery(
5454
String.format("source=%s | stats count() by city.name", TEST_INDEX_DEEP_NESTED));
55-
verifySchema(
56-
result,
57-
schema("count()", isCalciteEnabled() ? "bigint" : "int"),
58-
schema("city.name", "string"));
55+
verifySchema(result, schema("count()", "bigint"), schema("city.name", "string"));
5956
verifyDataRows(result, rows(1, "Seattle"));
6057
}
6158

0 commit comments

Comments
 (0)