Skip to content
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ types-protobuf<5
types-croniter
types-decorator
types-mock
types-cachetools
autoflake

pillow
Expand Down
32 changes: 32 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from cachetools import LRUCache
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
Expand Down Expand Up @@ -1174,6 +1175,7 @@ class TypeEngine(typing.Generic[T]):
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
lazy_import_lock = threading.Lock()
_CACHE: LRUCache = LRUCache(maxsize=128)

@classmethod
def register(
Expand Down Expand Up @@ -1377,6 +1379,28 @@ def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optiona
break
return hsh

@classmethod
def make_key(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[tuple]:
import cloudpickle

val_hash: typing.Any
type_hash: typing.Any
try:
try:
val_hash = hash(python_val)
except Exception:
val_hash = hash(cloudpickle.dumps(python_val))

try:
type_hash = hash(python_type)
except Exception:
type_hash = hash(cloudpickle.dumps(python_type))

return (val_hash, type_hash)

except Exception:
return None

@classmethod
def to_literal(
cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType
Expand All @@ -1386,6 +1410,10 @@ def to_literal(
to_literal function, and allowing this to_literal function, to then invoke yet another async function,
namely an async transformer.
"""
key = cls.make_key(python_val, python_type)
if key is not None and key in cls._CACHE:
return cls._CACHE[key]

from flytekit.core.promise import Promise

cls.to_literal_checks(python_val, python_type, expected)
Expand All @@ -1406,6 +1434,10 @@ def to_literal(

modify_literal_uris(lv)
lv.hash = cls.calculate_hash(python_val, python_type)

if key is not None:
cls._CACHE[key] = lv

return lv

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion tests/flytekit/unit/core/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@ async def test_coroutine_batching_of_list_transformer():

lt = LiteralType(simple=SimpleType.INTEGER)
python_val = [MyInt(10), MyInt(11), MyInt(12), MyInt(13), MyInt(14)]
# Use the different python_val to avoid hitting the cache
python_val_2 = [MyInt(11), MyInt(10), MyInt(12), MyInt(13), MyInt(14)]
ctx = FlyteContext.current_context()

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 2):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
with pytest.raises(ValueError):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)
TypeEngine.to_literal(ctx, python_val_2, typing.List[MyInt], lt)

del TypeEngine._REGISTRY[MyInt]
100 changes: 100 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,3 +3851,103 @@ async def test_dict_transformer_annotated_type():

literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type)
assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42

def test_type_engine_cache():
# Clear cache before test
TypeEngine._CACHE.clear()

# Test data
ctx = FlyteContext.current_context()
python_val = [1, 2, 3, 4, 5]
python_type = typing.List[int]
expected = TypeEngine.to_literal_type(python_type)

# First call - should not use cache
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Verify cache is populated
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE

# Second call with same parameters - should use cache
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Verify both literals are identical (same object from cache)
assert literal1 is literal2

# Test with different data - should not use cache
different_val = [2, 1, 3, 4, 5]
literal3 = TypeEngine.to_literal(ctx, different_val, python_type, expected)
key_different = TypeEngine.make_key(different_val, python_type)

assert key_different is not key
assert key_different is not None
assert key_different in TypeEngine._CACHE

# Verify different literals are different objects
assert literal1 is not literal3

# Test cache with unhashable objects
python_val = {"a": [1, 2, 3]}
python_type = typing.Dict[str, typing.List[int]]
expected = TypeEngine.to_literal_type(python_type)

# First call
literal4 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE

# Second call with same unhashable data
literal5 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Should be the same object since unhashable objects will fallback to cloudpickle to hash
assert literal4 is literal5

# Add many different values to test cache size limit
for i in range(200): # More than the default maxsize of 128
test_val = [i, i+1, i+2]
test_type = typing.List[int]
test_expected = TypeEngine.to_literal_type(test_type)
TypeEngine.to_literal(ctx, test_val, test_type, test_expected)

# Cache should not exceed maxsize
assert len(TypeEngine._CACHE) == 128

# Test cache with pandas DataFrame (if available)
try:
import pandas as pd

# Create DataFrame
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df_type = pd.DataFrame
df_expected = TypeEngine.to_literal_type(df_type)

# First call
literal6 = TypeEngine.to_literal(ctx, df, df_type, df_expected)

# Second call with same DataFrame
literal7 = TypeEngine.to_literal(ctx, df, df_type, df_expected)

# Should be same object (DataFrame should be hashable via cloudpickle)
assert literal6 is literal7

except ImportError:
# Skip pandas test if not available
pass

# Clean up
TypeEngine._CACHE.clear()

def test_make_key_with_annotated_types():
# Test with Annotated type
annotated_val = [1, 2, 3]
annotated_type = typing.Annotated[typing.List[int], "test_annotation"]

key = TypeEngine.make_key(annotated_val, annotated_type)
key_without_annotation = TypeEngine.make_key(annotated_val, typing.List[int])
# Should handle Annotated types correctly
assert key is not None
assert key_without_annotation is not None
assert key != key_without_annotation
Loading