|
7 | 7 |
|
8 | 8 | import uvicorn
|
9 | 9 | from fastapi import FastAPI
|
10 |
| -from lightning_utilities.core.imports import compare_version, module_available |
| 10 | +from lightning_utilities.core.imports import compare_version |
11 | 11 | from pydantic import BaseModel
|
12 | 12 |
|
13 | 13 | from lightning_app.core.work import LightningWork
|
|
16 | 16 |
|
17 | 17 | logger = Logger(__name__)
|
18 | 18 |
|
19 |
| -__doctest_skip__ = [] |
20 |
| -# Skip doctests if requirements aren't available |
21 |
| -if not module_available("lightning_api_access"): |
22 |
| - __doctest_skip__ += ["PythonServer", "PythonServer.*"] |
| 19 | +__doctest_skip__ = ["PythonServer", "PythonServer.*"] |
| 20 | + |
23 | 21 |
|
24 | 22 | # Skip doctests if requirements aren't available
|
25 | 23 | if not _is_torch_available():
|
@@ -72,7 +70,7 @@ class PythonServer(LightningWork, abc.ABC):
|
72 | 70 |
|
73 | 71 | _start_method = "spawn"
|
74 | 72 |
|
75 |
| - @requires(["torch", "lightning_api_access"]) |
| 73 | + @requires(["torch"]) |
76 | 74 | def __init__( # type: ignore
|
77 | 75 | self,
|
78 | 76 | input_type: type = _DefaultInputData,
|
@@ -193,29 +191,32 @@ def predict_fn(request: input_type): # type: ignore
|
193 | 191 | fastapi_app.post("/predict", response_model=output_type)(predict_fn)
|
194 | 192 |
|
195 | 193 | def configure_layout(self) -> None:
|
196 |
| - if module_available("lightning_api_access"): |
| 194 | + try: |
197 | 195 | from lightning_api_access import APIAccessFrontend
|
198 |
| - |
199 |
| - class_name = self.__class__.__name__ |
200 |
| - url = f"{self.url}/predict" |
201 |
| - |
202 |
| - try: |
203 |
| - request = self._get_sample_dict_from_datatype(self.configure_input_type()) |
204 |
| - response = self._get_sample_dict_from_datatype(self.configure_output_type()) |
205 |
| - except TypeError: |
206 |
| - return None |
207 |
| - |
208 |
| - return APIAccessFrontend( |
209 |
| - apis=[ |
210 |
| - { |
211 |
| - "name": class_name, |
212 |
| - "url": url, |
213 |
| - "method": "POST", |
214 |
| - "request": request, |
215 |
| - "response": response, |
216 |
| - } |
217 |
| - ] |
218 |
| - ) |
| 196 | + except ModuleNotFoundError: |
| 197 | + logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") |
| 198 | + return |
| 199 | + |
| 200 | + class_name = self.__class__.__name__ |
| 201 | + url = f"{self.url}/predict" |
| 202 | + |
| 203 | + try: |
| 204 | + request = self._get_sample_dict_from_datatype(self.configure_input_type()) |
| 205 | + response = self._get_sample_dict_from_datatype(self.configure_output_type()) |
| 206 | + except TypeError: |
| 207 | + return None |
| 208 | + |
| 209 | + return APIAccessFrontend( |
| 210 | + apis=[ |
| 211 | + { |
| 212 | + "name": class_name, |
| 213 | + "url": url, |
| 214 | + "method": "POST", |
| 215 | + "request": request, |
| 216 | + "response": response, |
| 217 | + } |
| 218 | + ] |
| 219 | + ) |
219 | 220 |
|
220 | 221 | def run(self, *args: Any, **kwargs: Any) -> Any:
|
221 | 222 | """Run method takes care of configuring and setting up a FastAPI server behind the scenes.
|
|
0 commit comments