Skip to content

Commit eb87018

Browse files
Sherin Thomaspre-commit-ci[bot]rlizzoBorda
authored andcommitted
Sample datatype for Serve Component (#15623)
* introducing serve component * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up tests * clean up tests * doctest * mypy * structure-fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup * cleanup * test fix * addition * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * requirements * getting future url * url for local * sample data typeg * changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * prediction * updates * updates * manifest * fix type error * fixed test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rick Izzo <[email protected]> Co-authored-by: Jirka <[email protected]> (cherry picked from commit 136a090)
1 parent 0851202 commit eb87018

File tree

8 files changed

+96
-6
lines changed

8 files changed

+96
-6
lines changed

examples/app_server/app.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# !pip install torchvision pydantic
2+
import base64
3+
import io
4+
5+
import torch
6+
import torchvision
7+
from PIL import Image
8+
from pydantic import BaseModel
9+
10+
import lightning as L
11+
from lightning.app.components.serve import Image as InputImage
12+
from lightning.app.components.serve import PythonServer
13+
14+
15+
class PyTorchServer(PythonServer):
16+
def setup(self):
17+
self._model = torchvision.models.resnet18(pretrained=True)
18+
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19+
self._model.to(self._device)
20+
21+
def predict(self, request):
22+
image = base64.b64decode(request.image.encode("utf-8"))
23+
image = Image.open(io.BytesIO(image))
24+
transforms = torchvision.transforms.Compose(
25+
[
26+
torchvision.transforms.Resize(224),
27+
torchvision.transforms.ToTensor(),
28+
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
29+
]
30+
)
31+
image = transforms(image)
32+
image = image.to(self._device)
33+
prediction = self._model(image.unsqueeze(0))
34+
return {"prediction": prediction.argmax().item()}
35+
36+
37+
class OutputData(BaseModel):
38+
prediction: int
39+
40+
41+
component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
42+
app = L.LightningApp(component)

src/lightning/__setup__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
3535
"recursive-include requirements *.txt",
3636
"recursive-include src/lightning/app/ui *",
3737
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
38+
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
3839
# fixme: this is strange, this shall work with setup find package - include
3940
"prune src/lightning_app",
4041
"prune src/lightning_lite",

src/lightning_app/__setup__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
5050
"recursive-exclude src *.md" + os.linesep,
5151
"recursive-exclude requirements *.txt" + os.linesep,
5252
"recursive-include src/lightning_app *.md" + os.linesep,
53+
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
5354
"recursive-include requirements/app *.txt" + os.linesep,
5455
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
5556
]

src/lightning_app/components/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from lightning_app.components.python.popen import PopenPythonScript
1010
from lightning_app.components.python.tracer import Code, TracerPythonScript
1111
from lightning_app.components.serve.gradio import ServeGradio
12-
from lightning_app.components.serve.python_server import PythonServer
12+
from lightning_app.components.serve.python_server import Image, Number, PythonServer
1313
from lightning_app.components.serve.serve import ModelInferenceAPI
1414
from lightning_app.components.serve.streamlit import ServeStreamlit
1515
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
@@ -24,6 +24,8 @@
2424
"ServeStreamlit",
2525
"ModelInferenceAPI",
2626
"PythonServer",
27+
"Image",
28+
"Number",
2729
"MultiNode",
2830
"LiteMultiNode",
2931
"LightningTrainingComponent",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from lightning_app.components.serve.gradio import ServeGradio
2-
from lightning_app.components.serve.python_server import PythonServer
2+
from lightning_app.components.serve.python_server import Image, Number, PythonServer
33
from lightning_app.components.serve.streamlit import ServeStreamlit
44

5-
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
5+
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
19.6 KB
Loading

src/lightning_app/components/serve/python_server.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import abc
2-
from typing import Any, Dict
2+
import base64
3+
from pathlib import Path
4+
from typing import Any, Dict, Optional
35

46
import uvicorn
57
from fastapi import FastAPI
@@ -12,6 +14,12 @@
1214
logger = Logger(__name__)
1315

1416

17+
def image_to_base64(image_path):
18+
with open(image_path, "rb") as image_file:
19+
encoded_string = base64.b64encode(image_file.read())
20+
return encoded_string.decode("UTF-8")
21+
22+
1523
class _DefaultInputData(BaseModel):
1624
payload: str
1725

@@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
2028
prediction: str
2129

2230

31+
class Image(BaseModel):
32+
image: Optional[str]
33+
34+
@staticmethod
35+
def _get_sample_data() -> Dict[Any, Any]:
36+
imagepath = Path(__file__).absolute().parent / "catimage.png"
37+
with open(imagepath, "rb") as image_file:
38+
encoded_string = base64.b64encode(image_file.read())
39+
return {"image": encoded_string.decode("UTF-8")}
40+
41+
42+
class Number(BaseModel):
43+
prediction: Optional[int]
44+
45+
@staticmethod
46+
def _get_sample_data() -> Dict[Any, Any]:
47+
return {"prediction": 463}
48+
49+
2350
class PythonServer(LightningWork, abc.ABC):
2451
def __init__( # type: ignore
2552
self,
@@ -110,6 +137,9 @@ def predict(self, request: Any) -> Any:
110137

111138
@staticmethod
112139
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
140+
if hasattr(datatype, "_get_sample_data"):
141+
return datatype._get_sample_data()
142+
113143
datatype_props = datatype.schema()["properties"]
114144
out: Dict[str, Any] = {}
115145
for k, v in datatype_props.items():
@@ -141,7 +171,7 @@ def _attach_frontend(self, fastapi_app: FastAPI) -> None:
141171
url = self._future_url if self._future_url else self.url
142172
if not url:
143173
# if the url is still empty, point it to localhost
144-
url = f"http://127.0.0.1{self.port}"
174+
url = f"http://127.0.0.1:{self.port}"
145175
url = f"{url}/predict"
146176
datatype_parse_error = False
147177
try:

tests/tests_app/components/serve/test_python_server.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import multiprocessing as mp
22

3-
from lightning_app.components import PythonServer
3+
from lightning_app.components import Image, Number, PythonServer
44
from lightning_app.utilities.network import _configure_session, find_free_network_port
55

66

@@ -29,3 +29,17 @@ def test_python_server_component():
2929
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
3030
process.terminate()
3131
assert res.json()["prediction"] == "test"
32+
33+
34+
def test_image_sample_data():
35+
data = Image()._get_sample_data()
36+
assert isinstance(data, dict)
37+
assert "image" in data
38+
assert len(data["image"]) > 100
39+
40+
41+
def test_number_sample_data():
42+
data = Number()._get_sample_data()
43+
assert isinstance(data, dict)
44+
assert "prediction" in data
45+
assert data["prediction"] == 463

0 commit comments

Comments
 (0)