Skip to content

Commit 27f4b76

Browse files
authored
lazy load pandas and numpy to improve startup performance (#443)
1 parent b6eb9ab commit 27f4b76

File tree

21 files changed

+150
-56
lines changed

21 files changed

+150
-56
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.5.21
2+
3+
### Fixes
4+
5+
* **Lazy load pandas and numpy** to improve startup performance
6+
17
## 0.5.20
28

39
### Features

requirements/common/base.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
-c constraints.txt
22

33
python-dateutil
4-
pandas
54
# Pydantic generic Secret only introduced in 2.7
65
pydantic>=2.7
76
dataclasses_json

requirements/common/extras.in

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Requirements file that is only used by extras
2+
# These requirements will not be installed by default(`unstructured-ingest`)
3+
# but will be installed when `unstructured-ingest[<any extra>]` is used
4+
pandas
5+
numpy

setup.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from unstructured_ingest.__version__ import __version__
2727

28+
_extra_reqs_filepath = "requirements/common/extras.in"
29+
2830

2931
def load_requirements(file: Union[str, Path]) -> List[str]:
3032
path = file if isinstance(file, Path) else Path(file)
@@ -42,8 +44,13 @@ def load_requirements(file: Union[str, Path]) -> List[str]:
4244
file_spec = recursive_req.split()[-1]
4345
file_path = Path(file_dir) / file_spec
4446
requirements.extend(load_requirements(file=file_path.resolve()))
47+
4548
# Remove duplicates and any blank entries
46-
return list({r for r in requirements if r})
49+
result = list({r for r in requirements if r})
50+
51+
if file != _extra_reqs_filepath:
52+
result.extend(load_requirements(_extra_reqs_filepath))
53+
return result
4754

4855

4956
csv_reqs = load_requirements("requirements/local_partition/tsv.in")

test/integration/connectors/test_astradb.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import json
23
import os
34
from dataclasses import dataclass
@@ -231,6 +232,13 @@ def test_astra_create_destination():
231232
)
232233
collection_name = "system_created-123"
233234
formatted_collection_name = "system_created_123"
235+
236+
client = AstraDBClient()
237+
db = client.get_database(api_endpoint=env_data.api_endpoint, token=env_data.token)
238+
with contextlib.suppress(Exception):
239+
# drop collection before trying to create it
240+
db.drop_collection(formatted_collection_name)
241+
234242
created = uploader.create_destination(destination_name=collection_name, vector_length=3072)
235243
assert created
236244
assert uploader.upload_config.collection_name == formatted_collection_name
@@ -239,8 +247,6 @@ def test_astra_create_destination():
239247
assert not created
240248

241249
# cleanup
242-
client = AstraDBClient()
243-
db = client.get_database(api_endpoint=env_data.api_endpoint, token=env_data.token)
244250
db.drop_collection(formatted_collection_name)
245251

246252

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.20" # pragma: no cover
1+
__version__ = "0.5.21" # pragma: no cover

unstructured_ingest/embed/interfaces.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from dataclasses import dataclass
33
from typing import Any, Optional
44

5-
import numpy as np
65
from pydantic import BaseModel, Field
76

87
from unstructured_ingest.utils.data_prep import batch_generator
8+
from unstructured_ingest.utils.dep_check import requires_dependencies
99

1010
EMBEDDINGS_KEY = "embeddings"
1111

@@ -32,7 +32,6 @@ def wrap_error(self, e: Exception) -> Exception:
3232

3333
@dataclass
3434
class BaseEmbeddingEncoder(BaseEncoder, ABC):
35-
3635
def initialize(self):
3736
"""Initializes the embedding encoder class. Should also validate the instance
3837
is properly configured: e.g., embed a single a element"""
@@ -46,8 +45,11 @@ def get_exemplary_embedding(self) -> list[float]:
4645
return self.embed_query(query="Q")
4746

4847
@property
48+
@requires_dependencies(["numpy"])
4949
def is_unit_vector(self) -> bool:
5050
"""Denotes if the embedding vector is a unit vector."""
51+
import numpy as np
52+
5153
exemplary_embedding = self.get_exemplary_embedding()
5254
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
5355

@@ -86,7 +88,6 @@ def embed_query(self, query: str) -> list[float]:
8688

8789
@dataclass
8890
class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
89-
9091
async def initialize(self):
9192
"""Initializes the embedding encoder class. Should also validate the instance
9293
is properly configured: e.g., embed a single a element"""
@@ -100,8 +101,11 @@ async def get_exemplary_embedding(self) -> list[float]:
100101
return await self.embed_query(query="Q")
101102

102103
@property
104+
@requires_dependencies(["numpy"])
103105
async def is_unit_vector(self) -> bool:
104106
"""Denotes if the embedding vector is a unit vector."""
107+
import numpy as np
108+
105109
exemplary_embedding = await self.get_exemplary_embedding()
106110
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
107111

unstructured_ingest/utils/data_prep.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22
import json
33
from datetime import datetime
44
from pathlib import Path
5-
from typing import Any, Generator, Iterable, Optional, Sequence, TypeVar, Union, cast
6-
7-
import pandas as pd
5+
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Sequence, TypeVar, Union, cast
86

97
from unstructured_ingest.utils import ndjson
8+
from unstructured_ingest.utils.dep_check import requires_dependencies
109
from unstructured_ingest.v2.logger import logger
1110

11+
if TYPE_CHECKING:
12+
from pandas import DataFrame
13+
1214
DATE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d+%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z")
1315

1416
T = TypeVar("T")
1517
IterableT = Iterable[T]
1618

1719

18-
def split_dataframe(df: pd.DataFrame, chunk_size: int = 100) -> Generator[pd.DataFrame, None, None]:
20+
def split_dataframe(df: "DataFrame", chunk_size: int = 100) -> Generator["DataFrame", None, None]:
1921
num_chunks = len(df) // chunk_size + 1
2022
for i in range(num_chunks):
2123
yield df[i * chunk_size : (i + 1) * chunk_size]
@@ -144,9 +146,13 @@ def get_data_by_suffix(path: Path) -> list[dict]:
144146
elif path.suffix == ".ndjson":
145147
return ndjson.load(f)
146148
elif path.suffix == ".csv":
149+
import pandas as pd
150+
147151
df = pd.read_csv(path)
148152
return df.to_dict(orient="records")
149153
elif path.suffix == ".parquet":
154+
import pandas as pd
155+
150156
df = pd.read_parquet(path)
151157
return df.to_dict(orient="records")
152158
else:
@@ -180,6 +186,9 @@ def get_data(path: Union[Path, str]) -> list[dict]:
180186
return ndjson.load(f)
181187
except Exception as e:
182188
logger.warning(f"failed to read {path} as ndjson: {e}")
189+
190+
import pandas as pd
191+
183192
try:
184193
df = pd.read_csv(path)
185194
return df.to_dict(orient="records")
@@ -202,7 +211,10 @@ def get_json_data(path: Path) -> list[dict]:
202211
raise ValueError(f"Unsupported file type: {path}")
203212

204213

205-
def get_data_df(path: Path) -> pd.DataFrame:
214+
@requires_dependencies(["pandas"])
215+
def get_data_df(path: Path) -> "DataFrame":
216+
import pandas as pd
217+
206218
with path.open() as f:
207219
if path.suffix == ".json":
208220
data = json.load(f)

unstructured_ingest/utils/table.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from typing import Any
2-
3-
import pandas as pd
1+
from typing import TYPE_CHECKING, Any
42

53
from unstructured_ingest.utils.data_prep import flatten_dict
4+
from unstructured_ingest.utils.dep_check import requires_dependencies
5+
6+
if TYPE_CHECKING:
7+
from pandas import DataFrame
68

79

10+
@requires_dependencies(["pandas"])
811
def get_default_pandas_dtypes() -> dict[str, Any]:
12+
import pandas as pd
13+
914
return {
1015
"text": pd.StringDtype(), # type: ignore
1116
"type": pd.StringDtype(), # type: ignore
@@ -57,7 +62,9 @@ def get_default_pandas_dtypes() -> dict[str, Any]:
5762
def convert_to_pandas_dataframe(
5863
elements_dict: list[dict[str, Any]],
5964
drop_empty_cols: bool = False,
60-
) -> pd.DataFrame:
65+
) -> "DataFrame":
66+
import pandas as pd
67+
6168
# Flatten metadata if it hasn't already been flattened
6269
for d in elements_dict:
6370
if metadata := d.pop("metadata", None):

unstructured_ingest/v2/processes/connectors/delta_table.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from dataclasses import dataclass, field
44
from multiprocessing import Process, Queue
55
from pathlib import Path
6-
from typing import Any, Optional
6+
from typing import TYPE_CHECKING, Any, Optional
77
from urllib.parse import urlparse
88

9-
import pandas as pd
109
from pydantic import Field, Secret
1110

1211
from unstructured_ingest.error import DestinationConnectionError
@@ -27,6 +26,9 @@
2726

2827
CONNECTOR_TYPE = "delta_table"
2928

29+
if TYPE_CHECKING:
30+
from pandas import DataFrame
31+
3032

3133
@requires_dependencies(["deltalake"], extras="delta-table")
3234
def write_deltalake_with_error_handling(queue, **kwargs):
@@ -136,7 +138,7 @@ def precheck(self):
136138
logger.error(f"failed to validate connection: {e}", exc_info=True)
137139
raise DestinationConnectionError(f"failed to validate connection: {e}")
138140

139-
def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
141+
def upload_dataframe(self, df: "DataFrame", file_data: FileData) -> None:
140142
updated_upload_path = os.path.join(
141143
self.connection_config.table_uri, file_data.source_identifiers.relative_path
142144
)
@@ -172,7 +174,10 @@ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
172174
logger.error(f"Exception occurred in write_deltalake: {error_message}")
173175
raise RuntimeError(f"Error in write_deltalake: {error_message}")
174176

177+
@requires_dependencies(["pandas"], extras="delta-table")
175178
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
179+
import pandas as pd
180+
176181
df = pd.DataFrame(data=data)
177182
self.upload_dataframe(df=df, file_data=file_data)
178183

0 commit comments

Comments
 (0)