Skip to content

Commit e2e1bd5

Browse files
author
Ilyas Gasanov
committed
[DOP-21444] Add Excel integration tests
1 parent f8103ae commit e2e1bd5

File tree

9 files changed

+150
-41
lines changed

9 files changed

+150
-41
lines changed

poetry.lock

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ python-jose = { version = "^3.3.0", extras = ["cryptography"], optional = true }
5959
jinja2 = { version = "^3.1.4", optional = true }
6060
python-multipart = { version = ">=0.0.9,<0.0.18", optional = true }
6161
celery = { version = "^5.4.0", optional = true }
62+
pyspark = { version = ">=3.5.1,<3.5.2", optional = true }
6263
onetl = { version = "^0.12.0", extras = ["spark"], optional = true }
6364
pyyaml = {version = "*", optional = true}
6465
# due to not supporting MacOS 14.x https://www.psycopg.org/psycopg3/docs/news.html#psycopg-3-1-20

syncmaster/dto/transfers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import ClassVar
66

7-
from onetl.file.format import CSV, JSON, JSONLine
7+
from onetl.file.format import CSV, JSON, Excel, JSONLine
88

99

1010
@dataclass
@@ -20,10 +20,17 @@ class DBTransferDTO(TransferDTO):
2020
@dataclass
2121
class FileTransferDTO(TransferDTO):
2222
directory_path: str
23-
file_format: CSV | JSONLine | JSON
23+
file_format: CSV | JSONLine | JSON | Excel
2424
options: dict
2525
df_schema: dict | None = None
2626

27+
_format_parsers = {
28+
"csv": CSV,
29+
"jsonline": JSONLine,
30+
"json": JSON,
31+
"excel": Excel,
32+
}
33+
2734
def __post_init__(self):
2835
if isinstance(self.file_format, dict):
2936
self.file_format = self._get_format(self.file_format.copy())
@@ -32,13 +39,10 @@ def __post_init__(self):
3239

3340
def _get_format(self, file_format: dict):
3441
file_type = file_format.pop("type", None)
35-
if file_type == "csv":
36-
return CSV.parse_obj(file_format)
37-
if file_type == "jsonline":
38-
return JSONLine.parse_obj(file_format)
39-
if file_type == "json":
40-
return JSON.parse_obj(file_format)
41-
raise ValueError("Unknown file type")
42+
parser_class = self._format_parsers.get(file_type)
43+
if parser_class is not None:
44+
return parser_class.parse_obj(file_format)
45+
raise ValueError(f"Unknown file type: {file_type}")
4246

4347

4448
@dataclass

syncmaster/worker/handlers/file/s3.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from typing import TYPE_CHECKING
77

88
from onetl.connection import SparkS3
9+
from onetl.file import FileDFReader
910

1011
from syncmaster.dto.connections import S3ConnectionDTO
1112
from syncmaster.worker.handlers.file.base import FileHandler
1213

1314
if TYPE_CHECKING:
14-
from pyspark.sql import SparkSession
15+
from pyspark.sql import DataFrame, SparkSession
1516

1617

1718
class S3Handler(FileHandler):
@@ -29,3 +30,20 @@ def connect(self, spark: SparkSession):
2930
extra=self.connection_dto.additional_params,
3031
spark=spark,
3132
).check()
33+
34+
def read(self) -> DataFrame:
35+
from pyspark.sql.types import StructType
36+
37+
options = {}
38+
if self.transfer_dto.file_format.__class__.__name__ == "Excel":
39+
options = {"inferSchema": True}
40+
41+
reader = FileDFReader(
42+
connection=self.connection,
43+
format=self.transfer_dto.file_format,
44+
source_path=self.transfer_dto.directory_path,
45+
df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None,
46+
options={**options, **self.transfer_dto.options},
47+
)
48+
49+
return reader.run()

syncmaster/worker/spark.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_worker_spark_session(
3737

3838
def get_packages(db_type: str) -> list[str]:
3939
from onetl.connection import MSSQL, Clickhouse, Oracle, Postgres, SparkS3
40+
from onetl.file.format import Excel
4041

4142
if db_type == "postgres":
4243
return Postgres.get_packages()
@@ -51,7 +52,12 @@ def get_packages(db_type: str) -> list[str]:
5152
import pyspark
5253

5354
spark_version = pyspark.__version__
54-
return SparkS3.get_packages(spark_version=spark_version)
55+
return SparkS3.get_packages(spark_version=spark_version) + Excel.get_packages(spark_version=spark_version)
56+
if db_type == "hdfs":
57+
import pyspark
58+
59+
spark_version = pyspark.__version__
60+
return Excel.get_packages(spark_version=spark_version)
5561

5662
# If the database type does not require downloading .jar packages
5763
return []

tests/test_integration/test_run_transfer/conftest.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import os
44
import secrets
55
from collections import namedtuple
6-
from pathlib import Path, PurePosixPath
6+
from pathlib import Path, PosixPath, PurePosixPath
77

88
import pyspark
99
import pytest
1010
import pytest_asyncio
1111
from onetl.connection import MSSQL, Clickhouse, Hive, Oracle, Postgres, SparkS3
12+
from onetl.connection.file_connection.s3 import S3
1213
from onetl.db import DBWriter
13-
from onetl.file.format import CSV, JSON, JSONLine
14+
from onetl.file.format import CSV, JSON, Excel, JSONLine
1415
from pyspark.sql import DataFrame, SparkSession
1516
from pyspark.sql.types import (
1617
DateType,
@@ -107,6 +108,9 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
107108
)
108109
)
109110

111+
if "hdfs" in markers or "s3" in markers:
112+
maven_packages.extend(Excel.get_packages(spark_version=pyspark.__version__))
113+
110114
if maven_packages:
111115
spark = spark.config("spark.jars.packages", ",".join(maven_packages))
112116

@@ -427,12 +431,22 @@ def s3_file_df_connection(s3_file_connection, spark, s3_server):
427431

428432

429433
@pytest.fixture(scope="session")
430-
def prepare_s3(resource_path, s3_file_connection, s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath]):
431-
logger.info("START PREPARE HDFS")
432-
connection, upload_to = s3_file_df_connection_with_path
433-
files = upload_files(resource_path, upload_to, s3_file_connection)
434-
logger.info("END PREPARE HDFS")
435-
return connection, upload_to, files
434+
def prepare_s3(
435+
resource_path: PosixPath,
436+
s3_file_connection: S3,
437+
s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath],
438+
):
439+
logger.info("START PREPARE S3")
440+
connection, remote_path = s3_file_df_connection_with_path
441+
442+
s3_file_connection.remove_dir(remote_path, recursive=True)
443+
files = upload_files(resource_path, remote_path, s3_file_connection)
444+
445+
yield connection, remote_path, files
446+
447+
logger.info("START POST-CLEANUP S3")
448+
s3_file_connection.remove_dir(remote_path, recursive=True)
449+
logger.info("END POST-CLEANUP S3")
436450

437451

438452
@pytest.fixture(scope="session")
@@ -600,14 +614,14 @@ def prepare_clickhouse(
600614
pass
601615

602616
def fill_with_data(df: DataFrame):
603-
logger.info("START PREPARE ORACLE")
617+
logger.info("START PREPARE CLICKHOUSE")
604618
db_writer = DBWriter(
605619
connection=onetl_conn,
606620
target=f"{clickhouse.user}.source_table",
607621
options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"),
608622
)
609623
db_writer.run(df)
610-
logger.info("END PREPARE ORACLE")
624+
logger.info("END PREPARE CLICKHOUSE")
611625

612626
yield onetl_conn, fill_with_data
613627

@@ -666,7 +680,7 @@ def fill_with_data(df: DataFrame):
666680
pass
667681

668682

669-
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})])
683+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {}), ("excel", {})])
670684
def source_file_format(request: FixtureRequest):
671685
name, params = request.param
672686
if name == "csv":
@@ -690,10 +704,17 @@ def source_file_format(request: FixtureRequest):
690704
**params,
691705
)
692706

707+
if name == "excel":
708+
return "excel", Excel(
709+
header=True,
710+
inferSchema=True,
711+
**params,
712+
)
713+
693714
raise ValueError(f"Unsupported file format: {name}")
694715

695716

696-
@pytest.fixture(params=[("csv", {}), ("jsonline", {})])
717+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("excel", {})])
697718
def target_file_format(request: FixtureRequest):
698719
name, params = request.param
699720
if name == "csv":
@@ -712,6 +733,12 @@ def target_file_format(request: FixtureRequest):
712733
**params,
713734
)
714735

736+
if name == "excel":
737+
return "excel", Excel(
738+
header=False,
739+
**params,
740+
)
741+
715742
raise ValueError(f"Unsupported file format: {name}")
716743

717744

tests/test_integration/test_run_transfer/test_hdfs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onetl.db import DBReader
99
from onetl.file import FileDFReader
1010
from pyspark.sql import DataFrame
11+
from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp
1112
from pytest import FixtureRequest
1213
from sqlalchemy.ext.asyncio import AsyncSession
1314

@@ -37,6 +38,7 @@ async def hdfs_to_postgres(
3738
file_format_flavor: str,
3839
):
3940
format_name, file_format = source_file_format
41+
format_name_in_path = "xlsx" if format_name == "excel" else format_name
4042
_, source_path, _ = prepare_hdfs
4143

4244
result = await create_transfer(
@@ -47,7 +49,7 @@ async def hdfs_to_postgres(
4749
target_connection_id=postgres_connection.id,
4850
source_params={
4951
"type": "hdfs",
50-
"directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor),
52+
"directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor),
5153
"file_format": {
5254
"type": format_name,
5355
**file_format.dict(),
@@ -121,6 +123,11 @@ async def postgres_to_hdfs(
121123
"without_compression",
122124
id="jsonline",
123125
),
126+
pytest.param(
127+
("excel", {}),
128+
"with_header",
129+
id="excel",
130+
),
124131
],
125132
indirect=["source_file_format", "file_format_flavor"],
126133
)
@@ -135,6 +142,7 @@ async def test_run_transfer_hdfs_to_postgres(
135142
):
136143
# Arrange
137144
postgres, _ = prepare_postgres
145+
file_format, _ = source_file_format
138146

139147
# Act
140148
result = await client.post(
@@ -164,6 +172,12 @@ async def test_run_transfer_hdfs_to_postgres(
164172
table="public.target_table",
165173
)
166174
df = reader.run()
175+
176+
# as initial datetime values in excel are rounded to milliseconds
177+
if file_format == "excel":
178+
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
179+
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
180+
167181
for field in init_df.schema:
168182
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
169183

@@ -183,6 +197,11 @@ async def test_run_transfer_hdfs_to_postgres(
183197
"without_compression",
184198
id="jsonline",
185199
),
200+
pytest.param(
201+
("excel", {}),
202+
"with_header",
203+
id="excel",
204+
),
186205
],
187206
indirect=["target_file_format", "file_format_flavor"],
188207
)
@@ -235,6 +254,13 @@ async def test_run_transfer_postgres_to_hdfs(
235254
)
236255
df = reader.run()
237256

257+
# when spark writes to excel, the datetime precision is truncated to milliseconds (without rounding)
258+
if format_name == "excel":
259+
init_df = init_df.withColumn(
260+
"REGISTERED_AT",
261+
to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")),
262+
)
263+
238264
for field in init_df.schema:
239265
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
240266

0 commit comments

Comments
 (0)