Skip to content

Commit 9a01ece

Browse files
andersbogsnesAnders Bogsnes
authored andcommitted
Add import check for optional dependency on pyiceberg_core (apache#2221)
Added a NotInstalledException check when pyiceberg_core is imported for pyarrow_transforms This will raise a helpful error message when endusers try to use methods that depend on pyiceberg_core but haven't installed the optional dependency <!-- Thanks for opening a pull request! --> Closes apache#1987 # Rationale for this change If an enduser hasn't installed the `pyiceberg_core` optional dependency, using pyarrow transforms will crash with an unhelpful error. This PR gives the enduser a nicer error message that informs them to install the optional dependency # Are these changes tested? Yes, a test was added to verify the behaviour # Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: Anders Bogsnes <[email protected]>
1 parent 25e5792 commit 9a01ece

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

pyiceberg/transforms.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
import base64
1919
import datetime as py_datetime
20+
import importlib
2021
import struct
22+
import types
2123
from abc import ABC, abstractmethod
2224
from enum import IntEnum
2325
from functools import singledispatch
@@ -28,6 +30,7 @@
2830
import mmh3
2931
from pydantic import Field, PositiveInt, PrivateAttr
3032

33+
from pyiceberg.exceptions import NotInstalledError
3134
from pyiceberg.expressions import (
3235
BoundEqualTo,
3336
BoundGreaterThan,
@@ -106,6 +109,17 @@
106109
TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE)
107110

108111

112+
def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.ModuleType:
113+
try:
114+
return importlib.import_module(module_name)
115+
except ImportError:
116+
if extras_name:
117+
msg = f'{module_name} needs to be installed. pip install "pyiceberg[{extras_name}]"'
118+
else:
119+
msg = f"{module_name} needs to be installed."
120+
raise NotInstalledError(msg) from None
121+
122+
109123
def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
110124
"""Small helper to upwrap the value from the literal, and wrap it again."""
111125
return literal(func(lit.value))
@@ -382,8 +396,7 @@ def __repr__(self) -> str:
382396
return f"BucketTransform(num_buckets={self._num_buckets})"
383397

384398
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
385-
from pyiceberg_core import transform as pyiceberg_core_transform
386-
399+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
387400
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)
388401

389402
@property
@@ -509,9 +522,8 @@ def __repr__(self) -> str:
509522
return "YearTransform()"
510523

511524
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
512-
import pyarrow as pa
513-
from pyiceberg_core import transform as pyiceberg_core_transform
514-
525+
pa = _try_import("pyarrow")
526+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
515527
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.year, expected_type=pa.int32())
516528

517529

@@ -570,8 +582,8 @@ def __repr__(self) -> str:
570582
return "MonthTransform()"
571583

572584
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
573-
import pyarrow as pa
574-
from pyiceberg_core import transform as pyiceberg_core_transform
585+
pa = _try_import("pyarrow")
586+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
575587

576588
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.month, expected_type=pa.int32())
577589

@@ -639,8 +651,8 @@ def __repr__(self) -> str:
639651
return "DayTransform()"
640652

641653
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
642-
import pyarrow as pa
643-
from pyiceberg_core import transform as pyiceberg_core_transform
654+
pa = _try_import("pyarrow", extras_name="pyarrow")
655+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
644656

645657
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.day, expected_type=pa.int32())
646658

@@ -692,7 +704,7 @@ def __repr__(self) -> str:
692704
return "HourTransform()"
693705

694706
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
695-
from pyiceberg_core import transform as pyiceberg_core_transform
707+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
696708

697709
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.hour)
698710

@@ -915,7 +927,7 @@ def __repr__(self) -> str:
915927
return f"TruncateTransform(width={self._width})"
916928

917929
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
918-
from pyiceberg_core import transform as pyiceberg_core_transform
930+
pyiceberg_core_transform = _try_import("pyiceberg_core", extras_name="pyiceberg-core").transform
919931

920932
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)
921933

tests/test_transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
RootModel,
3131
WithJsonSchema,
3232
)
33+
from pytest_mock import MockFixture
3334

35+
from pyiceberg.exceptions import NotInstalledError
3436
from pyiceberg.expressions import (
3537
AlwaysFalse,
3638
BooleanExpression,
@@ -1668,3 +1670,15 @@ def test_truncate_pyarrow_transforms(
16681670
) -> None:
16691671
transform: Transform[Any, Any] = TruncateTransform(width=width)
16701672
assert expected == transform.pyarrow_transform(source_type)(input_arr)
1673+
1674+
1675+
@pytest.mark.parametrize(
1676+
"transform", [BucketTransform(num_buckets=5), TruncateTransform(width=5), YearTransform(), MonthTransform(), DayTransform()]
1677+
)
1678+
def test_calling_pyarrow_transform_without_pyiceberg_core_installed_correctly_raises_not_imported_error(
1679+
transform, mocker: MockFixture
1680+
) -> None:
1681+
mocker.patch.dict("sys.modules", {"pyiceberg_core": None})
1682+
1683+
with pytest.raises(NotInstalledError):
1684+
transform.pyarrow_transform(StringType())

0 commit comments

Comments
 (0)