Skip to content

Commit e952c8c

Browse files
author
Ilyas Gasanov
committed
[DOP-21444] Add Excel integration tests
1 parent f8103ae commit e952c8c

File tree

8 files changed

+137
-36
lines changed

8 files changed

+137
-36
lines changed

poetry.lock

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ python-jose = { version = "^3.3.0", extras = ["cryptography"], optional = true }
5959
jinja2 = { version = "^3.1.4", optional = true }
6060
python-multipart = { version = ">=0.0.9,<0.0.18", optional = true }
6161
celery = { version = "^5.4.0", optional = true }
62+
pyspark = { version = ">=3.5.1,<3.5.2", optional = true }
6263
onetl = { version = "^0.12.0", extras = ["spark"], optional = true }
6364
pyyaml = {version = "*", optional = true}
6465
# due to not supporting MacOS 14.x https://www.psycopg.org/psycopg3/docs/news.html#psycopg-3-1-20

syncmaster/dto/transfers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import ClassVar
66

7-
from onetl.file.format import CSV, JSON, JSONLine
7+
from onetl.file.format import CSV, JSON, Excel, JSONLine
88

99

1010
@dataclass
@@ -20,10 +20,17 @@ class DBTransferDTO(TransferDTO):
2020
@dataclass
2121
class FileTransferDTO(TransferDTO):
2222
directory_path: str
23-
file_format: CSV | JSONLine | JSON
23+
file_format: CSV | JSONLine | JSON | Excel
2424
options: dict
2525
df_schema: dict | None = None
2626

27+
_format_parsers = {
28+
"csv": CSV,
29+
"jsonline": JSONLine,
30+
"json": JSON,
31+
"excel": Excel,
32+
}
33+
2734
def __post_init__(self):
2835
if isinstance(self.file_format, dict):
2936
self.file_format = self._get_format(self.file_format.copy())
@@ -32,13 +39,10 @@ def __post_init__(self):
3239

3340
def _get_format(self, file_format: dict):
3441
file_type = file_format.pop("type", None)
35-
if file_type == "csv":
36-
return CSV.parse_obj(file_format)
37-
if file_type == "jsonline":
38-
return JSONLine.parse_obj(file_format)
39-
if file_type == "json":
40-
return JSON.parse_obj(file_format)
41-
raise ValueError("Unknown file type")
42+
parser_class = self._format_parsers.get(file_type)
43+
if parser_class is not None:
44+
return parser_class.parse_obj(file_format)
45+
raise ValueError(f"Unknown file type: {file_type}")
4246

4347

4448
@dataclass

syncmaster/worker/handlers/file/s3.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from typing import TYPE_CHECKING
77

88
from onetl.connection import SparkS3
9+
from onetl.file import FileDFReader
910

1011
from syncmaster.dto.connections import S3ConnectionDTO
1112
from syncmaster.worker.handlers.file.base import FileHandler
1213

1314
if TYPE_CHECKING:
14-
from pyspark.sql import SparkSession
15+
from pyspark.sql import DataFrame, SparkSession
1516

1617

1718
class S3Handler(FileHandler):
@@ -29,3 +30,20 @@ def connect(self, spark: SparkSession):
2930
extra=self.connection_dto.additional_params,
3031
spark=spark,
3132
).check()
33+
34+
def read(self) -> DataFrame:
35+
from pyspark.sql.types import StructType
36+
37+
options = {}
38+
if self.transfer_dto.file_format.__class__.__name__ == "Excel":
39+
options = {"inferSchema": True}
40+
41+
reader = FileDFReader(
42+
connection=self.connection,
43+
format=self.transfer_dto.file_format,
44+
source_path=self.transfer_dto.directory_path,
45+
df_schema=StructType.fromJson(self.transfer_dto.df_schema) if self.transfer_dto.df_schema else None,
46+
options={**options, **self.transfer_dto.options},
47+
)
48+
49+
return reader.run()

syncmaster/worker/spark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_worker_spark_session(
3737

3838
def get_packages(db_type: str) -> list[str]:
3939
from onetl.connection import MSSQL, Clickhouse, Oracle, Postgres, SparkS3
40+
from onetl.file.format import Excel
4041

4142
if db_type == "postgres":
4243
return Postgres.get_packages()
@@ -51,7 +52,7 @@ def get_packages(db_type: str) -> list[str]:
5152
import pyspark
5253

5354
spark_version = pyspark.__version__
54-
return SparkS3.get_packages(spark_version=spark_version)
55+
return SparkS3.get_packages(spark_version=spark_version) + Excel.get_packages(spark_version=spark_version)
5556

5657
# If the database type does not require downloading .jar packages
5758
return []

tests/test_integration/test_run_transfer/conftest.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import os
44
import secrets
55
from collections import namedtuple
6-
from pathlib import Path, PurePosixPath
6+
from pathlib import Path, PosixPath, PurePosixPath
77

88
import pyspark
99
import pytest
1010
import pytest_asyncio
1111
from onetl.connection import MSSQL, Clickhouse, Hive, Oracle, Postgres, SparkS3
12+
from onetl.connection.file_connection.s3 import S3
1213
from onetl.db import DBWriter
13-
from onetl.file.format import CSV, JSON, JSONLine
14+
from onetl.file.format import CSV, JSON, Excel, JSONLine
1415
from pyspark.sql import DataFrame, SparkSession
1516
from pyspark.sql.types import (
1617
DateType,
@@ -82,6 +83,7 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
8283

8384
if "s3" in markers:
8485
maven_packages.extend(SparkS3.get_packages(spark_version=pyspark.__version__))
86+
maven_packages.extend(Excel.get_packages(spark_version=pyspark.__version__))
8587
excluded_packages.extend(
8688
[
8789
"com.google.cloud.bigdataoss:gcs-connector",
@@ -427,12 +429,22 @@ def s3_file_df_connection(s3_file_connection, spark, s3_server):
427429

428430

429431
@pytest.fixture(scope="session")
430-
def prepare_s3(resource_path, s3_file_connection, s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath]):
431-
logger.info("START PREPARE HDFS")
432-
connection, upload_to = s3_file_df_connection_with_path
433-
files = upload_files(resource_path, upload_to, s3_file_connection)
434-
logger.info("END PREPARE HDFS")
435-
return connection, upload_to, files
432+
def prepare_s3(
433+
resource_path: PosixPath,
434+
s3_file_connection: S3,
435+
s3_file_df_connection_with_path: tuple[SparkS3, PurePosixPath],
436+
):
437+
logger.info("START PREPARE S3")
438+
connection, remote_path = s3_file_df_connection_with_path
439+
440+
s3_file_connection.remove_dir(remote_path, recursive=True)
441+
files = upload_files(resource_path, remote_path, s3_file_connection)
442+
443+
yield connection, remote_path, files
444+
445+
logger.info("START POST-CLEANUP S3")
446+
s3_file_connection.remove_dir(remote_path, recursive=True)
447+
logger.info("END POST-CLEANUP S3")
436448

437449

438450
@pytest.fixture(scope="session")
@@ -600,14 +612,14 @@ def prepare_clickhouse(
600612
pass
601613

602614
def fill_with_data(df: DataFrame):
603-
logger.info("START PREPARE ORACLE")
615+
logger.info("START PREPARE CLICKHOUSE")
604616
db_writer = DBWriter(
605617
connection=onetl_conn,
606618
target=f"{clickhouse.user}.source_table",
607619
options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"),
608620
)
609621
db_writer.run(df)
610-
logger.info("END PREPARE ORACLE")
622+
logger.info("END PREPARE CLICKHOUSE")
611623

612624
yield onetl_conn, fill_with_data
613625

@@ -666,7 +678,7 @@ def fill_with_data(df: DataFrame):
666678
pass
667679

668680

669-
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})])
681+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {}), ("excel", {})])
670682
def source_file_format(request: FixtureRequest):
671683
name, params = request.param
672684
if name == "csv":
@@ -690,10 +702,17 @@ def source_file_format(request: FixtureRequest):
690702
**params,
691703
)
692704

705+
if name == "excel":
706+
return "excel", Excel(
707+
header=True,
708+
inferSchema=True,
709+
**params,
710+
)
711+
693712
raise ValueError(f"Unsupported file format: {name}")
694713

695714

696-
@pytest.fixture(params=[("csv", {}), ("jsonline", {})])
715+
@pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("excel", {})])
697716
def target_file_format(request: FixtureRequest):
698717
name, params = request.param
699718
if name == "csv":
@@ -712,6 +731,12 @@ def target_file_format(request: FixtureRequest):
712731
**params,
713732
)
714733

734+
if name == "excel":
735+
return "excel", Excel(
736+
header=False,
737+
**params,
738+
)
739+
715740
raise ValueError(f"Unsupported file format: {name}")
716741

717742

tests/test_integration/test_run_transfer/test_hdfs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from onetl.db import DBReader
99
from onetl.file import FileDFReader
1010
from pyspark.sql import DataFrame
11+
from pyspark.sql.functions import col, date_format, date_trunc, to_timestamp
1112
from pytest import FixtureRequest
1213
from sqlalchemy.ext.asyncio import AsyncSession
1314

@@ -37,6 +38,7 @@ async def hdfs_to_postgres(
3738
file_format_flavor: str,
3839
):
3940
format_name, file_format = source_file_format
41+
format_name_in_path = "xlsx" if format_name == "excel" else format_name
4042
_, source_path, _ = prepare_hdfs
4143

4244
result = await create_transfer(
@@ -47,7 +49,7 @@ async def hdfs_to_postgres(
4749
target_connection_id=postgres_connection.id,
4850
source_params={
4951
"type": "hdfs",
50-
"directory_path": os.fspath(source_path / "file_df_connection" / format_name / file_format_flavor),
52+
"directory_path": os.fspath(source_path / "file_df_connection" / format_name_in_path / file_format_flavor),
5153
"file_format": {
5254
"type": format_name,
5355
**file_format.dict(),
@@ -121,6 +123,11 @@ async def postgres_to_hdfs(
121123
"without_compression",
122124
id="jsonline",
123125
),
126+
pytest.param(
127+
("excel", {}),
128+
"with_header",
129+
id="excel",
130+
),
124131
],
125132
indirect=["source_file_format", "file_format_flavor"],
126133
)
@@ -135,6 +142,7 @@ async def test_run_transfer_hdfs_to_postgres(
135142
):
136143
# Arrange
137144
postgres, _ = prepare_postgres
145+
file_format, _ = source_file_format
138146

139147
# Act
140148
result = await client.post(
@@ -164,6 +172,12 @@ async def test_run_transfer_hdfs_to_postgres(
164172
table="public.target_table",
165173
)
166174
df = reader.run()
175+
176+
# as initial datetime values in excel are rounded to milliseconds
177+
if file_format == "excel":
178+
df = df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
179+
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
180+
167181
for field in init_df.schema:
168182
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
169183

@@ -183,6 +197,11 @@ async def test_run_transfer_hdfs_to_postgres(
183197
"without_compression",
184198
id="jsonline",
185199
),
200+
pytest.param(
201+
("excel", {}),
202+
"with_header",
203+
id="excel",
204+
),
186205
],
187206
indirect=["target_file_format", "file_format_flavor"],
188207
)
@@ -235,6 +254,13 @@ async def test_run_transfer_postgres_to_hdfs(
235254
)
236255
df = reader.run()
237256

257+
# when spark writes to excel, the datetime precision is truncated to milliseconds (without rounding)
258+
if format_name == "excel":
259+
init_df = init_df.withColumn(
260+
"REGISTERED_AT",
261+
to_timestamp(date_format(col("REGISTERED_AT"), "yyyy-MM-dd HH:mm:ss.SSS")),
262+
)
263+
238264
for field in init_df.schema:
239265
df = df.withColumn(field.name, df[field.name].cast(field.dataType))
240266

0 commit comments

Comments
 (0)