Skip to content

Commit d9bb1b6

Browse files
committed
Use SparkUtils to improve SparkProcessingEngine
1 parent 843e8bc commit d9bb1b6

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

vtl-spark/src/main/java/fr/insee/vtl/spark/SparkProcessingEngine.java

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -72,23 +72,24 @@ private static Column convertAggregation(String columnName, AggregationExpressio
7272
throws UnsupportedOperationException {
7373
Column column;
7474
if (expression instanceof MinAggregationExpression) {
75-
column = min(columnName);
75+
column = min(SparkUtils.safeCol(columnName));
7676
} else if (expression instanceof MaxAggregationExpression) {
77-
column = max(columnName);
77+
column = max(SparkUtils.safeCol(columnName));
7878
} else if (expression instanceof AverageAggregationExpression) {
79-
column = avg(columnName);
79+
column = avg(SparkUtils.safeCol(columnName));
8080
} else if (expression instanceof SumAggregationExpression) {
81-
column = sum(columnName);
81+
column = sum(SparkUtils.safeCol(columnName));
8282
} else if (expression instanceof CountAggregationExpression) {
8383
column = count("*");
8484
} else if (expression instanceof MedianAggregationExpression) {
85-
column = percentile_approx(col(columnName), lit(0.5), lit(DEFAULT_MEDIAN_ACCURACY));
85+
column =
86+
percentile_approx(SparkUtils.safeCol(columnName), lit(0.5), lit(DEFAULT_MEDIAN_ACCURACY));
8687
} else if (expression instanceof StdDevSampAggregationExpression) {
87-
column = stddev_samp(columnName);
88+
column = stddev_samp(SparkUtils.safeCol(columnName));
8889
} else if (expression instanceof VarPopAggregationExpression) {
89-
column = var_pop(columnName);
90+
column = var_pop(SparkUtils.safeCol(columnName));
9091
} else if (expression instanceof VarSampAggregationExpression) {
91-
column = var_samp(columnName);
92+
column = var_samp(SparkUtils.safeCol(columnName));
9293
} else {
9394
throw new UnsupportedOperationException("unknown aggregation " + expression.getClass());
9495
}
@@ -132,7 +133,7 @@ private static WindowSpec buildWindowSpec(
132133
public static Seq<Column> colNameToCol(List<String> inputColNames) {
133134
List<Column> cols = new ArrayList<>();
134135
for (String colName : inputColNames) {
135-
cols.add(col(colName));
136+
cols.add(SparkUtils.safeCol(colName));
136137
}
137138
return JavaConverters.asScalaIteratorConverter(cols.iterator()).asScala().toSeq();
138139
}
@@ -142,9 +143,9 @@ public static Seq<Column> buildOrderCol(Map<String, Analytics.Order> orderCols)
142143
List<Column> orders = new ArrayList<>();
143144
for (Map.Entry<String, Analytics.Order> entry : orderCols.entrySet()) {
144145
if (entry.getValue().equals(Analytics.Order.DESC)) {
145-
orders.add(col(entry.getKey()).desc());
146+
orders.add(SparkUtils.safeCol(entry.getKey()).desc());
146147
} else {
147-
orders.add(col(entry.getKey()));
148+
orders.add(SparkUtils.safeCol(entry.getKey()));
148149
}
149150
}
150151
return JavaConverters.asScalaIteratorConverter(orders.iterator()).asScala().toSeq();
@@ -282,9 +283,9 @@ public Dataset<Row> rename(Dataset<Row> dataset, Map<String, String> fromTo) {
282283
List<Column> columns = new ArrayList<>();
283284
for (String name : dataset.columns()) {
284285
if (fromTo.containsKey(name)) {
285-
columns.add(col(name).as(fromTo.get(name)));
286+
columns.add(SparkUtils.safeCol(name).as(fromTo.get(name)));
286287
} else if (!fromTo.containsValue(name)) {
287-
columns.add(col(name));
288+
columns.add(SparkUtils.safeCol(name));
288289
}
289290
}
290291
return dataset.select(iterableAsScalaIterable(columns).toSeq());
@@ -366,7 +367,7 @@ public DatasetExpression executeAggr(
366367
.map(e -> convertAggregation(e.getKey(), e.getValue()))
367368
.collect(Collectors.toList());
368369
List<Column> groupByColumns =
369-
groupBy.stream().map(name -> col(name)).collect(Collectors.toList());
370+
groupBy.stream().map(SparkUtils::safeCol).collect(Collectors.toList());
370371
Dataset<Row> result =
371372
sparkDataset
372373
.getSparkDataset()
@@ -395,22 +396,23 @@ public DatasetExpression executeSimpleAnalytic(
395396
// step 2: call analytic func on window spec
396397
// 2.1 get all measurement column
397398

399+
Column safeCol = SparkUtils.safeCol(sourceColName);
400+
398401
Column column =
399402
switch (function) {
400-
case COUNT -> count(sourceColName).over(windowSpec);
401-
case SUM -> sum(sourceColName).over(windowSpec);
402-
case MIN -> min(sourceColName).over(windowSpec);
403-
case MAX -> max(sourceColName).over(windowSpec);
404-
case AVG -> avg(sourceColName).over(windowSpec);
403+
case COUNT -> count(safeCol).over(windowSpec);
404+
case SUM -> sum(safeCol).over(windowSpec);
405+
case MIN -> min(safeCol).over(windowSpec);
406+
case MAX -> max(safeCol).over(windowSpec);
407+
case AVG -> avg(safeCol).over(windowSpec);
405408
case MEDIAN ->
406-
percentile_approx(col(sourceColName), lit(0.5), lit(DEFAULT_MEDIAN_ACCURACY))
407-
.over(windowSpec);
408-
case STDDEV_POP -> stddev_pop(sourceColName).over(windowSpec);
409-
case STDDEV_SAMP -> stddev_samp(sourceColName).over(windowSpec);
410-
case VAR_POP -> var_pop(sourceColName).over(windowSpec);
411-
case VAR_SAMP -> var_samp(sourceColName).over(windowSpec);
412-
case FIRST_VALUE -> first(sourceColName).over(windowSpec);
413-
case LAST_VALUE -> last(sourceColName).over(windowSpec);
409+
percentile_approx(safeCol, lit(0.5), lit(DEFAULT_MEDIAN_ACCURACY)).over(windowSpec);
410+
case STDDEV_POP -> stddev_pop(safeCol).over(windowSpec);
411+
case STDDEV_SAMP -> stddev_samp(safeCol).over(windowSpec);
412+
case VAR_POP -> var_pop(safeCol).over(windowSpec);
413+
case VAR_SAMP -> var_samp(safeCol).over(windowSpec);
414+
case FIRST_VALUE -> first(safeCol).over(windowSpec);
415+
case LAST_VALUE -> last(safeCol).over(windowSpec);
414416
default -> throw UNKNOWN_ANALYTIC_FUNCTION;
415417
};
416418
var result = sparkDataset.getSparkDataset().withColumn(targetColName, column);
@@ -461,8 +463,8 @@ public DatasetExpression executeRatioToReportAn(
461463
Dataset<Row> result =
462464
sparkDataset
463465
.getSparkDataset()
464-
.withColumn(totalColName, sum(sourceColName).over(windowSpec))
465-
.withColumn(targetColName, col(sourceColName).divide(col(totalColName)))
466+
.withColumn(totalColName, sum(SparkUtils.safeCol(sourceColName)).over(windowSpec))
467+
.withColumn(targetColName, SparkUtils.safeCol(sourceColName).divide(col(totalColName)))
466468
.drop(totalColName);
467469
// 2.3 without the calc clause, we need to overwrite the measure columns with the result column
468470
return new SparkDatasetExpression(new SparkDataset(result), dataset);
@@ -1101,11 +1103,15 @@ public DatasetExpression executePivot(
11011103
List<String> groupByIdentifiers = new ArrayList<>(dsExpr.getIdentifierNames());
11021104
groupByIdentifiers.remove(idName);
11031105

1104-
Column[] groupByCols = groupByIdentifiers.stream().map(functions::col).toArray(Column[]::new);
1106+
Column[] groupByCols =
1107+
groupByIdentifiers.stream().map(SparkUtils::safeCol).toArray(Column[]::new);
11051108

11061109
// TODO: fail if any values needs to be aggregated
11071110
Dataset<Row> result =
1108-
sparkDataset.groupBy(groupByCols).pivot(idName).agg(functions.first(meName));
1111+
sparkDataset
1112+
.groupBy(groupByCols)
1113+
.pivot(SparkUtils.safeCol(idName))
1114+
.agg(functions.first(meName));
11091115

11101116
return new SparkDatasetExpression(new SparkDataset(result), pos);
11111117
}

vtl-spark/src/test/java/fr/insee/vtl/spark/processing.engine/analytic/AnalyticFirstTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ public void testAnFirstPartitionOrderByDesc() throws ScriptException {
8181
+----+----+----+----+----+
8282
* */
8383
ScriptContext context = engine.getContext();
84-
context.setAttribute("ds2", ds2, ScriptContext.ENGINE_SCOPE);
84+
// will also test dot espacing
85+
context.setAttribute("ds.ds2", ds2, ScriptContext.ENGINE_SCOPE);
8586

86-
engine.eval("res := first_value ( ds2 over ( partition by Id_1, Id_2 order by Year desc) );");
87+
engine.eval("res := first_value ( ds.ds2 over ( partition by Id_1, Id_2 order by Year desc) );");
8788
assertThat(engine.getContext().getAttribute("res")).isInstanceOf(Dataset.class);
8889

8990
/*

0 commit comments

Comments
 (0)