Skip to content

Commit 92fe188

Browse files
authored
[App] Release 1.8.3post1 (#15820)
* [App] Enable Python Server and Gradio Serve to run on accelerated device such as GPU CUDA / MPS (#15813) * bump version to 1.8.3.post1 * update changelog
1 parent 655ade6 commit 92fe188

File tree

7 files changed

+59
-4
lines changed

7 files changed

+59
-4
lines changed

src/lightning/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.8.3.post0"
1+
version = "1.8.3.post1"

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
- Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747))
2121
- Fixed setting property to the `LightningFlow` ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))
22+
- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))
2223

2324

2425
## [1.8.2] - 2022-11-17

src/lightning_app/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.8.3.post0"
1+
version = "1.8.3.post1"

src/lightning_app/components/serve/gradio.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import abc
2+
import os
23
from functools import partial
34
from types import ModuleType
45
from typing import Any, List, Optional
56

7+
from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
68
from lightning_app.core.work import LightningWork
79
from lightning_app.utilities.imports import _is_gradio_available, requires
810

@@ -39,6 +41,10 @@ def __init__(self, *args, **kwargs):
3941
assert self.inputs
4042
assert self.outputs
4143
self._model = None
44+
# Note: Enable to run inference on GPUs.
45+
self._run_executor_cls = (
46+
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
47+
)
4248

4349
@property
4450
def model(self):

src/lightning_app/components/serve/python_server.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import base64
3+
import os
34
from pathlib import Path
45
from typing import Any, Dict, Optional
56

@@ -9,12 +10,54 @@
910
from pydantic import BaseModel
1011
from starlette.staticfiles import StaticFiles
1112

13+
from lightning_app.core.queues import MultiProcessQueue
1214
from lightning_app.core.work import LightningWork
1315
from lightning_app.utilities.app_helpers import Logger
16+
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
1417

1518
logger = Logger(__name__)
1619

1720

21+
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
22+
23+
"""This Executor enables to move PyTorch tensors on GPU.
24+
25+
Without this executor, it woud raise the following expection:
26+
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
27+
To use CUDA with multiprocessing, you must use the 'spawn' start method
28+
"""
29+
30+
enable_start_observer: bool = False
31+
32+
def __call__(self, *args: Any, **kwargs: Any):
33+
import torch
34+
35+
with self.enable_spawn():
36+
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
37+
torch.multiprocessing.spawn(
38+
self.dispatch_run,
39+
args=(self.__class__, self.work, queue, args, kwargs),
40+
nprocs=1,
41+
)
42+
43+
@staticmethod
44+
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
45+
if local_rank == 0:
46+
if isinstance(delta_queue, dict):
47+
delta_queue = cls.process_queue(delta_queue)
48+
work._request_queue = cls.process_queue(work._request_queue)
49+
work._response_queue = cls.process_queue(work._response_queue)
50+
51+
state_observer = WorkStateObserver(work, delta_queue=delta_queue)
52+
state_observer.start()
53+
_proxy_setattr(work, delta_queue, state_observer)
54+
55+
unwrap(work.run)(*args, **kwargs)
56+
57+
if local_rank == 0:
58+
state_observer.join(0)
59+
60+
1861
class _DefaultInputData(BaseModel):
1962
payload: str
2063

@@ -106,6 +149,11 @@ def predict(self, request):
106149
self._input_type = input_type
107150
self._output_type = output_type
108151

152+
# Note: Enable to run inference on GPUs.
153+
self._run_executor_cls = (
154+
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
155+
)
156+
109157
def setup(self, *args, **kwargs) -> None:
110158
"""This method is called before the server starts. Override this if you need to download the model or
111159
initialize the weights, setting up pipelines etc.

src/lightning_lite/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.8.3.post0"
1+
version = "1.8.3.post1"

src/pytorch_lightning/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.8.3.post0"
1+
version = "1.8.3.post1"

0 commit comments

Comments
 (0)