Skip to content

Commit 4926e13

Browse files
Support Iceberg in read_from/write_from JobSocket (#147)
1 parent 8dbdf4a commit 4926e13

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

spark_utils/common/functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def read_from_socket(
7676
:return: Spark dataframe
7777
"""
7878
read_options = read_options or {}
79-
if socket.data_format.startswith("hive"):
79+
if socket.data_format.startswith("hive") or socket.data_format.startswith("iceberg"):
8080
return spark_session.table(socket.data_path)
8181

8282
return spark_session.read.options(**read_options).format(socket.data_format).load(socket.data_path)
@@ -103,6 +103,11 @@ def write_to_socket(
103103
if partition_count:
104104
data = data.repartition(partition_count, *partition_by)
105105

106+
# ignore all external write options as Iceberg writer will take care of those
107+
if socket.data_format.startswith("iceberg"):
108+
data.writeTo(socket.data_path).createOrReplace()
109+
return
110+
106111
writer = data.write.mode("overwrite").options(**write_options)
107112

108113
if partition_by:

test/test_iceberg.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import pytest
21
from pyspark.sql import SparkSession, DataFrame
32

3+
from spark_utils.common.functions import write_to_socket, read_from_socket
4+
from spark_utils.models.job_socket import JobSocket
5+
from test.test_common_functions import are_dfs_equal
6+
47

58
def test_iceberg_rest_create_schema(iceberg_spark_session: SparkSession):
69
try:
@@ -19,3 +22,28 @@ def test_iceberg_rest_create_table(iceberg_spark_session: SparkSession):
1922
assert rows.collect()[0].asDict() == {"C0": 1, "C1": "1231", "C2": 1.0}
2023
except BaseException as e:
2124
raise RuntimeError("Failed to create table") from e
25+
26+
27+
def test_write_to_socket(
28+
iceberg_spark_session: SparkSession,
29+
):
30+
output_socket = JobSocket(
31+
alias="test",
32+
data_path=f"iceberg.test.job_socket_write",
33+
data_format="iceberg",
34+
)
35+
df = iceberg_spark_session.createDataFrame(
36+
[{"C0": 1, "C1": "1231", "C2": 1.0}, {"C0": 2, "C1": "1232", "C2": 2.0}, {"C0": 3, "C1": "1233", "C2": 3.0}]
37+
)
38+
39+
write_to_socket(
40+
data=df,
41+
socket=output_socket,
42+
write_options=None,
43+
partition_by=None,
44+
partition_count=None,
45+
)
46+
47+
df_read = read_from_socket(socket=output_socket, spark_session=iceberg_spark_session, read_options=None)
48+
49+
assert are_dfs_equal(df, df_read.select(df.columns))

0 commit comments

Comments
 (0)