Skip to content

Commit 525d635

Browse files
Add Excel integration tests (#149)
1 parent 38fc8ad commit 525d635

File tree

7 files changed

+181
-29
lines changed

7 files changed

+181
-29
lines changed

syncmaster/dto/transfers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import ClassVar
66

7-
from onetl.file.format import CSV, JSON, JSONLine
7+
from onetl.file.format import CSV, JSON, Excel, JSONLine
88

99

1010
@dataclass
@@ -20,10 +20,17 @@ class DBTransferDTO(TransferDTO):
2020
@dataclass
2121
class FileTransferDTO(TransferDTO):
2222
directory_path: str
23-
file_format: CSV | JSONLine | JSON
23+
file_format: CSV | JSONLine | JSON | Excel
2424
options: dict
2525
df_schema: dict | None = None
2626

27+
_format_parsers = {
28+
"csv": CSV,
29+
"jsonline": JSONLine,
30+
"json": JSON,
31+
"excel": Excel,
32+
}
33+
2734
def __post_init__(self):
2835
if isinstance(self.file_format, dict):
2936
self.file_format = self._get_format(self.file_format.copy())
@@ -32,13 +39,10 @@ def __post_init__(self):
3239

3340
def _get_format(self, file_format: dict):
3441
file_type = file_format.pop("type", None)
35-
if file_type == "csv":
36-
return CSV.parse_obj(file_format)
37-
if file_type == "jsonline":
38-
return JSONLine.parse_obj(file_format)
39-
if file_type == "json":
40-
return JSON.parse_obj(file_format)
41-
raise ValueError("Unknown file type")
42+
parser_class = self._format_parsers.get(file_type)
43+
if parser_class is not None:
44+
return parser_class.parse_obj(file_format)
45+
raise ValueError(f"Unknown file type: {file_type}")
4246

4347

4448
@dataclass

syncmaster/worker/handlers/file/s3.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from typing import TYPE_CHECKING
77

88
from onetl.connection import SparkS3
9+
from onetl.file import FileDFReader
910

1011
from syncmaster.dto.connections import S3ConnectionDTO
1112
from syncmaster.worker.handlers.file.base import FileHandler
1213

1314
if TYPE_CHECKING:
14-
from pyspark.sql import SparkSession
15+
from pyspark.sql import DataFrame, SparkSession
1516

1617

1718
class S3Handler(FileHandler):
@@ -29,3 +30,20 @@ def connect(self, spark: SparkSession):
2930
extra=self.connection_dto.additional_params,
3031
spark=spark,
3132
).check()
33+
34+
def read(self) -> DataFrame:
35+
from pyspark.sql.types import StructType
36+
37+
options = {}
38+
if self.transfer_dto.file_format.__class__.__name__ == "Excel":
39+
options = {"inferSchema": True}
40+
41+
reader = FileDFReader(
42+
connection=self.connection,
43+
format=self.transfer_dto.file_format,
44+
source_path=self.transfer_dto.directory_path,
45+
df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None,
46+
options={**options, **self.transfer_dto.options},
47+
)
48+
49+
return reader.run()

syncmaster/worker/spark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_worker_spark_session(
3737

3838
def get_packages(db_type: str) -> list[str]:
3939
from onetl.connection import MSSQL, Clickhouse, MySQL, Oracle, Postgres, SparkS3
40+
from onetl.file.format import Excel
4041

4142
if db_type == "postgres":
4243
return Postgres.get_packages()
@@ -53,7 +54,11 @@ def get_packages(db_type: str) -> list[str]:
5354
import pyspark
5455

5556
spark_version = pyspark.__version__
56-
return SparkS3.get_packages(spark_version=spark_version)
57+
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
58+
return SparkS3.get_packages(spark_version=spark_version) + Excel.get_packages(spark_version="3.5.1")
59+
if db_type == "hdfs":
60+
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
61+
return Excel.get_packages(spark_version="3.5.1")
5762

5863
# If the database type does not require downloading .jar packages
5964
return []

tests/test_integration/test_run_transfer/conftest.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import os
44
import secrets
55
from collections import namedtuple
6-
from pathlib import Path, PurePosixPath
6+
from pathlib import Path, PosixPath, PurePosixPath
77

88
import pyspark
99
import pytest
1010
import pytest_asyncio
1111
from onetl.connection import MSSQL, Clickhouse, Hive, MySQL, Oracle, Postgres, SparkS3
12+
from onetl.connection.file_connection.s3 import S3
1213
from onetl.db import DBWriter
13-
from onetl.file.format import CSV, JSON, JSONLine
14+
from onetl.file.format import CSV, JSON, Excel, JSONLine
1415
from pyspark.sql import DataFrame, SparkSession
1516
from pyspark.sql.types import (
1617
DateType,
@@ -112,6 +113,10 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
112113
)
113114
)
114115

116+
if "hdfs" in markers or "s3" in markers:
117+
# see supported versions from https://mvnrepository.com/artifact/com.crealytics/spark-excel
118+
maven_packages.extend(Excel.get_packages(spark_version="3.5.1"))
119+
115120
if maven_packages:
116121
spark = spark.config("spark.jars.packages", ",".join(maven_packages))
117122

@@ -462,12 +467,22 @@ def s3_file_df_connection(s3_file_connection, spark, s3_server):
462467

463468

464469
@pytest.fixture(scope="session")
465-
def prepare_s3(resource_path, s3_file_connection, s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath]):
466-
logger.info("START PREPARE HDFS")
467-
connection, upload_to = s3_file_df_connection_with_path
468-
files = upload_files(resource_path, upload_to, s3_file_connection)
469-
logger.info("END PREPARE HDFS")
470-
return connection, upload_to, files
470+
def prepare_s3(
471+
resource_path: PosixPath,
472+
s3_file_connection: S3,
473+
s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath],
474+
):
475+
logger.info("START PREPARE S3")
476+
connection, remote_path = s3_file_df_connection_with_path
477+
478+
s3_file_connection.remove_dir(remote_path, recursive=True)
479+
files = upload_files(resource_path, remote_path, s3_file_connection)
480+
481+
yield connection, remote_path, files
482+
483+
logger.info("START POST-CLEANUP S3")
484+
s3_file_connection.remove_dir(remote_path, recursive=True)
485+
logger.info("END POST-CLEANUP S3")
471486

472487

473488
@pytest.fixture(scope="session")
@@ -635,14 +650,14 @@ def prepare_clickhouse(
635650
pass
636651

637652
def fill_with_data(df: DataFrame):
638-
logger.info("START PREPARE ORACLE")
653+
logger.info("START PREPARE CLICKHOUSE")
639654
db_writer = DBWriter(
640655
connection=onetl_conn,
641656
target=f"{clickhouse.user}.source_table",
642657
options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"),
643658
)
644659
db_writer.run(df)
645-
logger.info("END PREPARE ORACLE")
660+
logger.info("END PREPARE CLICKHOUSE")
646661

647662
yield onetl_conn, fill_with_data
648663

@@ -745,7 +760,51 @@ def fill_with_data(df: DataFrame):
745760
pass
746761

747762

748-
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})])
763+
@pytest.fixture
764+
def prepare_mysql(
765+
mysql_for_conftest: MySQLConnectionDTO,
766+
spark: SparkSession,
767+
):
768+
mysql = mysql_for_conftest
769+
onetl_conn = MySQL(
770+
host=mysql.host,
771+
port=mysql.port,
772+
user=mysql.user,
773+
password=mysql.password,
774+
database=mysql.database_name,
775+
spark=spark,
776+
).check()
777+
try:
778+
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table")
779+
except Exception:
780+
pass
781+
try:
782+
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table")
783+
except Exception:
784+
pass
785+
786+
def fill_with_data(df: DataFrame):
787+
logger.info("START PREPARE MYSQL")
788+
db_writer = DBWriter(
789+
connection=onetl_conn,
790+
target=f"{mysql.database_name}.source_table",
791+
)
792+
db_writer.run(df)
793+
logger.info("END PREPARE MYSQL")
794+
795+
yield onetl_conn, fill_with_data
796+
797+
try:
798+
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.source_table")
799+
except Exception:
800+
pass
801+
try:
802+
onetl_conn.execute(f"DROP TABLE IF EXISTS {mysql.database_name}.target_table")
803+
except Exception:
804+
pass
805+
806+
807+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {}), ("excel", {})])
749808
def source_file_format(request: FixtureRequest):
750809
name, params = request.param
751810
if name == "csv":
@@ -769,10 +828,17 @@ def source_file_format(request: FixtureRequest):
769828
**params,
770829
)
771830

831+
if name == "excel":
832+
return "excel", Excel(
833+
header=True,
834+
inferSchema=True,
835+
**params,
836+
)
837+
772838
raise ValueError(f"Unsupported file format: {name}")
773839

774840

775-
@pytest.fixture(params=[("csv", {}), ("jsonline", {})])
841+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("excel", {})])
776842
def target_file_format(request: FixtureRequest):
777843
name, params = request.param
778844
if name == "csv":
@@ -791,6 +857,12 @@ def target_file_format(request: FixtureRequest):
791857
**params,
792858
)
793859

860+
if name == "excel":
861+
return "excel", Excel(
862+
header=False,
863+
**params,
864+
)
865+
794866
raise ValueError(f"Unsupported file format: {name}")
795867

796868

tests/test_integration/test_run_transfer/test_hdfs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onetl.db import DBReader
99
from onetl.file import FileDFReader
1010
from pyspark.sql import DataFrame
11+
from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp
1112
from pytest import FixtureRequest
1213
from sqlalchemy.ext.asyncio import AsyncSession
1314

@@ -37,6 +38,7 @@ async def hdfs_to_postgres(
3738
file_format_flavor: str,
3839
):
3940
format_name, file_format = source_file_format
41+
format_name_in_path = "xlsx" if format_name == "excel" else format_name
4042
_, source_path, _ = prepare_hdfs
4143

4244
result = await create_transfer(
@@ -47,7 +49,7 @@ async def hdfs_to_postgres(
4749
target_connection_id=postgres_connection.id,
4850
source_params={
4951
"type": "hdfs",
50-
"directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor),
52+
"directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor),
5153
"file_format": {
5254
"type": format_name,
5355
**file_format.dict(),
@@ -121,6 +123,11 @@ async def postgres_to_hdfs(
121123
"without_compression",
122124
id="jsonline",
123125
),
126+
pytest.param(
127+
("excel", {}),
128+
"with_header",
129+
id="excel",
130+
),
124131
],
125132
indirect=["source_file_format", "file_format_flavor"],
126133
)
@@ -135,6 +142,7 @@ async def test_run_transfer_hdfs_to_postgres(
135142
):
136143
# Arrange
137144
postgres, _ = prepare_postgres
145+
file_format, _ = source_file_format
138146

139147
# Act
140148
result = await client.post(
@@ -164,6 +172,12 @@ async def test_run_transfer_hdfs_to_postgres(
164172
table="public.target_table",
165173
)
166174
df = reader.run()
175+
176+
# as Excel does not support datetime values with precision greater than milliseconds
177+
if file_format == "excel":
178+
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
179+
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
180+
167181
for field in init_df.schema:
168182
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
169183

@@ -183,6 +197,11 @@ async def test_run_transfer_hdfs_to_postgres(
183197
"without_compression",
184198
id="jsonline",
185199
),
200+
pytest.param(
201+
("excel", {}),
202+
"with_header",
203+
id="excel",
204+
),
186205
],
187206
indirect=["target_file_format", "file_format_flavor"],
188207
)
@@ -235,6 +254,13 @@ async def test_run_transfer_postgres_to_hdfs(
235254
)
236255
df = reader.run()
237256

257+
# as Excel does not support datetime values with precision greater than milliseconds
258+
if format_name == "excel":
259+
init_df = init_df.withColumn(
260+
"REGISTERED_AT",
261+
to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")),
262+
)
263+
238264
for field in init_df.schema:
239265
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
240266

tests/test_integration/test_run_transfer/test_mssql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def test_run_transfer_postgres_to_mssql(
118118
)
119119
df = reader.run()
120120

121-
# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
121+
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
122122
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
123123
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
124124

@@ -173,7 +173,7 @@ async def test_run_transfer_postgres_to_mssql_mixed_naming(
173173
assert df.columns != init_df_with_mixed_column_naming.columns
174174
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]
175175

176-
# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
176+
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
177177
df = df.withColumn("Registered At", date_trunc("second", col("Registered At")))
178178
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
179179
"Registered At",
@@ -228,7 +228,7 @@ async def test_run_transfer_mssql_to_postgres(
228228
)
229229
df = reader.run()
230230

231-
# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
231+
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
232232
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
233233
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
234234

@@ -283,7 +283,7 @@ async def test_run_transfer_mssql_to_postgres_mixed_naming(
283283
assert df.columns != init_df_with_mixed_column_naming.columns
284284
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]
285285

286-
# as spark rounds datetime up to milliseconds while writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
286+
# as spark rounds datetime to nearest 3.33 milliseconds when writing to mssql: https://onetl.readthedocs.io/en/latest/connection/db_connection/mssql/types.html#id5
287287
df = df.withColumn("Registered At", date_trunc("second", col("Registered At")))
288288
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
289289
"Registered At",

0 commit comments

Comments
 (0)