Skip to content
16 changes: 5 additions & 11 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,22 +1383,16 @@ def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optiona
def make_key(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[tuple]:
import cloudpickle

val_hash: typing.Any
type_hash: typing.Any
val_hash: int
type_hash: int
try:
try:
val_hash = hash(python_val)
except Exception:
val_hash = hash(cloudpickle.dumps(python_val))

try:
type_hash = hash(python_type)
except Exception:
type_hash = hash(cloudpickle.dumps(python_type))
val_hash = hash(cloudpickle.dumps(python_val))
type_hash = hash(cloudpickle.dumps(python_type))

return (val_hash, type_hash)

except Exception:
logger.warning(f"Could not hash python_val: {python_val} or python_type: {python_type}")
return None

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,7 @@ async def test_coroutine_batching_of_list_transformer():
with pytest.raises(ValueError):
TypeEngine.to_literal(ctx, python_val_2, typing.List[MyInt], lt)

with mock.patch("flytekit.core.type_engine._TYPE_ENGINE_COROS_BATCH_SIZE", 5):
TypeEngine.to_literal(ctx, python_val, typing.List[MyInt], lt)

del TypeEngine._REGISTRY[MyInt]
142 changes: 94 additions & 48 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3852,29 +3852,37 @@ async def test_dict_transformer_annotated_type():
literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type)
assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42

def test_type_engine_cache():
# Clear cache before test
@pytest.fixture(autouse=True)
def clear_type_engine_cache():
"""Clear TypeEngine cache before and after each test"""
TypeEngine._CACHE.clear()
yield
TypeEngine._CACHE.clear()

# Test data
def test_type_engine_cache_with_list():
ctx = FlyteContext.current_context()
python_val = [1, 2, 3, 4, 5]
python_type = typing.List[int]
expected = TypeEngine.to_literal_type(python_type)
list_transformer = TypeEngine.get_transformer(typing.List[int])
with mock.patch.object(list_transformer, 'async_to_literal',
wraps=list_transformer.async_to_literal) as mock_async_to_literal:

# First call - should not use cache
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
# First call
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
assert mock_async_to_literal.call_count == 1

# Verify cache is populated
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE

# Second call with same parameters - should use cache
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Verify both literals are identical (same object from cache)
assert literal1 is literal2
# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

assert literal1 is literal2

# Test with different data - should not use cache
different_val = [2, 1, 3, 4, 5]
Expand All @@ -3888,23 +3896,6 @@ def test_type_engine_cache():
# Verify different literals are different objects
assert literal1 is not literal3

# Test cache with unhashable objects
python_val = {"a": [1, 2, 3]}
python_type = typing.Dict[str, typing.List[int]]
expected = TypeEngine.to_literal_type(python_type)

# First call
literal4 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE

# Second call with same unhashable data
literal5 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

# Should be the same object since unhashable objects will fallback to cloudpickle to hash
assert literal4 is literal5

# Add many different values to test cache size limit
for i in range(200): # More than the default maxsize of 128
test_val = [i, i+1, i+2]
Expand All @@ -3915,30 +3906,30 @@ def test_type_engine_cache():
# Cache should not exceed maxsize
assert len(TypeEngine._CACHE) == 128

# Test cache with pandas DataFrame (if available)
try:
import pandas as pd

# Create DataFrame
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df_type = pd.DataFrame
df_expected = TypeEngine.to_literal_type(df_type)
def test_type_engine_cache_with_dict():
ctx = FlyteContext.current_context()
python_val = {"a": [1, 2, 3]}
python_type = typing.Dict[str, typing.List[int]]
expected = TypeEngine.to_literal_type(python_type)
dict_transformer = TypeEngine.get_transformer(typing.Dict[str, typing.List[int]])
with mock.patch.object(dict_transformer, 'async_to_literal',
wraps=dict_transformer.async_to_literal) as mock_async_to_literal:

# First call
literal6 = TypeEngine.to_literal(ctx, df, df_type, df_expected)
literal1 = TypeEngine.to_literal(ctx, python_val, python_type, expected)
assert mock_async_to_literal.call_count == 1

# Second call with same DataFrame
literal7 = TypeEngine.to_literal(ctx, df, df_type, df_expected)
key = TypeEngine.make_key(python_val, python_type)
assert key is not None
assert key in TypeEngine._CACHE

# Should be same object (DataFrame should be hashable via cloudpickle)
assert literal6 is literal7
# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, python_val, python_type, expected)

except ImportError:
# Skip pandas test if not available
pass
# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

# Clean up
TypeEngine._CACHE.clear()
assert literal1 is literal2

def test_make_key_with_annotated_types():
# Test with Annotated type
Expand All @@ -3951,3 +3942,58 @@ def test_make_key_with_annotated_types():
assert key is not None
assert key_without_annotation is not None
assert key != key_without_annotation

def test_type_engine_cache_with_pandas():
import pandas as pd
ctx = FlyteContext.current_context()
# Create DataFrame
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df_type = pd.DataFrame
df_expected = TypeEngine.to_literal_type(df_type)

# Get the transformer for DataFrame
df_transformer = TypeEngine._REGISTRY[pd.DataFrame]

# Mock the async_to_literal method with wraps to track calls
with mock.patch.object(df_transformer, 'async_to_literal',
wraps=df_transformer.async_to_literal) as mock_async_to_literal:

# First call
literal1 = TypeEngine.to_literal(ctx, df, df_type, df_expected)
assert mock_async_to_literal.call_count == 1

# Second call with same DataFrame
literal2 = TypeEngine.to_literal(ctx, df, df_type, df_expected)

# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

assert literal1 is literal2

def test_type_engine_cache_with_flytefile():

transformer = TypeEngine.get_transformer(FlyteFile)
ctx = FlyteContext.current_context()

temp_dir = tempfile.mkdtemp(prefix="temp_example_")
file_path = os.path.join(temp_dir, "file.txt")
with open(file_path, "w") as file1:
file1.write("hello world")

lt = TypeEngine.to_literal_type(FlyteFile)

# Mock the file upload
with mock.patch.object(transformer, 'async_to_literal',
wraps=transformer.async_to_literal) as mock_async_to_literal:

# Test 1: Upload local file to remote
lv1 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt)
assert mock_async_to_literal.call_count == 1

# Second call with same DataFrame
lv2 = TypeEngine.to_literal(ctx, file_path, FlyteFile, lt)

# Verify async_to_literal was called
assert mock_async_to_literal.call_count == 1

assert lv1 is lv2