1717 */
1818package org .apache .wayang .api .sql .calcite .converter .functions ;
1919
20- import java .util . Arrays ;
20+ import java .math . BigDecimal ;
2121import java .util .List ;
2222import java .util .Optional ;
2323import java .util .function .BiFunction ;
24- import java .util .stream .Collectors ;
2524
2625import org .apache .calcite .rel .core .AggregateCall ;
2726import org .apache .calcite .runtime .SqlFunctions ;
3130
3231public class AggregateFunction
3332 implements FunctionDescriptor .SerializableBinaryOperator <Record > {
34- final List <SqlKind > aggregateKinds ;
33+ private final List <SqlKind > aggregateKinds ;
3534
3635 public AggregateFunction (final List <AggregateCall > aggregateCalls ) {
3736 this .aggregateKinds = aggregateCalls .stream ()
38- .map (call -> call .getAggregation ().getKind ())
39- . collect ( Collectors . toList () );
37+ .map (call -> call .getAggregation ().getKind ())
38+ . toList ();
4039 }
4140
4241 @ Override
@@ -56,15 +55,15 @@ public Record apply(final Record record1, final Record record2) {
5655
5756 switch (kind ) {
5857 case SUM :
59- resValues [counter ] = this .castAndMap (field1 , field2 , null , Long ::sum , Integer ::sum , Double ::sum );
58+ resValues [counter ] = this .castAndMap (field1 , field2 , null , Long ::sum , Integer ::sum , Double ::sum , BigDecimal :: add );
6059 break ;
6160 case MIN :
6261 resValues [counter ] = this .castAndMap (field1 , field2 , SqlFunctions ::least , SqlFunctions ::least ,
63- SqlFunctions ::least , SqlFunctions ::least );
62+ SqlFunctions ::least , SqlFunctions ::least , SqlFunctions :: least );
6463 break ;
6564 case MAX :
6665 resValues [counter ] = this .castAndMap (field1 , field2 , SqlFunctions ::greatest , SqlFunctions ::greatest ,
67- SqlFunctions ::greatest , SqlFunctions ::greatest );
66+ SqlFunctions ::greatest , SqlFunctions ::greatest , SqlFunctions :: greatest );
6867 break ;
6968 case COUNT :
7069 // since aggregates inject an extra column for counting before,
@@ -76,9 +75,7 @@ public Record apply(final Record record1, final Record record2) {
7675 resValues [counter ] = count ;
7776 break ;
7877 case AVG :
79- assert (field1 instanceof Integer && field2 instanceof Integer )
80- : "Expected to find integers for count but found: " + field1 + " and " + field2 ;
81- final Object avg = Integer .class .cast (field1 ) + Integer .class .cast (field2 );
78+ final Object avg = this .castAndMap (field1 , field2 , null , Long ::sum , Integer ::sum , Double ::sum , BigDecimal ::add );
8279
8380 resValues [counter ] = avg ;
8481
@@ -95,6 +92,7 @@ public Record apply(final Record record1, final Record record2) {
9592 return new Record (resValues );
9693 }
9794
95+
9896 /**
9997 * Handles casts for the record class for each interior type.
10098 *
@@ -110,7 +108,8 @@ private Object castAndMap(final Object a, final Object b,
110108 final BiFunction <String , String , Object > stringMap ,
111109 final BiFunction <Long , Long , Object > longMap ,
112110 final BiFunction <Integer , Integer , Object > integerMap ,
113- final BiFunction <Double , Double , Object > doubleMap ) {
111+ final BiFunction <Double , Double , Object > doubleMap ,
112+ final BiFunction <BigDecimal , BigDecimal , Object > bigDecimalMap ) {
114113 // support operations between null and any
115114 // class
116115 if ((a == null || b == null ) || (a .getClass () == b .getClass ())) {
@@ -122,19 +121,16 @@ private Object castAndMap(final Object a, final Object b,
122121 // force .getClass() to be safe so
123122 // we can pass null objects to
124123 // .apply methods.
125- switch (aWrapped .orElse (bWrapped .orElse ("" )).getClass ().getSimpleName ()) {
126- case "String" :
127- return stringMap .apply ((String ) a , (String ) b );
128- case "Long" :
129- return longMap .apply ((Long ) a , (Long ) b );
130- case "Integer" :
131- return integerMap .apply ((Integer ) a , (Integer ) b );
132- case "Double" :
133- return doubleMap .apply ((Double ) a , (Double ) b );
134- default :
135- throw new IllegalStateException ("Unsupported operation between: " + aWrapped .getClass ().toString ()
136- + " and: " + bWrapped .getClass ().toString ());
137- }
124+ return switch (aWrapped .orElse (bWrapped .orElse ("" )).getClass ().getSimpleName ()) {
125+ case "String" -> stringMap .apply ((String ) a , (String ) b );
126+ case "Long" -> longMap .apply ((Long ) a , (Long ) b );
127+ case "Integer" -> integerMap .apply ((Integer ) a , (Integer ) b );
128+ case "Double" -> doubleMap .apply ((Double ) a , (Double ) b );
129+ case "BigDecimal" -> bigDecimalMap .apply ((BigDecimal ) a , (BigDecimal ) b );
130+ default -> throw new IllegalStateException ("Unsupported operation between: "
131+ + aWrapped .getClass ().toString ()
132+ + " and: " + bWrapped .getClass ().toString ());
133+ };
138134 }
139135 throw new IllegalStateException ("Unsupported operation between: " + a .getClass ().getSimpleName () + " and: "
140136 + b .getClass ().getSimpleName ());
0 commit comments