Skip to content

Commit 18785ca

Browse files
Fix issue with column rename for column with special chars (#136)
1 parent 0e13822 commit 18785ca

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

spark_utils/dataframes/functions.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""
2424
Helper functions for Spark Dataframes
2525
"""
26-
26+
import re
2727
from typing import List, Iterator, Tuple
2828

2929
from datetime import datetime
@@ -130,25 +130,15 @@ def rename_column(name: str) -> str:
130130
:return:
131131
"""
132132

133-
illegals = [
134-
" ",
135-
",",
136-
";",
137-
"{",
138-
"}",
139-
"(",
140-
")",
141-
"\t",
142-
"=",
143-
"/",
144-
"\\",
145-
".",
146-
]
133+
return re.sub(r"\W+", "", name)
134+
147135

148-
for illegal in illegals:
149-
name = name.replace(illegal, "")
136+
def safe_encode(column_name: str) -> str:
137+
"""
138+
Adds `` around the column name so columns with unsupported chars are resolved
139+
"""
150140

151-
return name
141+
return f"`{column_name}`"
152142

153143

154144
def rename_columns(dataframe: DataFrame) -> DataFrame:
@@ -158,7 +148,7 @@ def rename_columns(dataframe: DataFrame) -> DataFrame:
158148
:param dataframe: Source dataframe
159149
:return: Dataframe with renamed columns
160150
"""
161-
return dataframe.select([col(c).alias(rename_column(c)) for c in dataframe.columns])
151+
return dataframe.select([col(safe_encode(c)).alias(rename_column(c)) for c in dataframe.columns])
162152

163153

164154
def _max_timestamp(dataframe: DataFrame, timestamp_column: str, timestamp_column_format: str) -> datetime:

test/test_common_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pyspark.sql import DataFrame
88
from pyspark.sql import SparkSession
99

10+
from spark_utils.dataframes.functions import rename_column
1011
from spark_utils.models.job_socket import JobSocket
1112
from spark_utils.common.functions import read_from_socket, write_to_socket
1213

@@ -97,3 +98,18 @@ def test_job_socket_serialize(sep: str, test_base_path: str):
9798
)
9899

99100
assert socket.serialize(separator=sep) == f"{socket.alias}{sep}{socket.data_path}{sep}{socket.data_format}"
101+
102+
103+
@pytest.mark.parametrize(
104+
"funky_name, expected_name",
105+
[
106+
("a--bc", "abc"),
107+
(".abc", "abc"),
108+
("a bc", "abc"),
109+
("a\\bc", "abc"),
110+
("a/bc", "abc"),
111+
("a\t{};,bc", "abc"),
112+
],
113+
)
114+
def test_column_rename(funky_name: str, expected_name: str):
115+
assert expected_name == rename_column(funky_name)

0 commit comments

Comments
 (0)