Skip to content

Commit bd5679e

Browse files
authored
Add instructions for how to handle failures coming from cache policies (#3135)
* Add instructions for how to handle failures coming from cache policies Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> * Add Cache, CachePolicy, and VersionParameters to the API Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> * Give an example of how to set version in Cache Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> * Simplify error message matching Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> --------- Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com> Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
1 parent 57c7c7e commit bd5679e

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

flytekit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@
178178
:toctree: generated/
179179
180180
HashMethod
181+
Cache
182+
CachePolicy
183+
VersionParameters
181184
182185
Artifacts
183186
=========

flytekit/core/cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,13 @@ def get_version(self, params: VersionParameters) -> str:
8585
return self.version
8686

8787
task_hash = ""
88-
for cache_instance in self._policies:
89-
task_hash += cache_instance.get_version(self.salt, params)
88+
for policy in self._policies:
89+
try:
90+
task_hash += policy.get_version(self.salt, params)
91+
except Exception as e:
92+
raise ValueError(
93+
f"Failed to generate version for cache policy {policy}. Please consider setting the version in the Cache definition, e.g. Cache(version='v1.2.3')"
94+
) from e
9095

9196
hash_obj = hashlib.sha256(task_hash.encode())
9297
return hash_obj.hexdigest()

tests/flytekit/unit/core/test_cache.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ def get_version(self, salt: str, params: VersionParameters) -> str:
1212
return salt
1313

1414

15+
class ExceptionCachePolicy(CachePolicy):
16+
def get_version(self, salt: str, params: VersionParameters) -> str:
17+
raise Exception("This is an exception")
18+
19+
1520
@pytest.fixture
1621
def default_serialization_settings():
1722
default_image = Image(name="default", fqn="full/name", tag="some-tag")
@@ -167,3 +172,12 @@ def t_cached_explicit_version(a: int) -> int:
167172
serialized_t_cached_explicit_version = get_serializable_task(OrderedDict(), default_serialization_settings, t_cached_explicit_version)
168173
assert serialized_t_cached_explicit_version.template.metadata.discoverable == True
169174
assert serialized_t_cached_explicit_version.template.metadata.discovery_version == "a-version"
175+
176+
177+
def test_cache_policy_exception(default_serialization_settings):
178+
# Set the address of the ExceptionCachePolicy in the error message so that the test is robust to changes in the
179+
# address of the ExceptionCachePolicy class
180+
with pytest.raises(ValueError, match="Failed to generate version for cache policy"):
181+
@task(cache=Cache(policies=ExceptionCachePolicy()))
182+
def t_cached(a: int) -> int:
183+
return a

0 commit comments

Comments
 (0)