Skip to content

Commit 03cec6b

Browse files
Fix lazy import (#2987) (#2990)
Signed-off-by: Kevin Su <[email protected]> Co-authored-by: Kevin Su <[email protected]>
1 parent 0b4a60a commit 03cec6b

File tree

8 files changed

+70
-77
lines changed

8 files changed

+70
-77
lines changed

flytekit/core/type_engine.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,6 @@ class TypeEngine(typing.Generic[T]):
11461146
_RESTRICTED_TYPES: typing.List[type] = []
11471147
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
11481148
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
1149-
has_lazy_import = False
11501149
lazy_import_lock = threading.Lock()
11511150

11521151
@classmethod
@@ -1255,16 +1254,7 @@ def lazy_import_transformers(cls):
12551254
# Avoid a race condition where concurrent threads may exit lazy_import_transformers before the transformers
12561255
# have been imported. This could be implemented without a lock if you assume python assignments are atomic
12571256
# and re-registering transformers is acceptable, but I decided to play it safe.
1258-
if cls.has_lazy_import:
1259-
return
1260-
cls.has_lazy_import = True
1261-
from flytekit.types.structured import (
1262-
register_arrow_handlers,
1263-
register_bigquery_handlers,
1264-
register_pandas_handlers,
1265-
register_snowflake_handlers,
1266-
)
1267-
from flytekit.types.structured.structured_dataset import DuplicateHandlerError
1257+
from flytekit.types.structured import lazy_import_structured_dataset_handler
12681258

12691259
if is_imported("tensorflow"):
12701260
from flytekit.extras import tensorflow # noqa: F401
@@ -1279,29 +1269,11 @@ def lazy_import_transformers(cls):
12791269
from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401
12801270
except ValueError:
12811271
logger.debug("Transformer for pandas is already registered.")
1282-
try:
1283-
register_pandas_handlers()
1284-
except DuplicateHandlerError:
1285-
logger.debug("Transformer for pandas is already registered.")
1286-
if is_imported("pyarrow"):
1287-
try:
1288-
register_arrow_handlers()
1289-
except DuplicateHandlerError:
1290-
logger.debug("Transformer for arrow is already registered.")
1291-
if is_imported("google.cloud.bigquery"):
1292-
try:
1293-
register_bigquery_handlers()
1294-
except DuplicateHandlerError:
1295-
logger.debug("Transformer for bigquery is already registered.")
12961272
if is_imported("numpy"):
12971273
from flytekit.types import numpy # noqa: F401
12981274
if is_imported("PIL"):
12991275
from flytekit.types.file import image # noqa: F401
1300-
if is_imported("snowflake.connector"):
1301-
try:
1302-
register_snowflake_handlers()
1303-
except DuplicateHandlerError:
1304-
logger.debug("Transformer for snowflake is already registered.")
1276+
lazy_import_structured_dataset_handler()
13051277

13061278
@classmethod
13071279
def to_literal_type(cls, python_type: Type[T]) -> LiteralType:

flytekit/lazy_import/lazy_module.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import sys
33
import types
44

5-
LAZY_MODULES = []
65

6+
class _LazyModule(types.ModuleType):
7+
"""
8+
`lazy_module` returns an instance of this class if the module is not found in the python environment.
9+
"""
710

8-
class LazyModule(types.ModuleType):
911
def __init__(self, module_name: str):
1012
super().__init__(module_name)
1113
self._module_name = module_name
@@ -17,8 +19,12 @@ def __getattribute__(self, attr):
1719
def is_imported(module_name):
1820
"""
1921
This function is used to check if a module has been imported by the regular import.
22+
Return false if module is lazy imported and not used yet.
2023
"""
21-
return module_name in sys.modules and module_name not in LAZY_MODULES
24+
return (
25+
module_name in sys.modules
26+
and object.__getattribute__(lazy_module(module_name), "__class__").__name__ != "_LazyModule"
27+
)
2228

2329

2430
def lazy_module(fullname):
@@ -37,11 +43,12 @@ def lazy_module(fullname):
3743
if spec is None or spec.loader is None:
3844
# Return a lazy module if the module is not found in the python environment,
3945
# so that we can raise a proper error when the user tries to access an attribute in the module.
40-
return LazyModule(fullname)
46+
# The reason to do this is because importlib.util.LazyLoader still requires
47+
# the module to be installed even if you don't use it.
48+
return _LazyModule(fullname)
4149
loader = importlib.util.LazyLoader(spec.loader)
4250
spec.loader = loader
4351
module = importlib.util.module_from_spec(spec)
4452
sys.modules[fullname] = module
45-
LAZY_MODULES.append(module)
4653
loader.exec_module(module)
4754
return module

flytekit/types/structured/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
"""
1414

1515
from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer
16+
from flytekit.lazy_import.lazy_module import is_imported
1617
from flytekit.loggers import logger
1718

1819
from .structured_dataset import (
20+
DuplicateHandlerError,
1921
StructuredDataset,
2022
StructuredDatasetDecoder,
2123
StructuredDatasetEncoder,
22-
StructuredDatasetMetadata,
2324
StructuredDatasetTransformerEngine,
24-
StructuredDatasetType,
2525
)
2626

2727

@@ -84,3 +84,27 @@ def register_snowflake_handlers():
8484
"We won't register snowflake handler for structured dataset because "
8585
"we can't find package snowflake-connector-python"
8686
)
87+
88+
89+
def lazy_import_structured_dataset_handler():
90+
if is_imported("pandas"):
91+
try:
92+
register_pandas_handlers()
93+
register_csv_handlers()
94+
except DuplicateHandlerError:
95+
logger.debug("Transformer for pandas is already registered.")
96+
if is_imported("pyarrow"):
97+
try:
98+
register_arrow_handlers()
99+
except DuplicateHandlerError:
100+
logger.debug("Transformer for arrow is already registered.")
101+
if is_imported("google.cloud.bigquery"):
102+
try:
103+
register_bigquery_handlers()
104+
except DuplicateHandlerError:
105+
logger.debug("Transformer for bigquery is already registered.")
106+
if is_imported("snowflake.connector"):
107+
try:
108+
register_snowflake_handlers()
109+
except DuplicateHandlerError:
110+
logger.debug("Transformer for snowflake is already registered.")

flytekit/types/structured/structured_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,18 @@ def literal(self) -> Optional[literals.StructuredDataset]:
170170
return self._literal_sd
171171

172172
def open(self, dataframe_type: Type[DF]):
173+
from flytekit.types.structured import lazy_import_structured_dataset_handler
174+
175+
"""
176+
Load the handler if needed. For the use case like:
177+
@task
178+
def t1(sd: StructuredDataset):
179+
import pandas as pd
180+
sd.open(pd.DataFrame).all()
181+
182+
pandas is imported inside the task, so pandnas handler won't be loaded during deserialization in type engine.
183+
"""
184+
lazy_import_structured_dataset_handler()
173185
self._dataframe_type = dataframe_type
174186
return self
175187

tests/flytekit/unit/core/test_generice_idl_type_engine.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3439,35 +3439,6 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput:
34393439
assert none_value_output is None, f"None value was {none_value_output}, not None as expected"
34403440

34413441

3442-
@pytest.mark.serial
3443-
def test_lazy_import_transformers_concurrently():
3444-
# Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure
3445-
# this achieves what we expect.
3446-
TypeEngine.has_lazy_import = False
3447-
3448-
# Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order
3449-
after_import_mock, mock_register = mock.Mock(), mock.Mock()
3450-
mock_wrapper = mock.Mock()
3451-
mock_wrapper.mock_register = mock_register
3452-
mock_wrapper.after_import_mock = after_import_mock
3453-
3454-
with mock.patch.object(StructuredDatasetTransformerEngine, "register", new=mock_register):
3455-
def run():
3456-
TypeEngine.lazy_import_transformers()
3457-
after_import_mock()
3458-
3459-
N = 5
3460-
with ThreadPoolExecutor(max_workers=N) as executor:
3461-
futures = [executor.submit(run) for _ in range(N)]
3462-
[f.result() for f in futures]
3463-
3464-
# Assert that all the register calls come before anything else.
3465-
assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()] * N
3466-
expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N
3467-
assert all([mock_call[0] == "mock_register" for mock_call in
3468-
mock_wrapper.mock_calls[:expected_number_of_register_calls]])
3469-
3470-
34713442
@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9")
34723443
def test_option_list_with_pipe():
34733444
pt = list[int] | None

tests/flytekit/unit/core/test_type_engine.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,10 +3451,6 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput:
34513451

34523452
@pytest.mark.serial
34533453
def test_lazy_import_transformers_concurrently():
3454-
# Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure
3455-
# this achieves what we expect.
3456-
TypeEngine.has_lazy_import = False
3457-
34583454
# Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order
34593455
after_import_mock, mock_register = mock.Mock(), mock.Mock()
34603456
mock_wrapper = mock.Mock()
@@ -3471,11 +3467,11 @@ def run():
34713467
futures = [executor.submit(run) for _ in range(N)]
34723468
[f.result() for f in futures]
34733469

3474-
# Assert that all the register calls come before anything else.
3475-
assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()] * N
3470+
assert mock_wrapper.mock_calls[-1] == mock.call.after_import_mock()
34763471
expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N
3472+
assert sum([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls]) == expected_number_of_register_calls
34773473
assert all([mock_call[0] == "mock_register" for mock_call in
3478-
mock_wrapper.mock_calls[:expected_number_of_register_calls]])
3474+
mock_wrapper.mock_calls[:int(len(mock_wrapper.mock_calls)/N)-1]])
34793475

34803476

34813477
@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9")
Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1-
import pytest
1+
import sys
2+
from unittest.mock import Mock
23

3-
from flytekit.lazy_import.lazy_module import LazyModule, lazy_module
4+
import pytest
5+
from flytekit.lazy_import.lazy_module import _LazyModule, lazy_module, is_imported
46

57

68
def test_lazy_module():
79
mod = lazy_module("click")
810
assert mod.__name__ == "click"
911
mod = lazy_module("fake_module")
10-
assert isinstance(mod, LazyModule)
12+
13+
sys.modules["fake_module"] = mod
14+
assert not is_imported("fake_module")
15+
assert isinstance(mod, _LazyModule)
1116
with pytest.raises(ImportError, match="Module fake_module is not yet installed."):
1217
print(mod.attr)
18+
19+
non_lazy_module = Mock()
20+
non_lazy_module.__name__ = 'NonLazyModule'
21+
sys.modules["fake_module"] = non_lazy_module
22+
assert is_imported("fake_module")
23+
24+
assert is_imported("dataclasses")

tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from flytekit.models.literals import StructuredDatasetMetadata
2323
from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType
2424
from flytekit.tools.translator import get_serializable
25-
from flytekit.types.file import FlyteFile
2625
from flytekit.types.structured.structured_dataset import (
2726
PARQUET,
2827
StructuredDataset,

0 commit comments

Comments
 (0)