diff --git a/pylzy/lzy/api/v1/workflow.py b/pylzy/lzy/api/v1/workflow.py index 16df24963..bc010f31a 100644 --- a/pylzy/lzy/api/v1/workflow.py +++ b/pylzy/lzy/api/v1/workflow.py @@ -11,6 +11,8 @@ TypeVar, cast, Set, ) from ai.lzy.v1.whiteboard.whiteboard_pb2 import Whiteboard +from grpc.aio import AioRpcError + from lzy.api.v1.env import Env from lzy.api.v1.provisioning import Provisioning from lzy.api.v1.snapshot import Snapshot, DefaultSnapshot @@ -41,15 +43,15 @@ def get_active(cls) -> Optional["LzyWorkflow"]: return cls.instance def __init__( - self, - name: str, - owner: "Lzy", - env: Env, - provisioning: Provisioning, - auto_py_env: PyEnv, - *, - eager: bool = False, - interactive: bool = True + self, + name: str, + owner: "Lzy", + env: Env, + provisioning: Provisioning, + auto_py_env: PyEnv, + *, + eager: bool = False, + interactive: bool = True ): if not is_name_valid(name): raise ValueError(f"Invalid workflow name. Name can contain only {NAME_VALID_SYMBOLS}") @@ -145,6 +147,8 @@ def __enter__(self) -> "LzyWorkflow": self.__snapshot = DefaultSnapshot(self.owner.serializer_registry, storage_uri, self.owner.storage_client, self.owner.storage_name) return self + except AioRpcError as e: + raise e except Exception as e: try: self.__abort() @@ -167,6 +171,8 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: finally: if exc_type is None: self.__destroy() + elif exc_type is AioRpcError: + raise exc_val else: try: self.__abort()