Skip to content

Commit f3f5102

Browse files
thetorpedodogihnorton
authored andcommitted
Avoid importing Pandas until we actually use it.
Importing Pandas takes a significant amount of time; nearly half a second in UDFs. Previously, we imported it eagerly if it was present, even though outside of a couple functions it was only used for type annotations. This change ensures that we only import it exactly when we actually use it.
1 parent e24fbce commit f3f5102

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

tiledb/multirange_indexing.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import json
23
import time
34
import weakref
@@ -8,6 +9,7 @@
89
from dataclasses import dataclass
910
from numbers import Real
1011
from typing import (
12+
TYPE_CHECKING,
1113
Any,
1214
Dict,
1315
Iterator,
@@ -30,19 +32,14 @@
3032
from .query_condition import QueryCondition
3133
from .subarray import Subarray
3234

33-
current_timer: ContextVar[str] = ContextVar("timer_scope")
34-
35-
try:
35+
if TYPE_CHECKING:
36+
# We don't want to import these eagerly since importing Pandas in particular
37+
# can add around half a second of import time even if we never use it.
38+
import pandas
3639
import pyarrow
37-
from pyarrow import Table
38-
except ImportError:
39-
pyarrow = Table = None
4040

41-
try:
42-
import pandas as pd
43-
from pandas import DataFrame
44-
except ImportError:
45-
DataFrame = None
41+
42+
current_timer: ContextVar[str] = ContextVar("timer_scope")
4643

4744

4845
# sentinel value to denote selecting an empty range
@@ -373,8 +370,12 @@ def __init__(
373370
# we need to use a Query in order to get coords for a dense array
374371
if not query:
375372
query = QueryProxy(array, coords=True)
376-
if use_arrow is None:
377-
use_arrow = pyarrow is not None
373+
use_arrow = (
374+
bool(importlib.util.find_spec("pyarrow"))
375+
if use_arrow is None
376+
else use_arrow
377+
)
378+
378379
# TODO: currently there is lack of support for Arrow list types. This prevents
379380
# multi-value attributes, asides from strings, from being queried properly.
380381
# Until list attributes are supported in core, error with a clear message.
@@ -390,12 +391,15 @@ def __init__(
390391
)
391392
super().__init__(array, query, use_arrow, preload_metadata=True)
392393

393-
def _run_query(self) -> Union[DataFrame, Table]:
394+
def _run_query(self) -> Union["pandas.DataFrame", "pyarrow.Table"]:
395+
import pandas
396+
import pyarrow
397+
394398
if self.pyquery is not None:
395399
self.pyquery.submit()
396400

397401
if self.pyquery is None:
398-
df = DataFrame(self._empty_results)
402+
df = pandas.DataFrame(self._empty_results)
399403
elif self.use_arrow:
400404
with timing("buffer_conversion_time"):
401405
table = self.pyquery._buffers_to_pa_table()
@@ -417,14 +421,14 @@ def _run_query(self) -> Union[DataFrame, Table]:
417421
# converting all integers with NULLs to float64:
418422
# https://arrow.apache.org/docs/python/pandas.html#arrow-pandas-conversion
419423
extended_dtype_mapping = {
420-
pyarrow.int8(): pd.Int8Dtype(),
421-
pyarrow.int16(): pd.Int16Dtype(),
422-
pyarrow.int32(): pd.Int32Dtype(),
423-
pyarrow.int64(): pd.Int64Dtype(),
424-
pyarrow.uint8(): pd.UInt8Dtype(),
425-
pyarrow.uint16(): pd.UInt16Dtype(),
426-
pyarrow.uint32(): pd.UInt32Dtype(),
427-
pyarrow.uint64(): pd.UInt64Dtype(),
424+
pyarrow.int8(): pandas.Int8Dtype(),
425+
pyarrow.int16(): pandas.Int16Dtype(),
426+
pyarrow.int32(): pandas.Int32Dtype(),
427+
pyarrow.int64(): pandas.Int64Dtype(),
428+
pyarrow.uint8(): pandas.UInt8Dtype(),
429+
pyarrow.uint16(): pandas.UInt16Dtype(),
430+
pyarrow.uint32(): pandas.UInt32Dtype(),
431+
pyarrow.uint64(): pandas.UInt64Dtype(),
428432
}
429433
dtype = extended_dtype_mapping[pa_attr.type]
430434
else:
@@ -463,7 +467,7 @@ def _run_query(self) -> Union[DataFrame, Table]:
463467

464468
df = table.to_pandas()
465469
else:
466-
df = DataFrame(_get_pyquery_results(self.pyquery, self.array.schema))
470+
df = pandas.DataFrame(_get_pyquery_results(self.pyquery, self.array.schema))
467471

468472
with timing("pandas_index_update_time"):
469473
return _update_df_from_meta(df, self.array.meta, self.query.index_col)
@@ -663,8 +667,10 @@ def _get_empty_results(
663667

664668

665669
def _update_df_from_meta(
666-
df: DataFrame, array_meta: Metadata, index_col: Union[List[str], bool, None] = True
667-
) -> DataFrame:
670+
df: "pandas.DataFrame",
671+
array_meta: Metadata,
672+
index_col: Union[List[str], bool, None] = True,
673+
) -> "pandas.DataFrame":
668674
col_dtypes = {}
669675
if "__pandas_attribute_repr" in array_meta:
670676
attr_dtypes = json.loads(array_meta["__pandas_attribute_repr"])

tiledb/tests/test_basic_import.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import ast
2+
import subprocess
3+
import sys
4+
5+
6+
def test_dont_import_pandas() -> None:
7+
"""Verifies that when we import TileDB, we don't import Pandas eagerly."""
8+
# Get a list of all modules from a completely fresh interpreter.
9+
all_mods_str = subprocess.check_output(
10+
(sys.executable, "-c", "import sys, tiledb; print(list(sys.modules))")
11+
)
12+
all_mods = ast.literal_eval(all_mods_str.decode())
13+
assert "pandas" not in all_mods

0 commit comments

Comments
 (0)