@@ -72,23 +72,24 @@ private static Column convertAggregation(String columnName, AggregationExpressio
7272 throws UnsupportedOperationException {
7373 Column column ;
7474 if (expression instanceof MinAggregationExpression ) {
75- column = min (columnName );
75+ column = min (SparkUtils . safeCol ( columnName ) );
7676 } else if (expression instanceof MaxAggregationExpression ) {
77- column = max (columnName );
77+ column = max (SparkUtils . safeCol ( columnName ) );
7878 } else if (expression instanceof AverageAggregationExpression ) {
79- column = avg (columnName );
79+ column = avg (SparkUtils . safeCol ( columnName ) );
8080 } else if (expression instanceof SumAggregationExpression ) {
81- column = sum (columnName );
81+ column = sum (SparkUtils . safeCol ( columnName ) );
8282 } else if (expression instanceof CountAggregationExpression ) {
8383 column = count ("*" );
8484 } else if (expression instanceof MedianAggregationExpression ) {
85- column = percentile_approx (col (columnName ), lit (0.5 ), lit (DEFAULT_MEDIAN_ACCURACY ));
85+ column =
86+ percentile_approx (SparkUtils .safeCol (columnName ), lit (0.5 ), lit (DEFAULT_MEDIAN_ACCURACY ));
8687 } else if (expression instanceof StdDevSampAggregationExpression ) {
87- column = stddev_samp (columnName );
88+ column = stddev_samp (SparkUtils . safeCol ( columnName ) );
8889 } else if (expression instanceof VarPopAggregationExpression ) {
89- column = var_pop (columnName );
90+ column = var_pop (SparkUtils . safeCol ( columnName ) );
9091 } else if (expression instanceof VarSampAggregationExpression ) {
91- column = var_samp (columnName );
92+ column = var_samp (SparkUtils . safeCol ( columnName ) );
9293 } else {
9394 throw new UnsupportedOperationException ("unknown aggregation " + expression .getClass ());
9495 }
@@ -132,7 +133,7 @@ private static WindowSpec buildWindowSpec(
132133 public static Seq <Column > colNameToCol (List <String > inputColNames ) {
133134 List <Column > cols = new ArrayList <>();
134135 for (String colName : inputColNames ) {
135- cols .add (col (colName ));
136+ cols .add (SparkUtils . safeCol (colName ));
136137 }
137138 return JavaConverters .asScalaIteratorConverter (cols .iterator ()).asScala ().toSeq ();
138139 }
@@ -142,9 +143,9 @@ public static Seq<Column> buildOrderCol(Map<String, Analytics.Order> orderCols)
142143 List <Column > orders = new ArrayList <>();
143144 for (Map .Entry <String , Analytics .Order > entry : orderCols .entrySet ()) {
144145 if (entry .getValue ().equals (Analytics .Order .DESC )) {
145- orders .add (col (entry .getKey ()).desc ());
146+ orders .add (SparkUtils . safeCol (entry .getKey ()).desc ());
146147 } else {
147- orders .add (col (entry .getKey ()));
148+ orders .add (SparkUtils . safeCol (entry .getKey ()));
148149 }
149150 }
150151 return JavaConverters .asScalaIteratorConverter (orders .iterator ()).asScala ().toSeq ();
@@ -282,9 +283,9 @@ public Dataset<Row> rename(Dataset<Row> dataset, Map<String, String> fromTo) {
282283 List <Column > columns = new ArrayList <>();
283284 for (String name : dataset .columns ()) {
284285 if (fromTo .containsKey (name )) {
285- columns .add (col (name ).as (fromTo .get (name )));
286+ columns .add (SparkUtils . safeCol (name ).as (fromTo .get (name )));
286287 } else if (!fromTo .containsValue (name )) {
287- columns .add (col (name ));
288+ columns .add (SparkUtils . safeCol (name ));
288289 }
289290 }
290291 return dataset .select (iterableAsScalaIterable (columns ).toSeq ());
@@ -366,7 +367,7 @@ public DatasetExpression executeAggr(
366367 .map (e -> convertAggregation (e .getKey (), e .getValue ()))
367368 .collect (Collectors .toList ());
368369 List <Column > groupByColumns =
369- groupBy .stream ().map (name -> col ( name ) ).collect (Collectors .toList ());
370+ groupBy .stream ().map (SparkUtils :: safeCol ).collect (Collectors .toList ());
370371 Dataset <Row > result =
371372 sparkDataset
372373 .getSparkDataset ()
@@ -395,22 +396,23 @@ public DatasetExpression executeSimpleAnalytic(
395396 // step 2: call analytic func on window spec
396397 // 2.1 get all measurement column
397398
399+ Column safeCol = SparkUtils .safeCol (sourceColName );
400+
398401 Column column =
399402 switch (function ) {
400- case COUNT -> count (sourceColName ).over (windowSpec );
401- case SUM -> sum (sourceColName ).over (windowSpec );
402- case MIN -> min (sourceColName ).over (windowSpec );
403- case MAX -> max (sourceColName ).over (windowSpec );
404- case AVG -> avg (sourceColName ).over (windowSpec );
403+ case COUNT -> count (safeCol ).over (windowSpec );
404+ case SUM -> sum (safeCol ).over (windowSpec );
405+ case MIN -> min (safeCol ).over (windowSpec );
406+ case MAX -> max (safeCol ).over (windowSpec );
407+ case AVG -> avg (safeCol ).over (windowSpec );
405408 case MEDIAN ->
406- percentile_approx (col (sourceColName ), lit (0.5 ), lit (DEFAULT_MEDIAN_ACCURACY ))
407- .over (windowSpec );
408- case STDDEV_POP -> stddev_pop (sourceColName ).over (windowSpec );
409- case STDDEV_SAMP -> stddev_samp (sourceColName ).over (windowSpec );
410- case VAR_POP -> var_pop (sourceColName ).over (windowSpec );
411- case VAR_SAMP -> var_samp (sourceColName ).over (windowSpec );
412- case FIRST_VALUE -> first (sourceColName ).over (windowSpec );
413- case LAST_VALUE -> last (sourceColName ).over (windowSpec );
409+ percentile_approx (safeCol , lit (0.5 ), lit (DEFAULT_MEDIAN_ACCURACY )).over (windowSpec );
410+ case STDDEV_POP -> stddev_pop (safeCol ).over (windowSpec );
411+ case STDDEV_SAMP -> stddev_samp (safeCol ).over (windowSpec );
412+ case VAR_POP -> var_pop (safeCol ).over (windowSpec );
413+ case VAR_SAMP -> var_samp (safeCol ).over (windowSpec );
414+ case FIRST_VALUE -> first (safeCol ).over (windowSpec );
415+ case LAST_VALUE -> last (safeCol ).over (windowSpec );
414416 default -> throw UNKNOWN_ANALYTIC_FUNCTION ;
415417 };
416418 var result = sparkDataset .getSparkDataset ().withColumn (targetColName , column );
@@ -461,8 +463,8 @@ public DatasetExpression executeRatioToReportAn(
461463 Dataset <Row > result =
462464 sparkDataset
463465 .getSparkDataset ()
464- .withColumn (totalColName , sum (sourceColName ).over (windowSpec ))
465- .withColumn (targetColName , col (sourceColName ).divide (col (totalColName )))
466+ .withColumn (totalColName , sum (SparkUtils . safeCol ( sourceColName ) ).over (windowSpec ))
467+ .withColumn (targetColName , SparkUtils . safeCol (sourceColName ).divide (col (totalColName )))
466468 .drop (totalColName );
467469 // 2.3 without the calc clause, we need to overwrite the measure columns with the result column
468470 return new SparkDatasetExpression (new SparkDataset (result ), dataset );
@@ -1101,11 +1103,15 @@ public DatasetExpression executePivot(
11011103 List <String > groupByIdentifiers = new ArrayList <>(dsExpr .getIdentifierNames ());
11021104 groupByIdentifiers .remove (idName );
11031105
1104- Column [] groupByCols = groupByIdentifiers .stream ().map (functions ::col ).toArray (Column []::new );
1106+ Column [] groupByCols =
1107+ groupByIdentifiers .stream ().map (SparkUtils ::safeCol ).toArray (Column []::new );
11051108
11061109 // TODO: fail if any values needs to be aggregated
11071110 Dataset <Row > result =
1108- sparkDataset .groupBy (groupByCols ).pivot (idName ).agg (functions .first (meName ));
1111+ sparkDataset
1112+ .groupBy (groupByCols )
1113+ .pivot (SparkUtils .safeCol (idName ))
1114+ .agg (functions .first (meName ));
11091115
11101116 return new SparkDatasetExpression (new SparkDataset (result ), pos );
11111117 }
0 commit comments