Skip to content

Commit ff43e6e

Browse files
authored
Update HiveToDynamoDBOperator to support Polars (#54221)
1 parent 67f5565 commit ff43e6e

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import json
2323
from collections.abc import Callable, Sequence
24-
from typing import TYPE_CHECKING
24+
from typing import TYPE_CHECKING, Literal
2525

2626
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
2727
from airflow.providers.amazon.version_compat import BaseOperator
@@ -53,6 +53,7 @@ class HiveToDynamoDBOperator(BaseOperator):
5353
:param hiveserver2_conn_id: Reference to the
5454
:ref: `Hive Server2 thrift service connection id <howto/connection:hiveserver2>`.
5555
:param aws_conn_id: aws connection
56+
:param df_type: DataFrame type to use ("pandas" or "polars").
5657
"""
5758

5859
template_fields: Sequence[str] = ("sql",)
@@ -73,6 +74,7 @@ def __init__(
7374
schema: str = "default",
7475
hiveserver2_conn_id: str = "hiveserver2_default",
7576
aws_conn_id: str | None = "aws_default",
77+
df_type: Literal["pandas", "polars"] = "pandas",
7678
**kwargs,
7779
) -> None:
7880
super().__init__(**kwargs)
@@ -86,14 +88,15 @@ def __init__(
8688
self.schema = schema
8789
self.hiveserver2_conn_id = hiveserver2_conn_id
8890
self.aws_conn_id = aws_conn_id
91+
self.df_type = df_type
8992

9093
def execute(self, context: Context):
9194
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
9295

9396
self.log.info("Extracting data from Hive")
9497
self.log.info(self.sql)
9598

96-
data = hive.get_df(self.sql, schema=self.schema, df_type="pandas")
99+
data = hive.get_df(self.sql, schema=self.schema, df_type=self.df_type)
97100
dynamodb = DynamoDBHook(
98101
aws_conn_id=self.aws_conn_id,
99102
table_name=self.table_name,
@@ -104,7 +107,10 @@ def execute(self, context: Context):
104107
self.log.info("Inserting rows into dynamodb")
105108

106109
if self.pre_process is None:
107-
dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))
110+
if self.df_type == "polars":
111+
dynamodb.write_batch_data(data.to_dicts()) # type:ignore[operator]
112+
elif self.df_type == "pandas":
113+
dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) # type:ignore[union-attr]
108114
else:
109115
dynamodb.write_batch_data(
110116
self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)

providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from unittest import mock
2323

2424
import pandas as pd
25+
import polars as pl
26+
import pytest
2527
from moto import mock_aws
2628

2729
import airflow.providers.amazon.aws.transfers.hive_to_dynamodb
@@ -110,3 +112,43 @@ def test_pre_process_records_with_schema(self, mock_get_df):
110112
table = self.hook.get_conn().Table("test_airflow")
111113
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
112114
assert table.item_count == 1
115+
116+
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
117+
@mock_aws
118+
def test_df_type_parameter(self, df_type):
119+
if df_type == "polars" and pl is None:
120+
pytest.skip("Polars not installed")
121+
122+
if df_type == "pandas":
123+
test_df = pd.DataFrame(data=[("1", "sid")], columns=["id", "name"])
124+
else:
125+
test_df = pl.DataFrame({"id": ["1"], "name": ["sid"]})
126+
127+
with mock.patch(
128+
"airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_df",
129+
return_value=test_df,
130+
) as mock_get_df:
131+
self.hook.get_conn().create_table(
132+
TableName="test_airflow",
133+
KeySchema=[
134+
{"AttributeName": "id", "KeyType": "HASH"},
135+
],
136+
AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}],
137+
ProvisionedThroughput={"ReadCapacityUnits": 10, "WriteCapacityUnits": 10},
138+
)
139+
140+
operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
141+
sql=self.sql,
142+
table_name="test_airflow",
143+
task_id="hive_to_dynamodb_check",
144+
table_keys=["id"],
145+
df_type=df_type,
146+
dag=self.dag,
147+
)
148+
149+
operator.execute(None)
150+
mock_get_df.assert_called_once_with(self.sql, schema="default", df_type=df_type)
151+
152+
table = self.hook.get_conn().Table("test_airflow")
153+
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
154+
assert table.item_count == 1

0 commit comments

Comments
 (0)