File tree Expand file tree Collapse file tree 4 files changed +47
-9
lines changed
Expand file tree Collapse file tree 4 files changed +47
-9
lines changed Original file line number Diff line number Diff line change @@ -150,16 +150,13 @@ jobs:
150150 run :
151151 working-directory : ./datajunction-clients/java
152152 steps :
153- - uses : actions/checkout@v3
153+ - uses : actions/checkout@v4
154154 - name : Set up JDK 17
155- uses : actions/setup-java@v3
155+ uses : actions/setup-java@v4
156156 with :
157157 java-version : ' 17'
158158 distribution : ' temurin'
159- - name : Official Gradle Wrapper Validation Action
160- uses : gradle/actions/wrapper-validation@v3
159+ - name : Setup Gradle
160+ uses : gradle/actions/setup-gradle@v4
161161 - name : Build with Gradle
162- uses : gradle/actions/setup-gradle@v3
163- with :
164- arguments : build -x test
165- build-root-directory : ./datajunction-clients/java
162+ run : ./gradlew build -x test
Original file line number Diff line number Diff line change @@ -22,8 +22,11 @@ def __init__(self, query_ast: ast.Query):
2222 self .handlers = {
2323 dj_functions .Sum : self ._simple_associative_agg ,
2424 dj_functions .Count : self ._simple_associative_agg ,
25+ dj_functions .CountIf : self ._simple_associative_agg ,
2526 dj_functions .Max : self ._simple_associative_agg ,
27+ dj_functions .MaxBy : self ._simple_associative_agg ,
2628 dj_functions .Min : self ._simple_associative_agg ,
29+ dj_functions .MinBy : self ._simple_associative_agg ,
2730 dj_functions .Avg : self ._avg ,
2831 dj_functions .AnyValue : self ._simple_associative_agg ,
2932 }
@@ -158,7 +161,7 @@ def update_ast(func, measures: list[Measure]):
158161 ),
159162 )
160163 elif (
161- func .function () == dj_functions .Count
164+ func .function () in ( dj_functions .Count , dj_functions . CountIf )
162165 and func .quantifier != ast .SetQuantifier .Distinct
163166 ):
164167 func .name .name = "SUM"
Original file line number Diff line number Diff line change @@ -1222,6 +1222,8 @@ class CountIf(Function):
12221222 count_if(expr) - Returns the number of true values in expr.
12231223 """
12241224
1225+ is_aggregation = True
1226+
12251227
12261228@CountIf .register # type: ignore
12271229def infer_type (arg : ct .BooleanType ) -> ct .IntegerType :
@@ -3244,6 +3246,8 @@ class MaxBy(Function):
32443246 max_by(val, key) - Returns the value of val corresponding to the maximum value of key.
32453247 """
32463248
3249+ is_aggregation = True
3250+
32473251
32483252@MaxBy .register # type: ignore
32493253def infer_type (val : ct .ColumnType , key : ct .ColumnType ) -> ct .ColumnType :
@@ -3334,6 +3338,8 @@ class MinBy(Function):
33343338 min_by(val, key) - Returns the value of val corresponding to the minimum value of key.
33353339 """
33363340
3341+ is_aggregation = True
3342+
33373343
33383344@MinBy .register # type: ignore
33393345def infer_type (val : ct .ColumnType , key : ct .ColumnType ) -> ct .ColumnType :
Original file line number Diff line number Diff line change @@ -618,6 +618,38 @@ def test_unsupported_aggregation_function():
618618 )
619619
620620
621+ def test_count_if ():
622+ """
623+ Test decomposition for count_if.
624+ """
625+ extractor = MeasureExtractor .from_query_string (
626+ "SELECT CAST(COUNT_IF(ARRAY_CONTAINS(field_a, 'xyz')) AS FLOAT) / COUNT(*) "
627+ "FROM parent_node" ,
628+ )
629+ measures , derived_sql = extractor .extract ()
630+ expected_measures = [
631+ Measure (
632+ name = "field_a_count_if_c1f2ed10" ,
633+ expression = "ARRAY_CONTAINS(field_a, 'xyz')" ,
634+ aggregation = "COUNT_IF" ,
635+ rule = AggregationRule (type = Aggregability .FULL ),
636+ ),
637+ Measure (
638+ name = "count_3389dae3" ,
639+ expression = "*" ,
640+ aggregation = "COUNT" ,
641+ rule = AggregationRule (type = Aggregability .FULL ),
642+ ),
643+ ]
644+ assert measures == expected_measures
645+ assert str (derived_sql ) == str (
646+ parse (
647+ "SELECT CAST(SUM(field_a_count_if_c1f2ed10) AS FLOAT) / SUM(count_3389dae3) "
648+ "FROM parent_node" ,
649+ ),
650+ )
651+
652+
621653def test_metric_query_with_aliases ():
622654 """
623655 Test behavior when the query contains unsupported aggregation functions. We just return an
You can’t perform that action at this time.
0 commit comments