Skip to content
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ types-protobuf<5
types-croniter
types-decorator
types-mock
types-cachetools
autoflake

pillow
Expand Down
26 changes: 26 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from cachetools import LRUCache
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
Expand Down Expand Up @@ -1174,6 +1175,7 @@ class TypeEngine(typing.Generic[T]):
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
lazy_import_lock = threading.Lock()
_LITERAL_CACHE: LRUCache = LRUCache(maxsize=128)

@classmethod
def register(
Expand Down Expand Up @@ -1377,6 +1379,22 @@ def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optiona
break
return hsh

@classmethod
def _get_literal_cache_key(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[tuple]:
import cloudpickle

val_hash: int
type_hash: int
try:
val_hash = hash(cloudpickle.dumps(python_val))
type_hash = hash(cloudpickle.dumps(python_type))

return (val_hash, type_hash)

except Exception:
logger.warning(f"Could not hash python_val: {python_val} or python_type: {python_type}")
return None

@classmethod
def to_literal(
cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType
Expand All @@ -1386,6 +1404,10 @@ def to_literal(
to_literal function, and allowing this to_literal function, to then invoke yet another async function,
namely an async transformer.
"""
key = cls._get_literal_cache_key(python_val, python_type)
if key is not None and key in cls._LITERAL_CACHE:
return cls._LITERAL_CACHE[key]

from flytekit.core.promise import Promise

cls.to_literal_checks(python_val, python_type, expected)
Expand All @@ -1406,6 +1428,10 @@ def to_literal(

modify_literal_uris(lv)
lv.hash = cls.calculate_hash(python_val, python_type)

if key is not None:
cls._LITERAL_CACHE[key] = lv

return lv

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions tests/flytekit/unit/core/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,19 @@ async def test_coroutine_batching_of_list_transformer():

lt = LiteralType(simple=SimpleType.INTEGER)
python_val = [MyInt(10), MyInt(11), MyInt(12), MyInt(13), MyInt(14)]
# Use the different python_val to avoid hitting the cache
python_val_2 = [MyInt(11), MyInt(10), MyInt(12), MyInt(13), MyInt(14)]
ctx = FlyteContext.current_context()

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 2):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
with pytest.raises(ValueError):
TypeEngine.to_literal(ctx, python_val_2, typing.List[MyInt], lt)

# Cache hit for python_val prevents async_to_literal calls, avoiding the batch size limit of 2 error defined in MyIntAsyncTransformer
with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

del TypeEngine._REGISTRY[MyInt]
142 changes: 142 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,3 +3851,145 @@ async def test_dict_transformer_annotated_type():

literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type)
assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42

@pytest.fixture(autouse=True)
def clear_type_engine_cache():
"""Clear TypeEngine cache before and after each test"""
TypeEngine._LITERAL_CACHE.clear()
yield
TypeEngine._LITERAL_CACHE.clear()

def test_type_engine_cache_with_list():
ctx = FlyteContext.current_context()
python_val = [1, 2, 3, 4, 5]
python_type = typing.List[int]
expected = TypeEngine.to_literal_type(python_type)
list_transformer = TypeEngine.get_transformer(typing.List[int])
with mock.patch.object(list_transformer, 'async_to_literal',
wraps=list_transformer.async_to_literal) as mock_async_to_literal:

# First call
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

key = TypeEngine._get_literal_cache_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._LITERAL_CACHE

# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Verify async_to_literal was only called once
assert mock_async_to_literal.call_count == 1

assert literal1 is literal2

# Test with different data - should not use cache
different_val = [2, 1, 3, 4, 5]
literal3 = TypeEngine.to_literal(ctx, different_val, python_type, expected)
key_different = TypeEngine._get_literal_cache_key(different_val, python_type)

assert key_different is not key
assert key_different is not None
assert key_different in TypeEngine._LITERAL_CACHE

# Verify different literals are different objects
assert literal1 is not literal3

# Add many different values to test cache size limit
for i in range(200): # More than the default maxsize of 128
test_val = [i, i+1, i+2]
test_type = typing.List[int]
test_expected = TypeEngine.to_literal_type(test_type)
TypeEngine.to_literal(ctx, test_val, test_type, test_expected)

# Cache should not exceed maxsize
assert len(TypeEngine._LITERAL_CACHE) == 128

def test_type_engine_cache_with_dict():
ctx = FlyteContext.current_context()
python_val = {"a": [1, 2, 3]}
python_type = typing.Dict[str, typing.List[int]]
expected = TypeEngine.to_literal_type(python_type)
dict_transformer = TypeEngine.get_transformer(typing.Dict[str, typing.List[int]])
with mock.patch.object(dict_transformer, 'async_to_literal',
wraps=dict_transformer.async_to_literal) as mock_async_to_literal:

# First call
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

key = TypeEngine._get_literal_cache_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._LITERAL_CACHE

# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Verify async_to_literal was only called once
assert mock_async_to_literal.call_count == 1

assert literal1 is literal2

def test_make_key_with_annotated_types():
# Test with Annotated type
annotated_val = [1, 2, 3]
annotated_type = typing.Annotated[typing.List[int], "test_annotation"]

key = TypeEngine._get_literal_cache_key(annotated_val, annotated_type)
key_without_annotation = TypeEngine._get_literal_cache_key(annotated_val, typing.List[int])
# Should handle Annotated types correctly
assert key is not None
assert key_without_annotation is not None
assert key != key_without_annotation

def test_type_engine_cache_with_pandas():
pd = pytest.importorskip("pandas")
ctx = FlyteContext.current_context()
# Create DataFrame
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df_type = pd.DataFrame
df_expected = TypeEngine.to_literal_type(df_type)

# Get the transformer for DataFrame
df_transformer = TypeEngine._REGISTRY[pd.DataFrame]

# Mock the async_to_literal method with wraps to track calls
with mock.patch.object(df_transformer, 'async_to_literal',
wraps=df_transformer.async_to_literal) as mock_async_to_literal:

# First call
literal1 = TypeEngine.to_literal(ctx, df, df_type, df_expected)

# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, df, df_type, df_expected)

# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

assert literal1 is literal2

def test_type_engine_cache_with_flytefile():

transformer = TypeEngine.get_transformer(FlyteFile)
ctx = FlyteContext.current_context()

temp_dir = tempfile.mkdtemp(prefix="temp_example_")
file_path = os.path.join(temp_dir, "file.txt")
with open(file_path, "w") as file1:
file1.write("hello world")

lt = TypeEngine.to_literal_type(FlyteFile)

# Mock the file upload
with mock.patch.object(transformer, 'async_to_literal',
wraps=transformer.async_to_literal) as mock_async_to_literal:

# Test 1: Upload local file to remote
lv1 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt)

# Second call with same DataFrame
lv2 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt)

# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

assert lv1 is lv2
Loading