diff --git a/dev-requirements.in b/dev-requirements.in index 34f11ff34f..cef6ce1929 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -46,6 +46,7 @@ types-protobuf<5 types-croniter types-decorator types-mock +types-cachetools autoflake pillow diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 58ba0b8556..2895147a06 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Type, cast import msgpack +from cachetools import LRUCache from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212 @@ -1174,6 +1175,7 @@ class TypeEngine(typing.Generic[T]): _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore _ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore lazy_import_lock = threading.Lock() + _LITERAL_CACHE: LRUCache = LRUCache(maxsize=128) @classmethod def register( @@ -1377,6 +1379,22 @@ def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optiona break return hsh + @classmethod + def _get_literal_cache_key(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[tuple]: + import cloudpickle + + val_hash: int + type_hash: int + try: + val_hash = hash(cloudpickle.dumps(python_val)) + type_hash = hash(cloudpickle.dumps(python_type)) + + return (val_hash, type_hash) + + except Exception: + logger.warning(f"Could not hash python_val: {python_val} or python_type: {python_type}") + return None + @classmethod def to_literal( cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType @@ -1386,6 +1404,10 @@ def to_literal( to_literal function, and allowing this to_literal function, to then invoke yet another async function, namely an async transformer. """ + key = cls._get_literal_cache_key(python_val, python_type) + if key is not None and key in cls._LITERAL_CACHE: + return cls._LITERAL_CACHE[key] + from flytekit.core.promise import Promise cls.to_literal_checks(python_val, python_type, expected) @@ -1406,6 +1428,10 @@ def to_literal( modify_literal_uris(lv) lv.hash = cls.calculate_hash(python_val, python_type) + + if key is not None: + cls._LITERAL_CACHE[key] = lv + return lv @classmethod diff --git a/tests/flytekit/unit/core/test_list.py b/tests/flytekit/unit/core/test_list.py index 96ee2efe78..75917c3a0c 100644 --- a/tests/flytekit/unit/core/test_list.py +++ b/tests/flytekit/unit/core/test_list.py @@ -73,6 +73,8 @@ async def test_coroutine_batching_of_list_transformer(): lt = LiteralType(simple=SimpleType.INTEGER) python_val = [MyInt(10), MyInt(11), MyInt(12), MyInt(13), MyInt(14)] + # Use the different python_val to avoid hitting the cache + python_val_2 = [MyInt(11), MyInt(10), MyInt(12), MyInt(13), MyInt(14)] ctx = FlyteContext.current_context() with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 2): @@ -80,6 +82,10 @@ async def test_coroutine_batching_of_list_transformer(): with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5): with pytest.raises(ValueError): + TypeEngine.to_literal(ctx, python_val_2, typing.List[MyInt], lt) + + # Cache hit for python_val prevents async_to_literal calls, avoiding the batch size limit of 2 error defined in MyIntAsyncTransformer + with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5): TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt) del TypeEngine._REGISTRY[MyInt] diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 93d5d6af67..7e04ab0214 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3851,3 +3851,145 @@ async def test_dict_transformer_annotated_type(): literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 + +@pytest.fixture(autouse=True) +def clear_type_engine_cache(): + """Clear TypeEngine cache before and after each test""" + TypeEngine._LITERAL_CACHE.clear() + yield + TypeEngine._LITERAL_CACHE.clear() + +def test_type_engine_cache_with_list(): + ctx = FlyteContext.current_context() + python_val = [1, 2, 3, 4, 5] + python_type = typing.List[int] + expected = TypeEngine.to_literal_type(python_type) + list_transformer = TypeEngine.get_transformer(typing.List[int]) + with mock.patch.object(list_transformer, 'async_to_literal', + wraps=list_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + key = TypeEngine._get_literal_cache_key(python_val, python_type) + assert key is not None + assert key in TypeEngine._LITERAL_CACHE + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + # Verify async_to_literal was only called once + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + + # Test with different data - should not use cache + different_val = [2, 1, 3, 4, 5] + literal3 = TypeEngine.to_literal(ctx, different_val, python_type, expected) + key_different = TypeEngine._get_literal_cache_key(different_val, python_type) + + assert key_different is not key + assert key_different is not None + assert key_different in TypeEngine._LITERAL_CACHE + + # Verify different literals are different objects + assert literal1 is not literal3 + + # Add many different values to test cache size limit + for i in range(200): # More than the default maxsize of 128 + test_val = [i, i+1, i+2] + test_type = typing.List[int] + test_expected = TypeEngine.to_literal_type(test_type) + TypeEngine.to_literal(ctx, test_val, test_type, test_expected) + + # Cache should not exceed maxsize + assert len(TypeEngine._LITERAL_CACHE) == 128 + +def test_type_engine_cache_with_dict(): + ctx = FlyteContext.current_context() + python_val = {"a": [1, 2, 3]} + python_type = typing.Dict[str, typing.List[int]] + expected = TypeEngine.to_literal_type(python_type) + dict_transformer = TypeEngine.get_transformer(typing.Dict[str, typing.List[int]]) + with mock.patch.object(dict_transformer, 'async_to_literal', + wraps=dict_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + key = TypeEngine._get_literal_cache_key(python_val, python_type) + assert key is not None + assert key in TypeEngine._LITERAL_CACHE + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + # Verify async_to_literal was only called once + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + +def test_make_key_with_annotated_types(): + # Test with Annotated type + annotated_val = [1, 2, 3] + annotated_type = typing.Annotated[typing.List[int], "test_annotation"] + + key = TypeEngine._get_literal_cache_key(annotated_val, annotated_type) + key_without_annotation = TypeEngine._get_literal_cache_key(annotated_val, typing.List[int]) + # Should handle Annotated types correctly + assert key is not None + assert key_without_annotation is not None + assert key != key_without_annotation + +def test_type_engine_cache_with_pandas(): + pd = pytest.importorskip("pandas") + ctx = FlyteContext.current_context() + # Create DataFrame + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_type = pd.DataFrame + df_expected = TypeEngine.to_literal_type(df_type) + + # Get the transformer for DataFrame + df_transformer = TypeEngine._REGISTRY[pd.DataFrame] + + # Mock the async_to_literal method with wraps to track calls + with mock.patch.object(df_transformer, 'async_to_literal', + wraps=df_transformer.async_to_literal) as mock_async_to_literal: + + # First call + literal1 = TypeEngine.to_literal(ctx, df, df_type, df_expected) + + # Second call with same DataFrame + literal2 = TypeEngine.to_literal(ctx, df, df_type, df_expected) + + # Verify async_to_literal was called + assert mock_async_to_literal.call_count == 1 + + assert literal1 is literal2 + +def test_type_engine_cache_with_flytefile(): + + transformer = TypeEngine.get_transformer(FlyteFile) + ctx = FlyteContext.current_context() + + temp_dir = tempfile.mkdtemp(prefix="temp_example_") + file_path = os.path.join(temp_dir, "file.txt") + with open(file_path, "w") as file1: + file1.write("hello world") + + lt = TypeEngine.to_literal_type(FlyteFile) + + # Mock the file upload + with mock.patch.object(transformer, 'async_to_literal', + wraps=transformer.async_to_literal) as mock_async_to_literal: + + # Test 1: Upload local file to remote + lv1 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt) + + # Second call with same DataFrame + lv2 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt) + + # Verify async_to_literal was called + assert mock_async_to_literal.call_count == 1 + + assert lv1 is lv2