diff --git a/Sources/SparkConnect/DataFrameWriterV2.swift b/Sources/SparkConnect/DataFrameWriterV2.swift index b2ad861..55b8503 100644 --- a/Sources/SparkConnect/DataFrameWriterV2.swift +++ b/Sources/SparkConnect/DataFrameWriterV2.swift @@ -72,11 +72,7 @@ public actor DataFrameWriterV2: Sendable { /// - Parameter columns: Columns to partition /// - Returns: A ``DataFrameWriterV2``. public func partitionBy(_ columns: String...) -> DataFrameWriterV2 { - self.partitioningColumns = columns.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + self.partitioningColumns = columns.map { $0.toExpression } return self } diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index e841fa4..6586cd7 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -126,6 +126,14 @@ extension String { return expression } + var toExpression: Spark_Connect_Expression { + var expressionString = ExpressionString() + expressionString.expression = self + var expression = Spark_Connect_Expression() + expression.expressionString = expressionString + return expression + } + var toExplainMode: ExplainMode { let mode = switch self { case "codegen": ExplainMode.codegen diff --git a/Sources/SparkConnect/GroupedData.swift b/Sources/SparkConnect/GroupedData.swift index a460832..6a09fd5 100644 --- a/Sources/SparkConnect/GroupedData.swift +++ b/Sources/SparkConnect/GroupedData.swift @@ -32,16 +32,8 @@ public actor GroupedData { var aggregate = Aggregate() aggregate.input = await (self.df.getPlan() as! Plan).root aggregate.groupType = self.groupType - aggregate.groupingExpressions = self.groupingCols.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } - aggregate.aggregateExpressions = exprs.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + aggregate.groupingExpressions = self.groupingCols.map { $0.toExpression } + aggregate.aggregateExpressions = exprs.map { $0.toExpression } var relation = Relation() relation.aggregate = aggregate var plan = Plan() diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index cb1c2e1..5c48c55 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -508,11 +508,7 @@ public actor SparkConnectClient { static func getProjectExprs(_ child: Relation, _ exprs: [String]) -> Plan { var project = Project() project.input = child - let expressions: [Spark_Connect_Expression] = exprs.map { - var expression = Spark_Connect_Expression() - expression.exprType = .expressionString($0.toExpressionString) - return expression - } + let expressions: [Spark_Connect_Expression] = exprs.map { $0.toExpression } project.expressions = expressions var relation = Relation() relation.project = project @@ -908,11 +904,7 @@ public actor SparkConnectClient { ) -> Plan { var repartitionByExpression = RepartitionByExpression() repartitionByExpression.input = child - repartitionByExpression.partitionExprs = partitionExprs.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + repartitionByExpression.partitionExprs = partitionExprs.map { $0.toExpression } if let numPartitions { repartitionByExpression.numPartitions = numPartitions } @@ -932,18 +924,10 @@ public actor SparkConnectClient { ) -> Plan { var unpivot = Spark_Connect_Unpivot() unpivot.input = child - unpivot.ids = ids.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + unpivot.ids = ids.map { $0.toExpression } if let values { var unpivotValues = Spark_Connect_Unpivot.Values() - unpivotValues.values = values.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + unpivotValues.values = values.map { $0.toExpression } unpivot.values = unpivotValues } unpivot.variableColumnName = variableColumnName @@ -958,11 +942,7 @@ public actor SparkConnectClient { static func getTranspose(_ child: Relation, _ indexColumn: [String]) -> Plan { var transpose = Spark_Connect_Transpose() transpose.input = child - transpose.indexColumns = indexColumn.map { - var expr = Spark_Connect_Expression() - expr.expressionString = $0.toExpressionString - return expr - } + transpose.indexColumns = indexColumn.map { $0.toExpression } var relation = Relation() relation.transpose = transpose var plan = Plan()