Skip to content

Commit 26a8a69

Browse files
committed
HHH-18795 Add JSON aggregate support for CockroachDB
1 parent 243306d commit 26a8a69

File tree

6 files changed

+342
-5
lines changed

6 files changed

+342
-5
lines changed

hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/CockroachLegacyDialect.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.hibernate.boot.model.FunctionContributions;
2727
import org.hibernate.boot.model.TypeContributions;
2828
import org.hibernate.dialect.*;
29+
import org.hibernate.dialect.aggregate.AggregateSupport;
30+
import org.hibernate.dialect.aggregate.CockroachDBAggregateSupport;
2931
import org.hibernate.dialect.function.CommonFunctionFactory;
3032
import org.hibernate.dialect.function.FormatFunction;
3133
import org.hibernate.dialect.function.PostgreSQLTruncFunction;
@@ -699,6 +701,11 @@ public NationalizationSupport getNationalizationSupport() {
699701
return NationalizationSupport.IMPLICIT;
700702
}
701703

704+
@Override
705+
public AggregateSupport getAggregateSupport() {
706+
return CockroachDBAggregateSupport.valueOf( this );
707+
}
708+
702709
@Override
703710
public int getMaxIdentifierLength() {
704711
return 63;

hibernate-community-dialects/src/main/java/org/hibernate/community/dialect/H2LegacyDialect.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.hibernate.boot.model.FunctionContributions;
2121
import org.hibernate.boot.model.TypeContributions;
2222
import org.hibernate.dialect.*;
23+
import org.hibernate.dialect.aggregate.AggregateSupport;
24+
import org.hibernate.dialect.aggregate.H2AggregateSupport;
2325
import org.hibernate.dialect.function.CommonFunctionFactory;
2426
import org.hibernate.dialect.identity.H2FinalTableIdentityColumnSupport;
2527
import org.hibernate.dialect.identity.H2IdentityColumnSupport;
@@ -301,6 +303,11 @@ public void contributeTypes(TypeContributions typeContributions, ServiceRegistry
301303
jdbcTypeRegistry.addDescriptor( OrdinalEnumJdbcType.INSTANCE );
302304
}
303305

306+
@Override
307+
public AggregateSupport getAggregateSupport() {
308+
return H2AggregateSupport.valueOf( this );
309+
}
310+
304311
@Override
305312
public int getDefaultStatementBatchSize() {
306313
return 15;

hibernate-core/src/main/java/org/hibernate/dialect/CockroachDialect.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.hibernate.QueryTimeoutException;
2727
import org.hibernate.boot.model.FunctionContributions;
2828
import org.hibernate.boot.model.TypeContributions;
29+
import org.hibernate.dialect.aggregate.AggregateSupport;
30+
import org.hibernate.dialect.aggregate.CockroachDBAggregateSupport;
2931
import org.hibernate.dialect.function.CommonFunctionFactory;
3032
import org.hibernate.dialect.function.FormatFunction;
3133
import org.hibernate.dialect.function.PostgreSQLTruncFunction;
@@ -667,6 +669,11 @@ public NationalizationSupport getNationalizationSupport() {
667669
return NationalizationSupport.IMPLICIT;
668670
}
669671

672+
@Override
673+
public AggregateSupport getAggregateSupport() {
674+
return CockroachDBAggregateSupport.valueOf( this );
675+
}
676+
670677
@Override
671678
public int getMaxIdentifierLength() {
672679
return 63;
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
/*
2+
* SPDX-License-Identifier: LGPL-2.1-or-later
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.dialect.aggregate;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.internal.util.StringHelper;
9+
import org.hibernate.mapping.Column;
10+
import org.hibernate.metamodel.mapping.JdbcMapping;
11+
import org.hibernate.metamodel.mapping.SelectableMapping;
12+
import org.hibernate.metamodel.mapping.SelectablePath;
13+
import org.hibernate.metamodel.mapping.SqlTypedMapping;
14+
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
15+
import org.hibernate.sql.ast.SqlAstTranslator;
16+
import org.hibernate.sql.ast.spi.SqlAppender;
17+
import org.hibernate.type.BasicPluralType;
18+
import org.hibernate.type.spi.TypeConfiguration;
19+
20+
import java.util.LinkedHashMap;
21+
import java.util.Map;
22+
23+
import static org.hibernate.type.SqlTypes.ARRAY;
24+
import static org.hibernate.type.SqlTypes.BIGINT;
25+
import static org.hibernate.type.SqlTypes.BINARY;
26+
import static org.hibernate.type.SqlTypes.BOOLEAN;
27+
import static org.hibernate.type.SqlTypes.DOUBLE;
28+
import static org.hibernate.type.SqlTypes.FLOAT;
29+
import static org.hibernate.type.SqlTypes.INTEGER;
30+
import static org.hibernate.type.SqlTypes.JSON;
31+
import static org.hibernate.type.SqlTypes.JSON_ARRAY;
32+
import static org.hibernate.type.SqlTypes.LONG32VARBINARY;
33+
import static org.hibernate.type.SqlTypes.SMALLINT;
34+
import static org.hibernate.type.SqlTypes.TINYINT;
35+
import static org.hibernate.type.SqlTypes.VARBINARY;
36+
37+
public class CockroachDBAggregateSupport extends AggregateSupportImpl {
38+
39+
private static final AggregateSupport INSTANCE = new CockroachDBAggregateSupport();
40+
41+
public static AggregateSupport valueOf(Dialect dialect) {
42+
return CockroachDBAggregateSupport.INSTANCE;
43+
}
44+
45+
@Override
46+
public String aggregateComponentCustomReadExpression(
47+
String template,
48+
String placeholder,
49+
String aggregateParentReadExpression,
50+
String columnExpression,
51+
int aggregateColumnTypeCode,
52+
SqlTypedMapping column) {
53+
switch ( aggregateColumnTypeCode ) {
54+
case JSON_ARRAY:
55+
case JSON:
56+
switch ( column.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
57+
case JSON:
58+
case JSON_ARRAY:
59+
return template.replace(
60+
placeholder,
61+
aggregateParentReadExpression + "->'" + columnExpression + "'"
62+
);
63+
case BINARY:
64+
case VARBINARY:
65+
case LONG32VARBINARY:
66+
// We encode binary data as hex, so we have to decode here
67+
return template.replace(
68+
placeholder,
69+
"decode(" + aggregateParentReadExpression + "->>'" + columnExpression + "','hex')"
70+
);
71+
case ARRAY:
72+
final BasicPluralType<?, ?> pluralType = (BasicPluralType<?, ?>) column.getJdbcMapping();
73+
switch ( pluralType.getElementType().getJdbcType().getDefaultSqlTypeCode() ) {
74+
case BOOLEAN:
75+
case TINYINT:
76+
case SMALLINT:
77+
case INTEGER:
78+
case BIGINT:
79+
case FLOAT:
80+
case DOUBLE:
81+
// For types that are natively supported in jsonb we can use jsonb_array_elements,
82+
// but note that we can't use that for string types,
83+
// because casting a jsonb[] to text[] will not omit the quotes of the jsonb text values
84+
return template.replace(
85+
placeholder,
86+
"cast(array(select jsonb_array_elements(" + aggregateParentReadExpression + "->'" + columnExpression + "')) as " + column.getColumnDefinition() + ')'
87+
);
88+
case BINARY:
89+
case VARBINARY:
90+
case LONG32VARBINARY:
91+
// We encode binary data as hex, so we have to decode here
92+
return template.replace(
93+
placeholder,
94+
"array(select decode(jsonb_array_elements_text(" + aggregateParentReadExpression + "->'" + columnExpression + "'),'hex'))"
95+
);
96+
default:
97+
return template.replace(
98+
placeholder,
99+
"cast(array(select jsonb_array_elements_text(" + aggregateParentReadExpression + "->'" + columnExpression + "')) as " + column.getColumnDefinition() + ')'
100+
);
101+
}
102+
default:
103+
return template.replace(
104+
placeholder,
105+
"cast(" + aggregateParentReadExpression + "->>'" + columnExpression + "' as " + column.getColumnDefinition() + ')'
106+
);
107+
}
108+
}
109+
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode );
110+
}
111+
112+
private static String jsonCustomWriteExpression(String customWriteExpression, JdbcMapping jdbcMapping) {
113+
final int sqlTypeCode = jdbcMapping.getJdbcType().getDefaultSqlTypeCode();
114+
switch ( sqlTypeCode ) {
115+
case BINARY:
116+
case VARBINARY:
117+
case LONG32VARBINARY:
118+
// We encode binary data as hex
119+
return "to_jsonb(encode(" + customWriteExpression + ",'hex'))";
120+
case ARRAY:
121+
final BasicPluralType<?, ?> pluralType = (BasicPluralType<?, ?>) jdbcMapping;
122+
switch ( pluralType.getElementType().getJdbcType().getDefaultSqlTypeCode() ) {
123+
case BINARY:
124+
case VARBINARY:
125+
case LONG32VARBINARY:
126+
// We encode binary data as hex
127+
return "to_jsonb(array(select encode(unnest(" + customWriteExpression + "),'hex')))";
128+
default:
129+
return "to_jsonb(" + customWriteExpression + ")";
130+
}
131+
default:
132+
return "to_jsonb(" + customWriteExpression + ")";
133+
}
134+
}
135+
136+
@Override
137+
public String aggregateComponentAssignmentExpression(
138+
String aggregateParentAssignmentExpression,
139+
String columnExpression,
140+
int aggregateColumnTypeCode,
141+
Column column) {
142+
switch ( aggregateColumnTypeCode ) {
143+
case JSON:
144+
case JSON_ARRAY:
145+
// For JSON we always have to replace the whole object
146+
return aggregateParentAssignmentExpression;
147+
}
148+
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateColumnTypeCode );
149+
}
150+
151+
@Override
152+
public boolean requiresAggregateCustomWriteExpressionRenderer(int aggregateSqlTypeCode) {
153+
switch ( aggregateSqlTypeCode ) {
154+
case JSON:
155+
return true;
156+
}
157+
return false;
158+
}
159+
160+
@Override
161+
public WriteExpressionRenderer aggregateCustomWriteExpressionRenderer(
162+
SelectableMapping aggregateColumn,
163+
SelectableMapping[] columnsToUpdate,
164+
TypeConfiguration typeConfiguration) {
165+
final int aggregateSqlTypeCode = aggregateColumn.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode();
166+
switch ( aggregateSqlTypeCode ) {
167+
case JSON:
168+
return jsonAggregateColumnWriter( aggregateColumn, columnsToUpdate );
169+
}
170+
throw new IllegalArgumentException( "Unsupported aggregate SQL type: " + aggregateSqlTypeCode );
171+
}
172+
173+
private WriteExpressionRenderer jsonAggregateColumnWriter(
174+
SelectableMapping aggregateColumn,
175+
SelectableMapping[] columns) {
176+
return new RootJsonWriteExpression( aggregateColumn, columns );
177+
}
178+
179+
interface JsonWriteExpression {
180+
void append(
181+
SqlAppender sb,
182+
String path,
183+
SqlAstTranslator<?> translator,
184+
AggregateColumnWriteExpression expression);
185+
}
186+
private static class AggregateJsonWriteExpression implements JsonWriteExpression {
187+
private final LinkedHashMap<String, JsonWriteExpression> subExpressions = new LinkedHashMap<>();
188+
189+
protected void initializeSubExpressions(SelectableMapping[] columns) {
190+
for ( SelectableMapping column : columns ) {
191+
final SelectablePath selectablePath = column.getSelectablePath();
192+
final SelectablePath[] parts = selectablePath.getParts();
193+
AggregateJsonWriteExpression currentAggregate = this;
194+
for ( int i = 1; i < parts.length - 1; i++ ) {
195+
currentAggregate = (AggregateJsonWriteExpression) currentAggregate.subExpressions.computeIfAbsent(
196+
parts[i].getSelectableName(),
197+
k -> new AggregateJsonWriteExpression()
198+
);
199+
}
200+
final String customWriteExpression = column.getWriteExpression();
201+
currentAggregate.subExpressions.put(
202+
parts[parts.length - 1].getSelectableName(),
203+
new BasicJsonWriteExpression(
204+
column,
205+
jsonCustomWriteExpression( customWriteExpression, column.getJdbcMapping() )
206+
)
207+
);
208+
}
209+
}
210+
211+
@Override
212+
public void append(
213+
SqlAppender sb,
214+
String path,
215+
SqlAstTranslator<?> translator,
216+
AggregateColumnWriteExpression expression) {
217+
sb.append( "||jsonb_build_object" );
218+
char separator = '(';
219+
for ( Map.Entry<String, JsonWriteExpression> entry : subExpressions.entrySet() ) {
220+
final String column = entry.getKey();
221+
final JsonWriteExpression value = entry.getValue();
222+
final String subPath = path + "->'" + column + "'";
223+
sb.append( separator );
224+
if ( value instanceof AggregateJsonWriteExpression ) {
225+
sb.append( '\'' );
226+
sb.append( column );
227+
sb.append( "',coalesce(" );
228+
sb.append( subPath );
229+
sb.append( ",'{}')" );
230+
value.append( sb, subPath, translator, expression );
231+
}
232+
else {
233+
value.append( sb, subPath, translator, expression );
234+
}
235+
separator = ',';
236+
}
237+
sb.append( ')' );
238+
}
239+
}
240+
241+
private static class RootJsonWriteExpression extends AggregateJsonWriteExpression
242+
implements WriteExpressionRenderer {
243+
private final boolean nullable;
244+
private final String path;
245+
246+
RootJsonWriteExpression(SelectableMapping aggregateColumn, SelectableMapping[] columns) {
247+
this.nullable = aggregateColumn.isNullable();
248+
this.path = aggregateColumn.getSelectionExpression();
249+
initializeSubExpressions( columns );
250+
}
251+
252+
@Override
253+
public void render(
254+
SqlAppender sqlAppender,
255+
SqlAstTranslator<?> translator,
256+
AggregateColumnWriteExpression aggregateColumnWriteExpression,
257+
String qualifier) {
258+
final String basePath;
259+
if ( qualifier == null || qualifier.isBlank() ) {
260+
basePath = path;
261+
}
262+
else {
263+
basePath = qualifier + "." + path;
264+
}
265+
if ( nullable ) {
266+
sqlAppender.append( "coalesce(" );
267+
sqlAppender.append( basePath );
268+
sqlAppender.append( ",'{}')" );
269+
}
270+
else {
271+
sqlAppender.append( basePath );
272+
}
273+
append( sqlAppender, basePath, translator, aggregateColumnWriteExpression );
274+
}
275+
}
276+
private static class BasicJsonWriteExpression implements JsonWriteExpression {
277+
278+
private final SelectableMapping selectableMapping;
279+
private final String customWriteExpressionStart;
280+
private final String customWriteExpressionEnd;
281+
282+
BasicJsonWriteExpression(SelectableMapping selectableMapping, String customWriteExpression) {
283+
this.selectableMapping = selectableMapping;
284+
if ( customWriteExpression.equals( "?" ) ) {
285+
this.customWriteExpressionStart = "";
286+
this.customWriteExpressionEnd = "";
287+
}
288+
else {
289+
final String[] parts = StringHelper.split( "?", customWriteExpression );
290+
assert parts.length == 2;
291+
this.customWriteExpressionStart = parts[0];
292+
this.customWriteExpressionEnd = parts[1];
293+
}
294+
}
295+
296+
@Override
297+
public void append(
298+
SqlAppender sb,
299+
String path,
300+
SqlAstTranslator<?> translator,
301+
AggregateColumnWriteExpression expression) {
302+
sb.append( '\'' );
303+
sb.append( selectableMapping.getSelectableName() );
304+
sb.append( "'," );
305+
sb.append( customWriteExpressionStart );
306+
// We use NO_UNTYPED here so that expressions which require type inference are casted explicitly,
307+
// since we don't know how the custom write expression looks like where this is embedded,
308+
// so we have to be pessimistic and avoid ambiguities
309+
translator.render( expression.getValueExpression( selectableMapping ), SqlAstNodeRenderingMode.NO_UNTYPED );
310+
sb.append( customWriteExpressionEnd );
311+
}
312+
}
313+
314+
}

hibernate-core/src/main/java/org/hibernate/dialect/aggregate/H2AggregateSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public class H2AggregateSupport extends AggregateSupportImpl {
3838
public static @Nullable AggregateSupport valueOf(Dialect dialect) {
3939
return dialect.getVersion().isSameOrAfter( 2, 2, 220 )
4040
? H2AggregateSupport.INSTANCE
41-
: null;
41+
: AggregateSupportImpl.INSTANCE;
4242
}
4343

4444
@Override

0 commit comments

Comments
 (0)