Skip to content

Commit c653da4

Browse files
author
maxim-lixakov
committed
[DOP-21665] - add spark dialect extension to clickhouse
1 parent 66ea992 commit c653da4

File tree

5 files changed

+10
-17
lines changed

5 files changed

+10
-17
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `spark-dialect-extension <https://github.com/MobileTeleSystems/spark-dialect-extension/>`_

syncmaster/worker/handlers/db/clickhouse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class ClickhouseHandler(DBHandler):
2323
transfer_dto: ClickhouseTransferDTO
2424

2525
def connect(self, spark: SparkSession):
26+
ClickhouseDialectRegistry = (
27+
spark._jvm.io.github.mtsongithub.doetl.sparkdialectextensions.clickhouse.ClickhouseDialectRegistry
28+
)
29+
ClickhouseDialectRegistry.register()
2630
self.connection = Clickhouse(
2731
host=self.connection_dto.host,
2832
port=self.connection_dto.port,

syncmaster/worker/spark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def get_packages(db_type: str) -> list[str]:
5050
if db_type == "oracle":
5151
return Oracle.get_packages()
5252
if db_type == "clickhouse":
53-
# TODO: add https://github.com/MobileTeleSystems/spark-dialect-extension/ to spark jars
54-
return Clickhouse.get_packages()
53+
return [
54+
"io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2",
55+
*Clickhouse.get_packages(),
56+
]
5557
if db_type == "mssql":
5658
return MSSQL.get_packages()
5759
if db_type == "mysql":

tests/test_integration/test_run_transfer/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
7878
maven_packages.extend(Oracle.get_packages())
7979

8080
if "clickhouse" in markers:
81+
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
8182
maven_packages.extend(Clickhouse.get_packages())
8283

8384
if "mssql" in markers:

tests/test_integration/test_run_transfer/test_clickhouse.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from onetl.connection import Clickhouse
77
from onetl.db import DBReader
88
from pyspark.sql import DataFrame
9-
from pyspark.sql.functions import col, date_trunc
109
from sqlalchemy.ext.asyncio import AsyncSession
1110

1211
from syncmaster.db.models import Connection, Group, Queue, Status, Transfer
@@ -117,8 +116,6 @@ async def test_run_transfer_postgres_to_clickhouse(
117116
table=f"{clickhouse.user}.target_table",
118117
)
119118
df = reader.run()
120-
# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
121-
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
122119
for field in init_df.schema:
123120
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
124121

@@ -169,11 +166,6 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming(
169166
assert df.columns != init_df_with_mixed_column_naming.columns
170167
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]
171168

172-
# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
173-
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
174-
"Registered At",
175-
date_trunc("second", col("Registered At")),
176-
)
177169
for field in init_df_with_mixed_column_naming.schema:
178170
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
179171

@@ -222,8 +214,6 @@ async def test_run_transfer_clickhouse_to_postgres(
222214
)
223215
df = reader.run()
224216

225-
# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
226-
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
227217
for field in init_df.schema:
228218
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
229219

@@ -275,11 +265,6 @@ async def test_run_transfer_clickhouse_to_postgres_mixed_naming(
275265
assert df.columns != init_df_with_mixed_column_naming.columns
276266
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]
277267

278-
# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
279-
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
280-
"Registered At",
281-
date_trunc("second", col("Registered At")),
282-
)
283268
for field in init_df_with_mixed_column_naming.schema:
284269
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
285270

0 commit comments

Comments
 (0)