Skip to content

Commit 2682ed2

Browse files
authored
Feat: Adds interop with Arrow library using new method Dataset.to_arrow() (#281)
1 parent 3e1ffb0 commit 2682ed2

File tree

8 files changed

+261
-107
lines changed

8 files changed

+261
-107
lines changed

airbyte/caches/base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import TYPE_CHECKING, Any, Optional, final
88

99
import pandas as pd
10+
import pyarrow as pa
11+
import pyarrow.dataset as ds
1012
from pydantic import Field, PrivateAttr
1113

1214
from airbyte_protocol.models import ConfiguredAirbyteCatalog
@@ -19,6 +21,7 @@
1921
from airbyte._future_cdk.state_writers import StdOutStateWriter
2022
from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend
2123
from airbyte.caches._state_backend import SqlStateBackend
24+
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
2225
from airbyte.datasets._sql import CachedDataset
2326

2427

@@ -146,6 +149,38 @@ def get_pandas_dataframe(
146149
engine = self.get_sql_engine()
147150
return pd.read_sql_table(table_name, engine, schema=self.schema_name)
148151

152+
def get_arrow_dataset(
153+
self,
154+
stream_name: str,
155+
*,
156+
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
157+
) -> ds.Dataset:
158+
"""Return an Arrow Dataset with the stream's data."""
159+
table_name = self._read_processor.get_sql_table_name(stream_name)
160+
engine = self.get_sql_engine()
161+
162+
# Read the table in chunks to handle large tables which does not fits in memory
163+
pandas_chunks = pd.read_sql_table(
164+
table_name=table_name,
165+
con=engine,
166+
schema=self.schema_name,
167+
chunksize=max_chunk_size,
168+
)
169+
170+
arrow_batches_list = []
171+
arrow_schema = None
172+
173+
for pandas_chunk in pandas_chunks:
174+
if arrow_schema is None:
175+
# Initialize the schema with the first chunk
176+
arrow_schema = pa.Schema.from_pandas(pandas_chunk)
177+
178+
# Convert each pandas chunk to an Arrow Table
179+
arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema)
180+
arrow_batches_list.append(arrow_table)
181+
182+
return ds.dataset(arrow_batches_list)
183+
149184
@final
150185
@property
151186
def streams(self) -> dict[str, CachedDataset]:

airbyte/caches/bigquery.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,24 @@
2323
from airbyte.caches.base import (
2424
CacheBase,
2525
)
26+
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
2627

2728

2829
class BigQueryCache(BigQueryConfig, CacheBase):
2930
"""The BigQuery cache implementation."""
3031

3132
_sql_processor_class: type[BigQuerySqlProcessor] = PrivateAttr(default=BigQuerySqlProcessor)
33+
34+
def get_arrow_dataset(
35+
self,
36+
stream_name: str,
37+
*,
38+
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
39+
) -> None:
40+
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.
41+
https://github.com/airbytehq/PyAirbyte/issues/165
42+
"""
43+
raise NotImplementedError(
44+
"BigQuery doesn't currently support to_arrow"
45+
"Please consider using a different cache implementation for these functionalities."
46+
)

airbyte/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@
3838
3939
Specific caches may override this value with a different schema name.
4040
"""
41+
42+
DEFAULT_ARROW_MAX_CHUNK_SIZE = 100_000
43+
"""The default number of records to include in each batch of an Arrow dataset."""

airbyte/datasets/_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
from typing import TYPE_CHECKING, Any, cast
77

88
from pandas import DataFrame
9+
from pyarrow.dataset import Dataset
910

1011
from airbyte._util.document_rendering import DocumentRenderer
12+
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
1113

1214

1315
if TYPE_CHECKING:
16+
from pyarrow.dataset import Dataset
17+
1418
from airbyte_protocol.models import ConfiguredAirbyteStream
1519

1620
from airbyte.documents import Document
@@ -37,6 +41,17 @@ def to_pandas(self) -> DataFrame:
3741
# duck typing is correct for this use case.
3842
return DataFrame(cast(Iterator[dict[str, Any]], self))
3943

44+
def to_arrow(
45+
self,
46+
*,
47+
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
48+
) -> Dataset:
49+
"""Return an Arrow Dataset representation of the dataset.
50+
51+
This method should be implemented by subclasses.
52+
"""
53+
raise NotImplementedError("Not implemented in base class")
54+
4055
def to_documents(
4156
self,
4257
title_property: str | None = None,

airbyte/datasets/_sql.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212

1313
from airbyte_protocol.models.airbyte_protocol import ConfiguredAirbyteStream
1414

15+
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
1516
from airbyte.datasets._base import DatasetBase
1617

1718

1819
if TYPE_CHECKING:
1920
from collections.abc import Iterator
2021

2122
from pandas import DataFrame
23+
from pyarrow.dataset import Dataset
2224
from sqlalchemy import Table
2325
from sqlalchemy.sql import ClauseElement
2426
from sqlalchemy.sql.selectable import Selectable
@@ -102,6 +104,13 @@ def __len__(self) -> int:
102104
def to_pandas(self) -> DataFrame:
103105
return self._cache.get_pandas_dataframe(self._stream_name)
104106

107+
def to_arrow(
108+
self,
109+
*,
110+
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
111+
) -> Dataset:
112+
return self._cache.get_arrow_dataset(self._stream_name, max_chunk_size=max_chunk_size)
113+
105114
def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
106115
"""Filter the dataset by a set of column values.
107116
@@ -166,6 +175,26 @@ def to_pandas(self) -> DataFrame:
166175
"""Return the underlying dataset data as a pandas DataFrame."""
167176
return self._cache.get_pandas_dataframe(self._stream_name)
168177

178+
@overrides
179+
def to_arrow(
180+
self,
181+
*,
182+
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
183+
) -> Dataset:
184+
"""Return an Arrow Dataset containing the data from the specified stream.
185+
186+
Args:
187+
stream_name (str): Name of the stream to retrieve data from.
188+
max_chunk_size (int): max number of records to include in each batch of pyarrow dataset.
189+
190+
Returns:
191+
pa.dataset.Dataset: Arrow Dataset containing the stream's data.
192+
"""
193+
return self._cache.get_arrow_dataset(
194+
stream_name=self._stream_name,
195+
max_chunk_size=max_chunk_size,
196+
)
197+
169198
def to_sql_table(self) -> Table:
170199
"""Return the underlying SQL table as a SQLAlchemy Table object."""
171200
return self._cache.processor.get_sql_table(self.stream_name)

0 commit comments

Comments
 (0)