Skip to content

Commit d4c5ff3

Browse files
ethanwharrislexierule
authored andcommitted
[App] Add support for plugins to return actions (#16832)
1 parent 3a7598d commit d4c5ff3

File tree

11 files changed

+168
-42
lines changed

11 files changed

+168
-42
lines changed

requirements/app/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud>=0.5.26
1+
lightning-cloud>=0.5.27
22
packaging
33
typing-extensions>=4.0.0, <=4.4.0
44
deepdiff>=5.7.0, <6.2.4

src/lightning_app/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from lightning_app.core.flow import LightningFlow # noqa: E402
3333
from lightning_app.core.work import LightningWork # noqa: E402
3434
from lightning_app.perf import pdb # noqa: E402
35+
from lightning_app.plugin.plugin import LightningPlugin # noqa: E402
3536
from lightning_app.utilities.packaging.build_config import BuildConfig # noqa: E402
3637
from lightning_app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
3738

@@ -43,4 +44,4 @@
4344
_PACKAGE_ROOT = os.path.dirname(__file__)
4445
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_PACKAGE_ROOT))
4546

46-
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "BuildConfig", "CloudCompute", "pdb"]
47+
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin", "BuildConfig", "CloudCompute", "pdb"]

src/lightning_app/plugin/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import lightning_app.plugin.actions as actions
2+
from lightning_app.plugin.actions import NavigateTo, Toast, ToastSeverity
3+
from lightning_app.plugin.plugin import LightningPlugin
4+
5+
__all__ = ["LightningPlugin", "actions", "Toast", "ToastSeverity", "NavigateTo"]

src/lightning_app/plugin/actions.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from dataclasses import dataclass
15+
from enum import Enum
16+
from typing import Union
17+
18+
from lightning_cloud.openapi.models import V1CloudSpaceAppAction, V1CloudSpaceAppActionType
19+
20+
21+
class _Action:
22+
"""Actions are returned by `LightningPlugin` objects to perform actions in the UI."""
23+
24+
def to_spec(self) -> V1CloudSpaceAppAction:
25+
"""Convert this action to a ``V1CloudSpaceAppAction``"""
26+
raise NotImplementedError
27+
28+
29+
@dataclass
30+
class NavigateTo(_Action):
31+
"""The ``NavigateTo`` action can be used to navigate to a relative URL within the Lightning frontend.
32+
33+
Args:
34+
url: The relative URL to navigate to. E.g. ``/<username>/<project>``.
35+
"""
36+
37+
url: str
38+
39+
def to_spec(self) -> V1CloudSpaceAppAction:
40+
return V1CloudSpaceAppAction(
41+
type=V1CloudSpaceAppActionType.NAVIGATE_TO,
42+
content=self.url,
43+
)
44+
45+
46+
class ToastSeverity(Enum):
47+
ERROR = "error"
48+
INFO = "info"
49+
SUCCESS = "success"
50+
WARNING = "warning"
51+
52+
def __str__(self) -> str:
53+
return self.value
54+
55+
56+
@dataclass
57+
class Toast(_Action):
58+
"""The ``Toast`` action can be used to display a toast message to the user.
59+
60+
Args:
61+
severity: The severity level of the toast. One of: "error", "info", "success", "warning".
62+
message: The message body.
63+
"""
64+
65+
severity: Union[ToastSeverity, str]
66+
message: str
67+
68+
def to_spec(self) -> V1CloudSpaceAppAction:
69+
return V1CloudSpaceAppAction(
70+
type=V1CloudSpaceAppActionType.TOAST,
71+
content=f"{self.severity}:{self.message}",
72+
)

src/lightning_app/core/plugin.py renamed to src/lightning_app/plugin/plugin.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import tarfile
1616
import tempfile
1717
from pathlib import Path
18-
from typing import Dict, List, Optional
18+
from typing import Any, Dict, List, Optional
1919
from urllib.parse import urlparse
2020

2121
import requests
@@ -24,6 +24,8 @@
2424
from fastapi.middleware.cors import CORSMiddleware
2525
from pydantic import BaseModel
2626

27+
from lightning_app.core import constants
28+
from lightning_app.plugin.actions import _Action
2729
from lightning_app.utilities.app_helpers import Logger
2830
from lightning_app.utilities.component import _set_flow_context
2931
from lightning_app.utilities.enum import AppStage
@@ -41,16 +43,20 @@ def __init__(self) -> None:
4143
self.cloudspace_id = ""
4244
self.cluster_id = ""
4345

44-
def run(self, *args: str, **kwargs: str) -> None:
46+
def run(self, *args: str, **kwargs: str) -> Optional[List[_Action]]:
4547
"""Override with the logic to execute on the cloudspace."""
48+
raise NotImplementedError
4649

47-
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None:
50+
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> str:
4851
"""Run a job in the cloudspace associated with this plugin.
4952
5053
Args:
5154
name: The name of the job.
5255
app_entrypoint: The path of the file containing the app to run.
5356
env_vars: Additional env vars to set when running the app.
57+
58+
Returns:
59+
The relative URL of the created job.
5460
"""
5561
from lightning_app.runners.cloud import CloudRuntime
5662

@@ -74,12 +80,14 @@ def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, s
7480
# Used to indicate Lightning has been dispatched
7581
os.environ["LIGHTNING_DISPATCHED"] = "1"
7682

77-
runtime.cloudspace_dispatch(
83+
url = runtime.cloudspace_dispatch(
7884
project_id=self.project_id,
7985
cloudspace_id=self.cloudspace_id,
8086
name=name,
8187
cluster_id=self.cluster_id,
8288
)
89+
# Return a relative URL so it can be used with the NavigateTo action.
90+
return url.replace(constants.get_lightning_cloud_url(), "")
8391

8492
def _setup(
8593
self,
@@ -101,7 +109,7 @@ class _Run(BaseModel):
101109
plugin_arguments: Dict[str, str]
102110

103111

104-
def _run_plugin(run: _Run) -> List:
112+
def _run_plugin(run: _Run) -> Dict[str, Any]:
105113
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
106114
with tempfile.TemporaryDirectory() as tmpdir:
107115
download_path = os.path.join(tmpdir, "source.tar.gz")
@@ -115,6 +123,9 @@ def _run_plugin(run: _Run) -> List:
115123

116124
response = requests.get(source_code_url)
117125

126+
# TODO: Backoff retry a few times in case the URL is flaky
127+
response.raise_for_status()
128+
118129
with open(download_path, "wb") as f:
119130
f.write(response.content)
120131
except Exception as e:
@@ -152,17 +163,15 @@ def _run_plugin(run: _Run) -> List:
152163
cloudspace_id=run.cloudspace_id,
153164
cluster_id=run.cluster_id,
154165
)
155-
plugin.run(**run.plugin_arguments)
166+
actions = plugin.run(**run.plugin_arguments) or []
167+
return {"actions": [action.to_spec().to_dict() for action in actions]}
156168
except Exception as e:
157169
raise HTTPException(
158170
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
159171
)
160172
finally:
161173
os.chdir(cwd)
162174

163-
# TODO: Return actions from the plugin here
164-
return []
165-
166175

167176
def _start_plugin_server(host: str, port: int) -> None:
168177
"""Start the plugin server which can be used to dispatch apps or run plugins."""

src/lightning_app/runners/cloud.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def cloudspace_dispatch(
196196
cloudspace_id: str,
197197
name: str,
198198
cluster_id: str,
199-
):
199+
) -> str:
200200
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
201201
such as the project and cluster IDs that are instead passed directly.
202202
@@ -210,12 +210,15 @@ def cloudspace_dispatch(
210210
ApiException: If there was an issue in the backend.
211211
RuntimeError: If there are validation errors.
212212
ValueError: If there are validation errors.
213+
214+
Returns:
215+
The URL of the created job.
213216
"""
214217
# Dispatch in four phases: resolution, validation, spec creation, API transactions
215218
# Resolution
216219
root = self._resolve_root()
217220
repo = self._resolve_repo(root)
218-
self._resolve_cloudspace(project_id, cloudspace_id)
221+
project = self._resolve_project(project_id=project_id)
219222
existing_instances = self._resolve_run_instances_by_name(project_id, name)
220223
name = self._resolve_run_name(name, existing_instances)
221224
queue_server_type = self._resolve_queue_server_type()
@@ -240,7 +243,7 @@ def cloudspace_dispatch(
240243
run = self._api_create_run(project_id, cloudspace_id, run_body)
241244
self._api_package_and_upload_repo(repo, run)
242245

243-
self._api_create_run_instance(
246+
run_instance = self._api_create_run_instance(
244247
cluster_id,
245248
project_id,
246249
name,
@@ -251,6 +254,8 @@ def cloudspace_dispatch(
251254
env_vars,
252255
)
253256

257+
return self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui")
258+
254259
def dispatch(
255260
self,
256261
name: str = "",
@@ -451,16 +456,9 @@ def _resolve_repo(
451456

452457
return LocalSourceCodeDir(path=root, ignore_functions=ignore_functions)
453458

454-
def _resolve_project(self) -> V1Membership:
459+
def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership:
455460
"""Determine the project to run on, choosing a default if multiple projects are found."""
456-
return _get_project(self.backend.client)
457-
458-
def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpace:
459-
"""Get a cloudspace by project / cloudspace ID."""
460-
return self.backend.client.cloud_space_service_get_cloud_space(
461-
project_id=project_id,
462-
id=cloudspace_id,
463-
)
461+
return _get_project(self.backend.client, project_id=project_id)
464462

465463
def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
466464
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""

src/lightning_app/utilities/cloud.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
from typing import Optional
1617

1718
from lightning_cloud.openapi import V1Membership
1819

@@ -22,10 +23,11 @@
2223
from lightning_app.utilities.network import LightningClient
2324

2425

25-
def _get_project(
26-
client: LightningClient, project_id: str = LIGHTNING_CLOUD_PROJECT_ID, verbose: bool = True
27-
) -> V1Membership:
26+
def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
2827
"""Get a project membership for the user from the backend."""
28+
if project_id is None:
29+
project_id = LIGHTNING_CLOUD_PROJECT_ID
30+
2931
projects = client.projects_service_list_memberships()
3032
if project_id is not None:
3133
for membership in projects.memberships:

src/lightning_app/utilities/load_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
if TYPE_CHECKING:
2727
from lightning_app import LightningApp, LightningFlow, LightningWork
28-
from lightning_app.core.plugin import LightningPlugin
28+
from lightning_app.plugin.plugin import LightningPlugin
2929

3030
from lightning_app.utilities.app_helpers import _mock_missing_imports, Logger
3131

@@ -85,7 +85,7 @@ def _load_objects_from_file(
8585

8686

8787
def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
88-
from lightning_app.core.plugin import LightningPlugin
88+
from lightning_app.plugin.plugin import LightningPlugin
8989

9090
# TODO: Plugin should be run in the context of the created main module here
9191
plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)

tests/tests_app/plugin/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)