Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/data_processing/deltalake_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
idle_ttl=300, # Keep workers alive for 5 minutes after idle
),
image=image,
cache=flyte.Cache("auto", "1.1"),
cache=flyte.Cache("auto", "1.2"),
)

# Non-reusable environment for orchestration tasks
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ dev = [
"ipywidgets>=8.1.7",
"mypy>=1.16.0",
"kubernetes",
"textual>=0.80",
]

[tool.setuptools]
Expand Down
217 changes: 1 addition & 216 deletions src/flyte/_cache/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,216 +1 @@
import sqlite3
from pathlib import Path

try:
import aiosqlite

HAS_AIOSQLITE = True
except ImportError:
HAS_AIOSQLITE = False

from flyteidl2.task import common_pb2

from flyte._internal.runtime import convert
from flyte._logging import logger
from flyte.config import auto

DEFAULT_CACHE_DIR = "~/.flyte"
CACHE_LOCATION = "local-cache/cache.db"


class LocalTaskCache(object):
"""
This class implements a persistent store able to cache the result of local task executions.
"""

_conn: "aiosqlite.Connection | None" = None
_conn_sync: sqlite3.Connection | None = None
_initialized: bool = False

@staticmethod
def _get_cache_path() -> str:
"""Get the cache database path, creating directory if needed."""
config = auto()
if config.source:
cache_dir = config.source.parent
else:
cache_dir = Path(DEFAULT_CACHE_DIR).expanduser()

cache_path = cache_dir / CACHE_LOCATION
# Ensure the directory exists
cache_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Use local cache path: {cache_path}")
return str(cache_path)

@staticmethod
async def initialize():
"""Initialize the cache with database connection."""
if not LocalTaskCache._initialized:
if HAS_AIOSQLITE:
await LocalTaskCache._initialize_async()
else:
LocalTaskCache._initialize_sync()

@staticmethod
async def _initialize_async():
"""Initialize async cache connection."""
db_path = LocalTaskCache._get_cache_path()
conn = await aiosqlite.connect(db_path)
await conn.execute("""
CREATE TABLE IF NOT EXISTS task_cache (
key TEXT PRIMARY KEY,
value BLOB
)
""")
await conn.commit()
LocalTaskCache._conn = conn
LocalTaskCache._initialized = True

@staticmethod
def _initialize_sync():
"""Initialize sync cache connection."""
db_path = LocalTaskCache._get_cache_path()
conn = sqlite3.connect(db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS task_cache (
key TEXT PRIMARY KEY,
value BLOB
)
""")
conn.commit()
LocalTaskCache._conn_sync = conn
LocalTaskCache._initialized = True

@staticmethod
async def clear():
"""Clear all cache entries."""
if not LocalTaskCache._initialized:
await LocalTaskCache.initialize()

if HAS_AIOSQLITE:
await LocalTaskCache._clear_async()
else:
LocalTaskCache._clear_sync()

@staticmethod
async def _clear_async():
"""Clear all cache entries (async)."""
if LocalTaskCache._conn is None:
raise RuntimeError("Cache not properly initialized")
await LocalTaskCache._conn.execute("DELETE FROM task_cache")
await LocalTaskCache._conn.commit()

@staticmethod
def _clear_sync():
"""Clear all cache entries (sync)."""
if LocalTaskCache._conn_sync is None:
raise RuntimeError("Cache not properly initialized")
LocalTaskCache._conn_sync.execute("DELETE FROM task_cache")
LocalTaskCache._conn_sync.commit()

@staticmethod
async def get(cache_key: str) -> convert.Outputs | None:
if not LocalTaskCache._initialized:
await LocalTaskCache.initialize()

if HAS_AIOSQLITE:
return await LocalTaskCache._get_async(cache_key)
else:
return LocalTaskCache._get_sync(cache_key)

@staticmethod
async def _get_async(cache_key: str) -> convert.Outputs | None:
"""Get cache entry (async)."""
if LocalTaskCache._conn is None:
raise RuntimeError("Cache not properly initialized")

async with LocalTaskCache._conn.execute("SELECT value FROM task_cache WHERE key = ?", (cache_key,)) as cursor:
row = await cursor.fetchone()
if row:
outputs_bytes = row[0]
outputs = common_pb2.Outputs()
outputs.ParseFromString(outputs_bytes)
return convert.Outputs(proto_outputs=outputs)
return None

@staticmethod
def _get_sync(cache_key: str) -> convert.Outputs | None:
"""Get cache entry (sync)."""
if LocalTaskCache._conn_sync is None:
raise RuntimeError("Cache not properly initialized")

cursor = LocalTaskCache._conn_sync.execute("SELECT value FROM task_cache WHERE key = ?", (cache_key,))
row = cursor.fetchone()
if row:
outputs_bytes = row[0]
outputs = common_pb2.Outputs()
outputs.ParseFromString(outputs_bytes)
return convert.Outputs(proto_outputs=outputs)
return None

@staticmethod
async def set(
cache_key: str,
value: convert.Outputs,
) -> None:
if not LocalTaskCache._initialized:
await LocalTaskCache.initialize()

if HAS_AIOSQLITE:
await LocalTaskCache._set_async(cache_key, value)
else:
LocalTaskCache._set_sync(cache_key, value)

@staticmethod
async def _set_async(
cache_key: str,
value: convert.Outputs,
) -> None:
"""Set cache entry (async)."""
if LocalTaskCache._conn is None:
raise RuntimeError("Cache not properly initialized")

output_bytes = value.proto_outputs.SerializeToString()
await LocalTaskCache._conn.execute(
"INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)
)
await LocalTaskCache._conn.commit()

@staticmethod
def _set_sync(
cache_key: str,
value: convert.Outputs,
) -> None:
"""Set cache entry (sync)."""
if LocalTaskCache._conn_sync is None:
raise RuntimeError("Cache not properly initialized")

output_bytes = value.proto_outputs.SerializeToString()
LocalTaskCache._conn_sync.execute(
"INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)
)
LocalTaskCache._conn_sync.commit()

@staticmethod
async def close():
"""Close the database connection."""
if HAS_AIOSQLITE:
await LocalTaskCache._close_async()
else:
LocalTaskCache._close_sync()

@staticmethod
async def _close_async():
"""Close async database connection."""
if LocalTaskCache._conn:
await LocalTaskCache._conn.close()
LocalTaskCache._conn = None
LocalTaskCache._initialized = False

@staticmethod
def _close_sync():
"""Close sync database connection."""
if LocalTaskCache._conn_sync:
LocalTaskCache._conn_sync.close()
LocalTaskCache._conn_sync = None
LocalTaskCache._initialized = False
from flyte._persistence._task_cache import LocalTaskCache # noqa: F401
13 changes: 13 additions & 0 deletions src/flyte/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CommonInit:
batch_size: int = 1000
source_config_path: Optional[Path] = None # Only used for documentation
sync_local_sys_paths: bool = True
local_persistence: bool = False


@dataclass(init=True, kw_only=True, repr=True, eq=True, frozen=True)
Expand Down Expand Up @@ -162,6 +163,7 @@ async def init(
source_config_path: Optional[Path] = None,
sync_local_sys_paths: bool = True,
load_plugin_type_transformers: bool = True,
local_persistence: bool = False,
) -> None:
"""
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
Expand Down Expand Up @@ -204,6 +206,7 @@ async def init(
into the remote container (default: True).
:param load_plugin_type_transformers: If enabled (default True), load the type transformer plugins registered under
the "flyte.plugins.types" entry point group.
:param local_persistence: Whether to enable SQLite persistence for local run metadata (default: False).
:return: None
"""
from flyte._utils import org_from_endpoint, sanitize_endpoint
Expand Down Expand Up @@ -254,6 +257,7 @@ async def init(
images=images or {},
source_config_path=source_config_path,
sync_local_sys_paths=sync_local_sys_paths,
local_persistence=local_persistence,
)


Expand Down Expand Up @@ -346,6 +350,7 @@ async def init_from_config(
storage=storage,
source_config_path=cfg_path,
sync_local_sys_paths=sync_local_sys_paths,
local_persistence=cfg.local.persistence,
)


Expand Down Expand Up @@ -590,6 +595,14 @@ def is_initialized() -> bool:
return _get_init_config() is not None


def is_persistence_enabled() -> bool:
"""Check if local run persistence is enabled."""
cfg = _get_init_config()
if cfg is None:
return False
return cfg.local_persistence


def initialize_in_cluster() -> None:
"""
Initialize the system for in-cluster execution. This is a placeholder function and does not perform any actions.
Expand Down
Loading
Loading