+.. displayitem::
+ :header: 2D Parallelism
+ :description: Combine Tensor Parallelism with FSDP (2D Parallel) to train efficiently on 100s of GPUs
+ :col_css: col-md-12
+ :button_link: ../advanced/model_parallel/tp_fsdp.html
+ :height: 100
+
.. displayitem::
:header: Accelerators
:description: Accelerators connect the Trainer to hardware to train faster
@@ -230,7 +239,7 @@ Glossary
:header: Model Parallelism
:description: A way to scale training that splits a model between multiple devices.
:col_css: col-md-12
- :button_link: ../advanced/model_parallel.html
+ :button_link: ../advanced/model_parallel/index.html
:height: 100
.. displayitem::
@@ -331,6 +340,13 @@ Glossary
:button_link: ../clouds/cluster_advanced.html
:height: 100
+.. displayitem::
+ :header: Tensor Parallelism
+ :description: Parallelize the computation of model layers across multiple GPUs, reducing memory usage and communication overhead
+ :col_css: col-md-12
+ :button_link: ../advanced/tp.html
+ :height: 100
+
.. displayitem::
:header: Transfer learning
:description: Using pre-trained models to improve learning
diff --git a/docs/source-pytorch/levels/advanced_level_21.rst b/docs/source-pytorch/levels/advanced_level_21.rst
index 6c07d1d037465..68262ca9e223c 100644
--- a/docs/source-pytorch/levels/advanced_level_21.rst
+++ b/docs/source-pytorch/levels/advanced_level_21.rst
@@ -25,9 +25,9 @@ Scale to billions of parameters with multiple distributed strategies.
.. displayitem::
:header: Train models with billions of parameters
- :description: Scale to billions of params on GPUs with FSDP or Deepspeed.
+ :description: Scale to billions of params on GPUs with FSDP, TP or Deepspeed.
:col_css: col-md-6
- :button_link: ../advanced/model_parallel.html
+ :button_link: ../advanced/model_parallel/index.html
:height: 150
:tag: advanced
diff --git a/docs/source-pytorch/past_versions.rst b/docs/source-pytorch/past_versions.rst
index 95ef7255af099..76b93b763c99e 100644
--- a/docs/source-pytorch/past_versions.rst
+++ b/docs/source-pytorch/past_versions.rst
@@ -3,7 +3,7 @@ Past PyTorch Lightning versions
PyTorch Lightning :doc:`evolved over time
`. Here's the history of versions with links to their respective docs.
-To help you with keeping up to spead, check :doc:`Migration guide `.
+To help you with keeping up to speed, check :doc:`Migration guide `.
.. list-table:: Past versions
:widths: 5 50 30 15
diff --git a/docs/source-pytorch/upgrade/sections/2_0_regular.rst b/docs/source-pytorch/upgrade/sections/2_0_regular.rst
index 192f20bc669b9..2f94ef7ab66fd 100644
--- a/docs/source-pytorch/upgrade/sections/2_0_regular.rst
+++ b/docs/source-pytorch/upgrade/sections/2_0_regular.rst
@@ -6,7 +6,7 @@
- Then
- Ref
- * - used PyTorch 3.11
+ * - used PyTorch 1.11
- upgrade to PyTorch 2.1 or higher
- `PR18691`_
diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst
index 96038e63cf807..d923b01c7edb3 100644
--- a/docs/source-pytorch/versioning.rst
+++ b/docs/source-pytorch/versioning.rst
@@ -16,8 +16,6 @@ A Lightning release number is in the format of ``MAJOR.MINOR.PATCH``.
With every release, we publish a changelog where we list additions, removals, deprecations, changed functionality and fixes.
-The ``lightning.app`` package is an exception to this rule, as it may contain any change with or without deprecations in any of the releases.
-
API Stability
*************
@@ -81,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou
- ``torch``
- ``torchmetrics``
- Python
+ * - 2.4
+ - 2.4
+ - 2.4
+ - ≥2.1, ≤2.4
+ - ≥0.7.0
+ - ≥3.9, ≤3.12
+ * - 2.3
+ - 2.3
+ - 2.3
+ - ≥2.0, ≤2.3
+ - ≥0.7.0
+ - ≥3.8, ≤3.11
* - 2.2
- 2.2
- 2.2
diff --git a/docs/source-pytorch/visualize/supported_exp_managers.rst b/docs/source-pytorch/visualize/supported_exp_managers.rst
index 42a0e6c9a85ed..e26514e9747c4 100644
--- a/docs/source-pytorch/visualize/supported_exp_managers.rst
+++ b/docs/source-pytorch/visualize/supported_exp_managers.rst
@@ -134,7 +134,7 @@ Here's the full documentation for the :class:`~lightning.pytorch.loggers.TensorB
Weights and Biases
==================
-To use `Weights and Biases `_ (wandb) first install the wandb package:
+To use `Weights and Biases `_ (wandb) first install the wandb package:
.. code-block:: bash
diff --git a/examples/README.md b/examples/README.md
index 6c51e07ae0d7b..f796375e2a088 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -22,11 +22,3 @@ In this folder, we have 2 simple examples that showcase the power of the Lightni
- [Image Classifier](pytorch/basics/backbone_image_classifier.py) (trains arbitrary datasets with arbitrary backbones).
- [Autoencoder](pytorch/basics/autoencoder.py)
-
-______________________________________________________________________
-
-## Domain Examples
-
-This folder contains older examples. You should instead use the examples
-in [Lightning Bolts](https://lightning.ai/docs/pytorch/stable/ecosystem/bolts.html)
-for advanced use cases.
diff --git a/examples/app/argparse/app.py b/examples/app/argparse/app.py
deleted file mode 100644
index 5fa8039908eb3..0000000000000
--- a/examples/app/argparse/app.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-
-
-class Work(LightningWork):
- def __init__(self, cloud_compute):
- super().__init__(cloud_compute=cloud_compute)
-
- def run(self):
- pass
-
-
-class Flow(LightningFlow):
- def __init__(self, cloud_compute):
- super().__init__()
- self.work = Work(cloud_compute)
-
- def run(self):
- assert self.work.cloud_compute.name == "gpu", self.work.cloud_compute.name
- self.stop()
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--use_gpu", action="store_true", default=False, help="Whether to use GPU in the cloud")
- hparams = parser.parse_args()
- app = LightningApp(Flow(CloudCompute("gpu" if hparams.use_gpu else "cpu")))
diff --git a/examples/app/boring/.gitignore b/examples/app/boring/.gitignore
deleted file mode 100644
index 94018704d9f90..0000000000000
--- a/examples/app/boring/.gitignore
+++ /dev/null
@@ -1,10 +0,0 @@
-lightning_logs
-*.pt
-.storage/
-.shared/
-data
-*.ckpt
-redis-stable
-node_modules
-*.rdb
-boring_file.txt
diff --git a/examples/app/boring/app.py b/examples/app/boring/app.py
deleted file mode 100644
index 0dfaedfae0107..0000000000000
--- a/examples/app/boring/app.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import os
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-from lightning.app.components import TracerPythonScript
-from lightning.app.storage.path import Path
-
-FILE_CONTENT = """
-Hello there!
-This tab is currently an IFrame of the FastAPI Server running in `DestinationFileAndServeWork`.
-Also, the content of this file was created in `SourceFileWork` and then transferred to `DestinationFileAndServeWork`.
-Are you already 🤯 ? Stick with us, this is only the beginning. Lightning is 🚀.
-"""
-
-
-class SourceFileWork(LightningWork):
- def __init__(self, cloud_compute: CloudCompute = CloudCompute(), **kwargs):
- super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute)
- self.boring_path = None
-
- def run(self):
- # This should be used as a REFERENCE to the file.
- self.boring_path = "lit://boring_file.txt"
- with open(self.boring_path, "w", encoding="utf-8") as f:
- f.write(FILE_CONTENT)
-
-
-class DestinationFileAndServeWork(TracerPythonScript):
- def run(self, path: Path):
- assert path.exists()
- self.script_args += [f"--filepath={path}", f"--host={self.host}", f"--port={self.port}"]
- super().run()
-
-
-class BoringApp(LightningFlow):
- def __init__(self):
- super().__init__()
- self.source_work = SourceFileWork()
- self.dest_work = DestinationFileAndServeWork(
- script_path=os.path.join(os.path.dirname(__file__), "scripts/serve.py"),
- port=1111,
- parallel=False, # runs until killed.
- cloud_compute=CloudCompute(),
- raise_exception=True,
- )
-
- @property
- def ready(self) -> bool:
- return self.dest_work.is_running
-
- def run(self):
- self.source_work.run()
- if self.source_work.has_succeeded:
- # the flow passes the file from one work to another.
- self.dest_work.run(self.source_work.boring_path)
- self.stop("Boring App End")
-
- def configure_layout(self):
- return {"name": "Boring Tab", "content": self.dest_work.url + "/file"}
-
-
-app = LightningApp(BoringApp())
diff --git a/examples/app/boring/app_dynamic.py b/examples/app/boring/app_dynamic.py
deleted file mode 100644
index b08b8cf5ce10d..0000000000000
--- a/examples/app/boring/app_dynamic.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import os
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-from lightning.app.components import TracerPythonScript
-from lightning.app.storage.path import Path
-from lightning.app.structures import Dict
-
-FILE_CONTENT = """
-Hello there!
-This tab is currently an IFrame of the FastAPI Server running in `DestinationFileAndServeWork`.
-Also, the content of this file was created in `SourceFileWork` and then transferred to `DestinationFileAndServeWork`.
-Are you already 🤯 ? Stick with us, this is only the beginning. Lightning is 🚀.
-"""
-
-
-class SourceFileWork(LightningWork):
- def __init__(self, cloud_compute: CloudCompute = CloudCompute(), **kwargs):
- super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute)
- self.boring_path = None
-
- def run(self):
- # This should be used as a REFERENCE to the file.
- self.boring_path = "lit://boring_file.txt"
- with open(self.boring_path, "w") as f:
- f.write(FILE_CONTENT)
-
-
-class DestinationFileAndServeWork(TracerPythonScript):
- def run(self, path: Path):
- assert path.exists()
- self.script_args += [f"--filepath={path}", f"--host={self.host}", f"--port={self.port}"]
- super().run()
-
-
-class BoringApp(LightningFlow):
- def __init__(self):
- super().__init__()
- self.dict = Dict()
-
- @property
- def ready(self) -> bool:
- if "dst_w" in self.dict:
- return self.dict["dst_w"].url != ""
- return False
-
- def run(self):
- # create dynamically the source_work at runtime
- if "src_w" not in self.dict:
- self.dict["src_w"] = SourceFileWork()
-
- self.dict["src_w"].run()
-
- if self.dict["src_w"].has_succeeded:
- # create dynamically the dst_w at runtime
- if "dst_w" not in self.dict:
- self.dict["dst_w"] = DestinationFileAndServeWork(
- script_path=os.path.join(os.path.dirname(__file__), "scripts/serve.py"),
- port=1111,
- parallel=False, # runs until killed.
- cloud_compute=CloudCompute(),
- raise_exception=True,
- )
-
- # the flow passes the file from one work to another.
- self.dict["dst_w"].run(self.dict["src_w"].boring_path)
- self.stop("Boring App End")
-
- def configure_layout(self):
- return {"name": "Boring Tab", "content": self.dict["dst_w"].url + "/file" if "dst_w" in self.dict else ""}
-
-
-app = LightningApp(BoringApp(), log_level="debug")
diff --git a/examples/app/boring/scripts/__init__.py b/examples/app/boring/scripts/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/examples/app/boring/scripts/serve.py b/examples/app/boring/scripts/serve.py
deleted file mode 100644
index dedd6013985ca..0000000000000
--- a/examples/app/boring/scripts/serve.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import argparse
-import os
-
-import uvicorn
-from fastapi import FastAPI
-from fastapi.requests import Request
-from fastapi.responses import HTMLResponse
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser("Server Parser")
- parser.add_argument("--filepath", type=str, help="Where to find the `filepath`")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host`")
- parser.add_argument("--port", type=int, default="8888", help="Server port`")
- hparams = parser.parse_args()
-
- fastapi_service = FastAPI()
-
- if not os.path.exists(str(hparams.filepath)):
- content = ["The file wasn't transferred"]
- else:
- with open(hparams.filepath) as fo:
- content = fo.readlines() # read the file received from SourceWork.
-
- @fastapi_service.get("/file")
- async def get_file_content(request: Request, response_class=HTMLResponse):
- lines = "\n".join(["" + line + "
" for line in content])
- return HTMLResponse(f"")
-
- uvicorn.run(app=fastapi_service, host=hparams.host, port=hparams.port)
diff --git a/examples/app/commands_and_api/.lightningignore b/examples/app/commands_and_api/.lightningignore
deleted file mode 100644
index f7275bbbd035b..0000000000000
--- a/examples/app/commands_and_api/.lightningignore
+++ /dev/null
@@ -1 +0,0 @@
-venv/
diff --git a/examples/app/commands_and_api/app.py b/examples/app/commands_and_api/app.py
deleted file mode 100644
index 3f59c117c4180..0000000000000
--- a/examples/app/commands_and_api/app.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from command import CustomCommand, CustomConfig
-from lightning import LightningFlow
-from lightning.app.api import Get, Post
-from lightning.app.core.app import LightningApp
-
-
-async def handler():
- print("Has been called")
- return "Hello World !"
-
-
-class ChildFlow(LightningFlow):
- def nested_command(self, name: str):
- """A nested command."""
- print(f"Hello {name}")
-
- def configure_commands(self):
- return [{"nested_command": self.nested_command}]
-
-
-class FlowCommands(LightningFlow):
- def __init__(self):
- super().__init__()
- self.names = []
- self.child_flow = ChildFlow()
-
- def run(self):
- if self.names:
- print(self.names)
-
- def command_without_client(self, name: str):
- """A command without a client."""
- self.names.append(name)
-
- def command_with_client(self, config: CustomConfig):
- self.names.append(config.name)
-
- def configure_commands(self):
- commands = [
- {"command_without_client": self.command_without_client},
- {"command_with_client": CustomCommand(self.command_with_client)},
- ]
- return commands + self.child_flow.configure_commands()
-
- def configure_api(self):
- return [
- Post("/user/command_without_client", self.command_without_client),
- Get("/pure_function", handler),
- ]
-
-
-app = LightningApp(FlowCommands(), log_level="debug")
diff --git a/examples/app/commands_and_api/command.py b/examples/app/commands_and_api/command.py
deleted file mode 100644
index e2dd26f684b03..0000000000000
--- a/examples/app/commands_and_api/command.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from argparse import ArgumentParser
-
-from lightning.app.utilities.commands import ClientCommand
-from pydantic import BaseModel
-
-
-class CustomConfig(BaseModel):
- name: str
-
-
-class CustomCommand(ClientCommand):
- description = "A command with a client."
-
- def run(self):
- parser = ArgumentParser()
- parser.add_argument("--name", type=str)
- args = parser.parse_args()
- self.invoke_handler(config=CustomConfig(name=args.name))
diff --git a/examples/app/components/python/__init__.py b/examples/app/components/python/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/examples/app/components/python/app.py b/examples/app/components/python/app.py
deleted file mode 100644
index 944cb7de2995d..0000000000000
--- a/examples/app/components/python/app.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import os
-from pathlib import Path
-
-from lightning.app import LightningApp, LightningFlow
-
-from examples.components.python.component_tracer import PLTracerPythonScript
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- script_path = Path(__file__).parent / "pl_script.py"
- self.tracer_python_script = PLTracerPythonScript(script_path)
-
- def run(self):
- assert os.getenv("GLOBAL_RANK", "0") == "0"
- if not self.tracer_python_script.has_started:
- self.tracer_python_script.run()
- if self.tracer_python_script.has_succeeded:
- self.stop("tracer script succeed")
- if self.tracer_python_script.has_failed:
- self.stop("tracer script failed")
-
-
-app = LightningApp(RootFlow())
diff --git a/examples/app/components/python/component_popen.py b/examples/app/components/python/component_popen.py
deleted file mode 100644
index bc70b9f47b16d..0000000000000
--- a/examples/app/components/python/component_popen.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from pathlib import Path
-
-from lightning.app.components import PopenPythonScript
-
-if __name__ == "__main__":
- comp = PopenPythonScript(Path(__file__).parent / "pl_script.py")
- comp.run()
diff --git a/examples/app/components/python/component_tracer.py b/examples/app/components/python/component_tracer.py
deleted file mode 100644
index 3e2e96f38a7f3..0000000000000
--- a/examples/app/components/python/component_tracer.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from lightning.app.components import TracerPythonScript
-from lightning.app.storage.path import Path
-from lightning.app.utilities.tracer import Tracer
-from lightning.pytorch import Trainer
-
-
-class PLTracerPythonScript(TracerPythonScript):
- """This component can be used for ANY PyTorch Lightning script to track its progress and extract its best model
- path."""
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Define the component state.
- self.global_step = None
- self.best_model_path = None
-
- def configure_tracer(self) -> Tracer:
- from lightning.pytorch.callbacks import Callback
-
- class MyInjectedCallback(Callback):
- def __init__(self, lightning_work):
- self.lightning_work = lightning_work
-
- def on_train_start(self, trainer, pl_module) -> None:
- print("This code doesn't belong to the script but was injected.")
- print("Even the Lightning Work is available and state transfer works !")
- print(self.lightning_work)
-
- def on_batch_train_end(self, trainer, *_) -> None:
- # On every batch end, collects some information.
- # This is communicated automatically to the rest of the app,
- # so you can track your training in real time in the Lightning App UI.
- self.lightning_work.global_step = trainer.global_step
- best_model_path = trainer.checkpoint_callback.best_model_path
- if best_model_path:
- self.lightning_work.best_model_path = Path(best_model_path)
-
- # This hook would be called every time
- # before a Trainer `__init__` method is called.
-
- def trainer_pre_fn(trainer, *args, **kwargs):
- kwargs["callbacks"] = kwargs.get("callbacks", []) + [MyInjectedCallback(self)]
- return {}, args, kwargs
-
- tracer = super().configure_tracer()
- tracer.add_traced(Trainer, "__init__", pre_fn=trainer_pre_fn)
- return tracer
-
-
-if __name__ == "__main__":
- comp = PLTracerPythonScript(Path(__file__).parent / "pl_script.py")
- res = comp.run()
diff --git a/examples/app/components/python/pl_script.py b/examples/app/components/python/pl_script.py
deleted file mode 100644
index 75538daf4bed2..0000000000000
--- a/examples/app/components/python/pl_script.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.pytorch import Trainer
-from lightning.pytorch.demos.boring_classes import BoringModel
-
-if __name__ == "__main__":
- model = BoringModel()
- trainer = Trainer(max_epochs=1, accelerator="cpu", devices=2, strategy="ddp")
- trainer.fit(model)
- trainer.validate(model)
- trainer.test(model)
- trainer.predict(model)
diff --git a/examples/app/components/serve/gradio/app.py b/examples/app/components/serve/gradio/app.py
deleted file mode 100644
index ec07e4ba99c06..0000000000000
--- a/examples/app/components/serve/gradio/app.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from functools import partial
-
-import gradio as gr
-import requests
-import torch
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.components import ServeGradio
-from PIL import Image
-
-
-# Credit to @akhaliq for his inspiring work.
-# Find his original code there: https://huggingface.co/spaces/akhaliq/AnimeGANv2/blob/main/app.py
-class AnimeGANv2UI(ServeGradio):
- inputs = gr.inputs.Image(type="pil")
- outputs = gr.outputs.Image(type="pil")
- elon = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg/330px-Elon_Musk_Royal_Society_%28crop2%29.jpg"
- img = Image.open(requests.get(elon, stream=True).raw)
- img.save("elon.jpg")
- examples = [["elon.jpg"]]
-
- def __init__(self):
- super().__init__()
- self.ready = False
-
- def predict(self, img):
- return self.model(img=img)
-
- def build_model(self):
- repo = "AK391/animegan2-pytorch:main"
- model = torch.hub.load(repo, "generator", device="cpu")
- face2paint = torch.hub.load(repo, "face2paint", size=512, device="cpu")
- self.ready = True
- return partial(face2paint, model=model)
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.demo = AnimeGANv2UI()
-
- def run(self):
- self.demo.run()
-
- def configure_layout(self):
- tabs = []
- if self.demo.ready:
- tabs.append({"name": "Home", "content": self.demo})
- return tabs
-
-
-app = LightningApp(RootFlow())
diff --git a/examples/app/components/serve/gradio/beyonce.jpg b/examples/app/components/serve/gradio/beyonce.jpg
deleted file mode 100644
index 68b6084475b01..0000000000000
Binary files a/examples/app/components/serve/gradio/beyonce.jpg and /dev/null differ
diff --git a/examples/app/components/serve/gradio/requirements.txt b/examples/app/components/serve/gradio/requirements.txt
deleted file mode 100644
index 25aceddaba262..0000000000000
--- a/examples/app/components/serve/gradio/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-gradio
diff --git a/examples/app/dag/.gitignore b/examples/app/dag/.gitignore
deleted file mode 100644
index fcb9fa9fe4d29..0000000000000
--- a/examples/app/dag/.gitignore
+++ /dev/null
@@ -1,6 +0,0 @@
-df_data
-df_target
-X_train
-X_test
-y_train
-y_test
diff --git a/examples/app/dag/.lightningignore b/examples/app/dag/.lightningignore
deleted file mode 100644
index 78ae49041d0b6..0000000000000
--- a/examples/app/dag/.lightningignore
+++ /dev/null
@@ -1,8 +0,0 @@
-*df_data*
-*df_target*
-*X_train*
-*X_test*
-*y_train*
-*y_test*
-*.shared*
-*.storage*
diff --git a/examples/app/dag/app.py b/examples/app/dag/app.py
deleted file mode 100644
index 7e94f64b62c5d..0000000000000
--- a/examples/app/dag/app.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import os
-from importlib import import_module
-
-import numpy as np
-import pandas as pd
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.components import TracerPythonScript
-from lightning.app.storage import Payload
-from lightning.app.structures import Dict, List
-from sklearn import datasets
-from sklearn.metrics import mean_squared_error
-
-
-def get_path(path):
- return os.path.join(os.path.dirname(__file__), path)
-
-
-class GetDataWork(LightningWork):
- """This component is responsible to download some data and store them with a PayLoad."""
-
- def __init__(self):
- super().__init__()
- self.df_data = None
- self.df_target = None
-
- def run(self):
- print("Starting data collection...")
- data = datasets.fetch_california_housing(data_home=get_path("data"))
- self.df_data = Payload(pd.DataFrame(data["data"], columns=data["feature_names"]))
- self.df_target = Payload(pd.DataFrame(data["target"], columns=["MedHouseVal"]))
- print("Finished data collection.")
-
-
-class ModelWork(LightningWork):
- """This component is receiving some data and train a sklearn model."""
-
- def __init__(self, model_path: str, parallel: bool):
- super().__init__(parallel=parallel)
- self.model_path, self.model_name = model_path.split(".")
- self.test_rmse = None
-
- def run(self, X_train: Payload, X_test: Payload, y_train: Payload, y_test: Payload):
- print(f"Starting training and evaluating {self.model_name}...")
- module = import_module(f"sklearn.{self.model_path}")
- model = getattr(module, self.model_name)()
- model.fit(X_train.value, y_train.value.ravel())
- y_test_prediction = model.predict(X_test.value)
- self.test_rmse = np.sqrt(mean_squared_error(y_test.value, y_test_prediction))
- print(f"Finished training and evaluating {self.model_name}.")
-
-
-class DAG(LightningFlow):
- """This component is a DAG."""
-
- def __init__(self, models_paths: list):
- super().__init__()
- # Step 1: Create a work to get the data.
- self.data_collector = GetDataWork()
-
- # Step 2: Create a tracer component. This is used to execute python script
- # and collect any outputs from its globals as Payloads.
- self.processing = TracerPythonScript(
- get_path("processing.py"),
- outputs=["X_train", "X_test", "y_train", "y_test"],
- )
-
- # Step 3: Create the work to train the models_paths in parallel.
- self.dict = Dict(**{
- model_path.split(".")[-1]: ModelWork(model_path, parallel=True) for model_path in models_paths
- })
-
- # Step 4: Some element to track components progress.
- self.has_completed = False
- self.metrics = {}
-
- def run(self):
- # Step 1 and 2: Download and process the data.
- self.data_collector.run()
- self.processing.run(
- df_data=self.data_collector.df_data,
- df_target=self.data_collector.df_target,
- )
-
- # Step 3: Launch n models training in parallel.
- for model, work in self.dict.items():
- work.run(
- X_train=self.processing.X_train,
- X_test=self.processing.X_test,
- y_train=self.processing.y_train,
- y_test=self.processing.y_test,
- )
- if work.test_rmse: # Use the state to control when to collect and stop.
- self.metrics[model] = work.test_rmse
- work.stop() # Stop the model work to reduce cost
-
- # Step 4: Print the score of each model when they are all finished.
- if len(self.metrics) == len(self.dict):
- print(self.metrics)
- self.has_completed = True
-
-
-class ScheduledDAG(LightningFlow):
- def __init__(self, dag_cls, **dag_kwargs):
- super().__init__()
- self.dags = List()
- self._dag_cls = dag_cls
- self.dag_kwargs = dag_kwargs
-
- def run(self):
- """Example of scheduling an infinite number of DAG runs continuously."""
- # Step 1: Every minute, create and launch a new DAG.
- if self.schedule("* * * * *"):
- print("Launching a new DAG")
- self.dags.append(self._dag_cls(**self.dag_kwargs))
-
- for dag in self.dags:
- if not dag.has_completed:
- dag.run()
-
-
-app = LightningApp(
- ScheduledDAG(
- DAG,
- models_paths=[
- "svm.SVR",
- "linear_model.LinearRegression",
- "tree.DecisionTreeRegressor",
- ],
- ),
-)
diff --git a/examples/app/dag/processing.py b/examples/app/dag/processing.py
deleted file mode 100644
index 245377fa8cbaa..0000000000000
--- a/examples/app/dag/processing.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import random
-
-from sklearn.model_selection import train_test_split
-from sklearn.preprocessing import MinMaxScaler
-
-print("Starting processing ...")
-scaler = MinMaxScaler()
-
-X_train, X_test, y_train, y_test = train_test_split(
- df_data.values, df_target.values, test_size=0.20, random_state=random.randint(0, 42)
-)
-X_train = scaler.fit_transform(X_train)
-X_test = scaler.transform(X_test)
-print("Finished processing.")
diff --git a/examples/app/dag/requirements.txt b/examples/app/dag/requirements.txt
deleted file mode 100644
index f669f518e7389..0000000000000
--- a/examples/app/dag/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-scikit-learn
-pandas
diff --git a/examples/app/display_name/.lightningignore b/examples/app/display_name/.lightningignore
deleted file mode 100644
index f7275bbbd035b..0000000000000
--- a/examples/app/display_name/.lightningignore
+++ /dev/null
@@ -1 +0,0 @@
-venv/
diff --git a/examples/app/display_name/app.py b/examples/app/display_name/app.py
deleted file mode 100644
index a3a36b8afb02c..0000000000000
--- a/examples/app/display_name/app.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from lightning.app import LightningApp, LightningFlow, LightningWork
-
-
-class Work(LightningWork):
- def __init__(self, start_with_flow=True):
- super().__init__(start_with_flow=start_with_flow)
-
- def run(self):
- pass
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = Work()
- self.w1 = Work(start_with_flow=False)
- self.w.display_name = "My Custom Name" # Not supported yet
- self.w1.display_name = "My Custom Name 1"
-
- def run(self):
- self.w.run()
- self.w1.run()
-
-
-app = LightningApp(Flow())
diff --git a/examples/app/drive/.gitignore b/examples/app/drive/.gitignore
deleted file mode 100644
index eaa5fa8755fc2..0000000000000
--- a/examples/app/drive/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-a.txt
diff --git a/examples/app/drive/app.py b/examples/app/drive/app.py
deleted file mode 100644
index e56636ce13887..0000000000000
--- a/examples/app/drive/app.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import os
-
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.storage import Drive
-
-
-class Work_1(LightningWork):
- def run(self, drive: Drive):
- # 1. Create a file.
- with open("a.txt", "w") as f:
- f.write("Hello World !")
-
- # 2. Put the file into the drive.
- drive.put("a.txt")
-
- # 3. Delete the locally.
- os.remove("a.txt")
-
-
-class Work_2(LightningWork):
- def __init__(self):
- super().__init__()
-
- def run(self, drive: Drive):
- print(drive.list(".")) # Prints ["a.txt"]
-
- print(os.path.exists("a.txt")) # Prints False
-
- drive.get("a.txt") # Transfer the file from this drive to the local filesystem.
-
- print(os.path.exists("a.txt")) # Prints True
-
- with open("a.txt") as f:
- print(f.readlines()[0]) # Prints Hello World !
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.drive_1 = Drive("lit://drive_1")
- self.work_1 = Work_1()
- self.work_2 = Work_2()
-
- def run(self):
- # Pass the drive to both works.
- self.work_1.run(self.drive_1)
- self.work_2.run(self.drive_1)
- self.stop("Application End!")
-
-
-app = LightningApp(Flow())
diff --git a/examples/app/hpo/README.md b/examples/app/hpo/README.md
deleted file mode 100644
index b9b648bcaf14f..0000000000000
--- a/examples/app/hpo/README.md
+++ /dev/null
@@ -1,64 +0,0 @@
-# Build a Lightning Hyperparameter Optimization (HPO) App
-
-## A bit of background
-
-Traditionally, developing machine learning (ML) products requires choosing among a large space of
-hyperparameters while creating and training the ML models. Hyperparameter optimization
-(HPO) aims to find a well-performing hyperparameter configuration for a given ML model
-on a dataset at hand, including the ML model,
-its hyperparameters, and other data processing steps.
-
-HPOs free the human expert from a tedious and error-prone, manual hyperparameter tuning process.
-
-As an example, in the famous [scikit-learn](https://scikit-learn.org/stable/) library,
-hyperparameters are passed as arguments to the constructor of
-the estimator classes such as `C` kernel for
-[Support Vector Classifier](https://scikit-learn.org/stable/modules/classes.html?highlight=svm#module-sklearn.svm), etc.
-
-It is possible and recommended to search the hyperparameter space for the best validation score.
-
-An HPO search consists of:
-
-- an objective method
-- a defined parameter space
-- a method for searching or sampling candidates
-
-A naive method for sampling candidates is grid search, which exhaustively considers all
-hyperparameter combinations from a user-specified grid.
-
-Fortunately, HPO is an active area of research, and many methods have been developed to
-optimize the time required to get strong candidates.
-
-In the following tutorial, you will learn how to use Lightning together with [Optuna](https://optuna.org/).
-
-[Optuna](https://optuna.org/) is an open source HPO framework to automate hyperparameter search.
-Out-of-the-box, it provides efficient algorithms to search large spaces and prune unpromising trials for faster results.
-
-First, you will learn about the best practices on how to implement HPO without the Lightning Framework.
-Secondly, we will dive into a working HPO application with Lightning, and finally create a neat
-[HiPlot UI](https://facebookresearch.github.io/hiplot/_static/demo/demo_basic_usage.html?hip.filters=%5B%5D&hip.color_by=%22dropout%22&hip.PARALLEL_PLOT.order=%5B%22uid%22%2C%22dropout%22%2C%22lr%22%2C%22loss%22%2C%22optimizer%22%5D)
-for our application.
-
-## Getting started
-
-### Step 1: Download the data
-
-```bash
-python download_data.py
-```
-
-### Step 2: Run the HPO Lightning App without an UI
-
-```bash
-lightning run app app_wo_ui.py
-```
-
-### Step 3: Run the HPO Lightning App with HiPlot UI in Streamlit.
-
-```bash
-lightning run app app_wi_ui.py
-```
-
-## Learn More
-
-In the documentation, search for `Build a Sweep App`.
diff --git a/examples/app/hpo/app_wi_ui.py b/examples/app/hpo/app_wi_ui.py
deleted file mode 100644
index e6ef3f3a1b983..0000000000000
--- a/examples/app/hpo/app_wi_ui.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from pathlib import Path
-
-import optuna
-from hyperplot import HiPlotFlow
-from lightning.app import CloudCompute, LightningApp, LightningFlow
-from lightning.app.structures import Dict
-from objective import ObjectiveWork
-
-
-class RootHPOFlow(LightningFlow):
- def __init__(self, script_path, data_dir, total_trials, simultaneous_trials):
- super().__init__()
- self.script_path = script_path
- self.data_dir = data_dir
- self.total_trials = total_trials
- self.simultaneous_trials = simultaneous_trials
- self.num_trials = simultaneous_trials
- self._study = optuna.create_study()
- self.ws = Dict()
- self.hi_plot = HiPlotFlow()
-
- def run(self):
- if self.num_trials >= self.total_trials:
- self.stop()
-
- has_told_study = []
-
- for trial_idx in range(self.num_trials):
- work_name = f"objective_work_{trial_idx}"
- if work_name not in self.ws:
- objective_work = ObjectiveWork(
- script_path=self.script_path,
- data_dir=self.data_dir,
- cloud_compute=CloudCompute("cpu"),
- )
- self.ws[work_name] = objective_work
- if not self.ws[work_name].has_started:
- trial = self._study.ask(ObjectiveWork.distributions())
- self.ws[work_name].run(trial_id=trial._trial_id, **trial.params)
-
- if self.ws[work_name].metric and not self.ws[work_name].has_told_study:
- self.hi_plot.data.append({"x": -1 * self.ws[work_name].metric, **self.ws[work_name].params})
- self._study.tell(self.ws[work_name].trial_id, self.ws[work_name].metric)
- self.ws[work_name].has_told_study = True
-
- has_told_study.append(self.ws[work_name].has_told_study)
-
- if all(has_told_study):
- self.num_trials += self.simultaneous_trials
-
-
-if __name__ == "__main__":
- app = LightningApp(
- RootHPOFlow(
- script_path=str(Path(__file__).parent / "pl_script.py"),
- data_dir="data/hymenoptera_data_version_0",
- total_trials=6,
- simultaneous_trials=2,
- )
- )
diff --git a/examples/app/hpo/app_wo_ui.py b/examples/app/hpo/app_wo_ui.py
deleted file mode 100644
index e20318347a1f7..0000000000000
--- a/examples/app/hpo/app_wo_ui.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from pathlib import Path
-
-import optuna
-from lightning.app import CloudCompute, LightningApp, LightningFlow
-from lightning.app.structures import Dict
-from objective import ObjectiveWork
-
-
-class RootHPOFlow(LightningFlow):
- def __init__(self, script_path, data_dir, total_trials, simultaneous_trials):
- super().__init__()
- self.script_path = script_path
- self.data_dir = data_dir
- self.total_trials = total_trials
- self.simultaneous_trials = simultaneous_trials
- self.num_trials = simultaneous_trials
- self._study = optuna.create_study()
- self.ws = Dict()
-
- def run(self):
- if self.num_trials >= self.total_trials:
- self.stop()
-
- has_told_study = []
-
- for trial_idx in range(self.num_trials):
- work_name = f"objective_work_{trial_idx}"
- if work_name not in self.ws:
- objective_work = ObjectiveWork(
- script_path=self.script_path,
- data_dir=self.data_dir,
- cloud_compute=CloudCompute("cpu"),
- )
- self.ws[work_name] = objective_work
- if not self.ws[work_name].has_started:
- trial = self._study.ask(ObjectiveWork.distributions())
- self.ws[work_name].run(trial_id=trial._trial_id, **trial.params)
-
- if self.ws[work_name].metric and not self.ws[work_name].has_told_study:
- self._study.tell(self.ws[work_name].trial_id, self.ws[work_name].metric)
- self.ws[work_name].has_told_study = True
-
- has_told_study.append(self.ws[work_name].has_told_study)
-
- if all(has_told_study):
- self.num_trials += self.simultaneous_trials
-
-
-if __name__ == "__main__":
- app = LightningApp(
- RootHPOFlow(
- script_path=str(Path(__file__).parent / "pl_script.py"),
- data_dir="data/hymenoptera_data_version_0",
- total_trials=6,
- simultaneous_trials=2,
- )
- )
diff --git a/examples/app/hpo/download_data.py b/examples/app/hpo/download_data.py
deleted file mode 100644
index d82b86a9dee95..0000000000000
--- a/examples/app/hpo/download_data.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from utils import download_data
-
-data_dir = "hymenoptera_data_version_0"
-download_url = f"https://pl-flash-data.s3.amazonaws.com/{data_dir}.zip"
-download_data(download_url, "./data")
diff --git a/examples/app/hpo/hyperplot.py b/examples/app/hpo/hyperplot.py
deleted file mode 100644
index 8ff238ce38985..0000000000000
--- a/examples/app/hpo/hyperplot.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from lightning.app import LightningFlow
-from lightning.app.frontend import StreamlitFrontend
-from lightning.app.utilities.state import AppState
-
-
-class HiPlotFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.data = []
-
- def run(self):
- pass
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=render_fn)
-
-
-def render_fn(state: AppState):
- import json
-
- import hiplot as hip
- import streamlit as st
- from streamlit_autorefresh import st_autorefresh
-
- st.set_page_config(layout="wide")
- st_autorefresh(interval=1000, limit=None, key="refresh")
-
- if not state.data:
- st.write("No data available yet ! Stay tuned")
- return
-
- xp = hip.Experiment.from_iterable(state.data)
- ret_val = xp.to_streamlit(ret="selected_uids", key="hip").display()
- st.markdown("hiplot returned " + json.dumps(ret_val))
diff --git a/examples/app/hpo/objective.py b/examples/app/hpo/objective.py
deleted file mode 100644
index e320b66217db1..0000000000000
--- a/examples/app/hpo/objective.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import os
-import tempfile
-from datetime import datetime
-from typing import Optional
-
-import pandas as pd
-import torch
-from lightning.app import CloudCompute
-from lightning.app.components import TracerPythonScript
-from optuna.distributions import CategoricalDistribution, LogUniformDistribution
-from torchmetrics import Accuracy
-
-
-class ObjectiveWork(TracerPythonScript):
- def __init__(self, script_path: str, data_dir: str, cloud_compute: Optional[CloudCompute]):
- timestamp = datetime.now().strftime("%H:%M:%S")
- tmpdir = tempfile.TemporaryDirectory().name
- submission_path = os.path.join(tmpdir, f"{timestamp}.csv")
- best_model_path = os.path.join(tmpdir, f"{timestamp}.model.pt")
- super().__init__(
- script_path,
- script_args=[
- f"--train_data_path={data_dir}/train",
- f"--test_data_path={data_dir}/test",
- f"--submission_path={submission_path}",
- f"--best_model_path={best_model_path}",
- ],
- cloud_compute=cloud_compute,
- )
- self.data_dir = data_dir
- self.best_model_path = best_model_path
- self.submission_path = submission_path
- self.metric = None
- self.trial_id = None
- self.metric = None
- self.params = None
- self.has_told_study = False
-
- def run(self, trial_id: int, **params):
- self.trial_id = trial_id
- self.params = params
- self.script_args.extend([f"--{k}={v}" for k, v in params.items()])
- super().run()
- self.compute_metric()
-
- def _to_labels(self, path: str):
- return torch.from_numpy(pd.read_csv(path).label.values)
-
- def compute_metric(self):
- self.metric = -1 * float(
- Accuracy(task="binary")(
- self._to_labels(self.submission_path),
- self._to_labels(f"{self.data_dir}/ground_truth.csv"),
- )
- )
-
- @staticmethod
- def distributions():
- return {
- "backbone": CategoricalDistribution(["resnet18", "resnet34"]),
- "learning_rate": LogUniformDistribution(0.0001, 0.1),
- }
diff --git a/examples/app/hpo/pl_script.py b/examples/app/hpo/pl_script.py
deleted file mode 100644
index bbc453798431a..0000000000000
--- a/examples/app/hpo/pl_script.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import argparse
-import os
-
-import pandas as pd
-import torch
-from flash import Trainer
-from flash.image import ImageClassificationData, ImageClassifier
-
-# Parse arguments provided by the Work.
-parser = argparse.ArgumentParser()
-parser.add_argument("--train_data_path", type=str, required=True)
-parser.add_argument("--submission_path", type=str, required=True)
-parser.add_argument("--test_data_path", type=str, required=True)
-parser.add_argument("--best_model_path", type=str, required=True)
-# Optional
-parser.add_argument("--backbone", type=str, default="resnet18")
-parser.add_argument("--learning_rate", type=float, default=0.01)
-args = parser.parse_args()
-
-
-datamodule = ImageClassificationData.from_folders(
- train_folder=args.train_data_path,
- batch_size=8,
-)
-
-model = ImageClassifier(datamodule.num_classes, backbone=args.backbone)
-trainer = Trainer(fast_dev_run=True)
-trainer.fit(model, datamodule=datamodule)
-trainer.save_checkpoint(args.best_model_path)
-
-datamodule = ImageClassificationData.from_folders(
- predict_folder=args.test_data_path,
- batch_size=8,
-)
-
-predictions = Trainer().predict(model, datamodule=datamodule)
-submission_data = [
- {"filename": os.path.basename(p["metadata"]["filepath"]), "label": torch.argmax(p["preds"]).item()}
- for batch in predictions
- for p in batch
-]
-df = pd.DataFrame(submission_data)
-df.to_csv(args.submission_path, index=False)
diff --git a/examples/app/hpo/requirements.txt b/examples/app/hpo/requirements.txt
deleted file mode 100644
index bd85880da2237..0000000000000
--- a/examples/app/hpo/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-optuna
-lightning-flash[image,serve] == 0.7.0
-hiplot
diff --git a/examples/app/hpo/utils.py b/examples/app/hpo/utils.py
deleted file mode 100644
index a07ae73f8fd3e..0000000000000
--- a/examples/app/hpo/utils.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import os
-import os.path
-import tarfile
-import zipfile
-
-import requests
-
-
-def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
- """Download file with progressbar.
-
- # Code taken from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603
- # __author__ = "github.com/ruxi"
- # __license__ = "MIT"
-
- Usage:
- download_file('http://web4host.net/5MB.zip')
-
- """
- if url == "NEED_TO_BE_CREATED":
- raise NotImplementedError
-
- if not os.path.exists(path):
- os.makedirs(path)
- local_filename = os.path.join(path, url.split("/")[-1])
- r = requests.get(url, stream=True, verify=False)
- file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0
- chunk_size = 1024
- num_bars = int(file_size / chunk_size)
- if verbose:
- print({"file_size": file_size})
- print({"num_bars": num_bars})
-
- if not os.path.exists(local_filename):
- with open(local_filename, "wb") as fp:
- for chunk in r.iter_content(chunk_size=chunk_size):
- fp.write(chunk) # type: ignore
-
- def extract_tarfile(file_path: str, extract_path: str, mode: str):
- if os.path.exists(file_path):
- with tarfile.open(file_path, mode=mode) as tar_ref:
- for member in tar_ref.getmembers():
- try:
- tar_ref.extract(member, path=extract_path, set_attrs=False)
- except PermissionError:
- raise PermissionError(f"Could not extract tar file {file_path}")
-
- if ".zip" in local_filename:
- if os.path.exists(local_filename):
- with zipfile.ZipFile(local_filename, "r") as zip_ref:
- zip_ref.extractall(path) # noqa: S202
- elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"):
- extract_tarfile(local_filename, path, "r:gz")
- elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"):
- extract_tarfile(local_filename, path, "r:bz2")
diff --git a/examples/app/installation_commands/app.py b/examples/app/installation_commands/app.py
deleted file mode 100644
index 526fcfef64413..0000000000000
--- a/examples/app/installation_commands/app.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# EXAMPLE COMPONENT: RUN A SCRIPT
-# app.py
-# !echo "I am installing a dependency not declared in a requirements file"
-# !pip install lmdb
-import lmdb
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-
-
-class YourComponent(LightningWork):
- def run(self):
- print(lmdb.version())
- print("lmdb successfully installed")
- print("Accessing a module in a Work or Flow body works!")
-
-
-class RootFlow(LightningFlow):
- def __init__(self, work):
- super().__init__()
- self.work = work
-
- def run(self):
- self.work.run()
-
-
-print(f"Accessing an object in main code body works!: version = {lmdb.version()}")
-
-
-# run on a cloud machine
-compute = CloudCompute("cpu")
-worker = YourComponent(cloud_compute=compute)
-app = LightningApp(RootFlow(worker))
diff --git a/examples/app/interruptible/app.py b/examples/app/interruptible/app.py
deleted file mode 100644
index a44fcf4dca3ed..0000000000000
--- a/examples/app/interruptible/app.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from time import sleep
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-
-
-class Work(LightningWork):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter = 0
-
- def run(self):
- while True:
- print(self.counter)
- self.counter += 1
- sleep(1)
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = Work(
- cloud_compute=CloudCompute("gpu", interruptible=True),
- start_with_flow=False,
- parallel=True,
- )
-
- def run(self):
- self.w.run()
- print(self.w.counter)
-
-
-app = LightningApp(Flow())
diff --git a/examples/app/justpy/app.py b/examples/app/justpy/app.py
deleted file mode 100644
index a4c9abc4cda1d..0000000000000
--- a/examples/app/justpy/app.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from typing import Callable
-
-from lightning import LightningApp, LightningFlow
-from lightning.app.frontend import JustPyFrontend
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- print(self.counter)
-
- def configure_layout(self):
- return JustPyFrontend(render_fn=render_fn)
-
-
-def render_fn(get_state: Callable) -> Callable:
- import justpy as jp
-
- def webpage():
- wp = jp.QuasarPage(dark=True)
- d = jp.Div(classes="q-pa-md q-gutter-sm", a=wp)
- container = jp.QBtn(color="primary", text="Counter: 0")
-
- async def click(*_):
- state = get_state()
- state.counter += 1
- container.text = f"Counter: {state.counter}"
-
- button = jp.QBtn(color="primary", text="Click Me!", click=click)
-
- d.add(button)
- d.add(container)
-
- return wp
-
- return webpage
-
-
-app = LightningApp(Flow())
diff --git a/examples/app/justpy/requirements.txt b/examples/app/justpy/requirements.txt
deleted file mode 100644
index 5f69409a4e4bb..0000000000000
--- a/examples/app/justpy/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-justpy
diff --git a/examples/app/layout/app.py b/examples/app/layout/app.py
deleted file mode 100644
index c57ab2eff78e7..0000000000000
--- a/examples/app/layout/app.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""An example showcasing how `configure_layout` can be used to nest user interfaces of different flows.
-
-Run the app:
-
-lightning run app examples/layout/demo.py
-
-This starts one server for each flow that returns a UI. Access the UI at the link printed in the terminal.
-
-"""
-
-import os
-from time import sleep
-
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.frontend import StaticWebFrontend, StreamlitFrontend
-
-
-class C11(LightningFlow):
- def __init__(self):
- super().__init__()
- self.message = "Hello Streamlit!"
-
- def run(self):
- pass
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=render_c11)
-
-
-def render_c11(state):
- import streamlit as st
-
- st.write(state.message)
-
-
-class C21(LightningFlow):
- def __init__(self):
- super().__init__()
-
- def run(self):
- pass
-
- def configure_layout(self):
- return StaticWebFrontend(os.path.join(os.path.dirname(__file__), "ui1"))
-
-
-class C22(LightningFlow):
- def __init__(self):
- super().__init__()
-
- def run(self):
- pass
-
- def configure_layout(self):
- return StaticWebFrontend(os.path.join(os.path.dirname(__file__), "ui2"))
-
-
-class C1(LightningFlow):
- def __init__(self):
- super().__init__()
- self.c11 = C11()
-
- def run(self):
- pass
-
-
-class C2(LightningFlow):
- def __init__(self):
- super().__init__()
- self.c21 = C21()
- self.c22 = C22()
-
- def run(self):
- pass
-
- def configure_layout(self):
- return [
- {"name": "one", "content": self.c21},
- {"name": "two", "content": self.c22},
- ]
-
-
-class Root(LightningFlow):
- def __init__(self):
- super().__init__()
- self.c1 = C1()
- self.c2 = C2()
-
- def run(self):
- sleep(10)
- self.stop("Layout End")
-
- def configure_layout(self):
- return [
- {"name": "one", "content": self.c1.c11},
- {"name": "two", "content": self.c2},
- {"name": "three", "content": "https://lightning.ai"},
- ]
-
-
-app = LightningApp(Root())
diff --git a/examples/app/layout/requirements.txt b/examples/app/layout/requirements.txt
deleted file mode 100644
index 12a4706528df6..0000000000000
--- a/examples/app/layout/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-streamlit
diff --git a/examples/app/layout/ui1/index.html b/examples/app/layout/ui1/index.html
deleted file mode 100644
index 11668dee1b911..0000000000000
--- a/examples/app/layout/ui1/index.html
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-
- One
-
-
- One
-
-
diff --git a/examples/app/layout/ui2/index.html b/examples/app/layout/ui2/index.html
deleted file mode 100644
index 7398be1f7630d..0000000000000
--- a/examples/app/layout/ui2/index.html
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-
- Two
-
-
- Two
-
-
diff --git a/examples/app/mount/app.py b/examples/app/mount/app.py
deleted file mode 100644
index b7e7c4df4746e..0000000000000
--- a/examples/app/mount/app.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import os
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-from lightning.app.storage import Mount
-
-
-class Work(LightningWork):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def run(self):
- files = os.listdir("/content/esRedditJson/")
- for file in files:
- print(file)
- assert "esRedditJson1" in files
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work_1 = Work(
- cloud_compute=CloudCompute(
- mounts=Mount(
- source="s3://ryft-public-sample-data/esRedditJson/",
- mount_path="/content/esRedditJson/",
- ),
- )
- )
-
- def run(self):
- self.work_1.run()
-
-
-app = LightningApp(Flow())
diff --git a/examples/app/multi_node/README.md b/examples/app/multi_node/README.md
deleted file mode 100644
index aef152444f4a4..0000000000000
--- a/examples/app/multi_node/README.md
+++ /dev/null
@@ -1,51 +0,0 @@
-# Lightning & Multi Node Training
-
-Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.
-
-## Multi Node with raw PyTorch
-
-You can run the multi-node raw PyTorch by running the following commands.
-
-Here is an example where you spawn your processes yourself.
-
-```bash
-lightning run app train_pytorch.py
-```
-
-or you can use the built-in component for it.
-
-```bash
-lightning run app train_pytorch_spawn.py
-```
-
-## Multi Node with raw PyTorch + Fabric
-
-You can run the multi-node raw PyTorch and Fabric by running the following commands.
-
-```bash
-lightning run app train_fabric.py
-```
-
-Using Fabric, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies.
-
-## Multi Node with Lightning Trainer
-
-Lightning supports running Lightning Trainer from a script or within a Lightning Work.
-
-You can either run a script directly
-
-```bash
-lightning run app train_pl_script.py
-```
-
-or run your code within as a work.
-
-```bash
-lightning run app train_pl.py
-```
-
-## Multi Node with any frameworks
-
-```bash
-lightning run app train_any.py
-```
diff --git a/examples/app/multi_node/pl_boring_script.py b/examples/app/multi_node/pl_boring_script.py
deleted file mode 100644
index f14809354f405..0000000000000
--- a/examples/app/multi_node/pl_boring_script.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from lightning.pytorch import Trainer
-from lightning.pytorch.demos.boring_classes import BoringModel
-
-if __name__ == "__main__":
- model = BoringModel()
- trainer = Trainer(max_epochs=1)
- trainer.fit(model)
diff --git a/examples/app/multi_node/requirements.txt b/examples/app/multi_node/requirements.txt
deleted file mode 100644
index 12c6d5d5eac2a..0000000000000
--- a/examples/app/multi_node/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-torch
diff --git a/examples/app/multi_node/train_any.py b/examples/app/multi_node/train_any.py
deleted file mode 100644
index b3c89ad534f43..0000000000000
--- a/examples/app/multi_node/train_any.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from lightning.app import CloudCompute, LightningApp, LightningWork
-from lightning.app.components import MultiNode
-
-
-class AnyDistributedComponent(LightningWork):
- def run(
- self,
- main_address: str,
- main_port: int,
- num_nodes: int,
- node_rank: int,
- ):
- print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")
-
-
-app = LightningApp(
- MultiNode(
- AnyDistributedComponent,
- num_nodes=2,
- cloud_compute=CloudCompute("gpu"),
- )
-)
diff --git a/examples/app/multi_node/train_fabric.py b/examples/app/multi_node/train_fabric.py
deleted file mode 100644
index 2379c491f89aa..0000000000000
--- a/examples/app/multi_node/train_fabric.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import torch
-from lightning.app import CloudCompute, LightningApp, LightningWork
-from lightning.app.components import FabricMultiNode
-from lightning.fabric import Fabric
-
-
-class FabricPyTorchDistributed(LightningWork):
- def run(self):
- # 1. Prepare the model
- model = torch.nn.Sequential(
- torch.nn.Linear(1, 1),
- torch.nn.ReLU(),
- torch.nn.Linear(1, 1),
- )
-
- # 2. Create Fabric.
- fabric = Fabric(strategy="ddp", precision="16-mixed")
- model, optimizer = fabric.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
- criterion = torch.nn.MSELoss()
-
- # 3. Train the model for 1000 steps.
- for step in range(1000):
- model.zero_grad()
- x = torch.tensor([0.8]).to(fabric.device)
- target = torch.tensor([1.0]).to(fabric.device)
- output = model(x)
- loss = criterion(output, target)
- print(f"global_rank: {fabric.global_rank} step: {step} loss: {loss}")
- fabric.backward(loss)
- optimizer.step()
-
-
-# 8 GPUs: (2 nodes of 4 x v100)
-app = LightningApp(
- FabricMultiNode(
- FabricPyTorchDistributed,
- cloud_compute=CloudCompute("gpu-fast-multi"), # 4 x V100
- num_nodes=2,
- )
-)
diff --git a/examples/app/multi_node/train_lt.py b/examples/app/multi_node/train_lt.py
deleted file mode 100644
index 23a2863e757c7..0000000000000
--- a/examples/app/multi_node/train_lt.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# app.py
-from lightning.app import CloudCompute, LightningApp, LightningWork
-from lightning.app.components import LightningTrainerMultiNode
-from lightning.pytorch import Trainer
-from lightning.pytorch.demos.boring_classes import BoringModel
-
-
-class LightningTrainerDistributed(LightningWork):
- def run(self):
- model = BoringModel()
- trainer = Trainer(max_epochs=10, strategy="ddp")
- trainer.fit(model)
-
-
-# 8 GPUs: (2 nodes of 4 x v100)
-component = LightningTrainerMultiNode(
- LightningTrainerDistributed,
- num_nodes=2,
- cloud_compute=CloudCompute("gpu-fast-multi"), # 4 x v100
-)
-app = LightningApp(component)
diff --git a/examples/app/multi_node/train_lt_script.py b/examples/app/multi_node/train_lt_script.py
deleted file mode 100644
index 7f89bc95e9b17..0000000000000
--- a/examples/app/multi_node/train_lt_script.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from lightning.app import CloudCompute, LightningApp
-from lightning.app.components import LightningTrainerScript
-
-# 8 GPUs: (2 nodes of 4 x v100)
-app = LightningApp(
- LightningTrainerScript(
- "pl_boring_script.py",
- num_nodes=2,
- cloud_compute=CloudCompute("gpu-fast-multi"), # 4 x v100
- ),
-)
diff --git a/examples/app/multi_node/train_pytorch.py b/examples/app/multi_node/train_pytorch.py
deleted file mode 100644
index a1c7fb8eac207..0000000000000
--- a/examples/app/multi_node/train_pytorch.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# app.py
-# ! pip install torch
-import torch
-from lightning.app import CloudCompute, LightningApp, LightningWork
-from lightning.app.components import MultiNode
-from torch.nn.parallel.distributed import DistributedDataParallel
-
-
-def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
- # 1. SET UP DISTRIBUTED ENVIRONMENT
- global_rank = local_rank + node_rank * nprocs
- world_size = num_nodes * nprocs
-
- if torch.distributed.is_available() and not torch.distributed.is_initialized():
- torch.distributed.init_process_group(
- "nccl" if torch.cuda.is_available() else "gloo",
- rank=global_rank,
- world_size=world_size,
- init_method=f"tcp://{main_address}:{main_port}",
- )
-
- # 2. PREPARE DISTRIBUTED MODEL
- model = torch.nn.Linear(32, 2)
- device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
- model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device)
-
- # 3. SETUP LOSS AND OPTIMIZER
- criterion = torch.nn.MSELoss()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
-
- # 4.TRAIN THE MODEL FOR 50 STEPS
- for step in range(50):
- model.zero_grad()
- x = torch.randn(64, 32).to(device)
- output = model(x)
- loss = criterion(output, torch.ones_like(output))
- print(f"global_rank: {global_rank} step: {step} loss: {loss}")
- loss.backward()
- optimizer.step()
-
- # 5. VERIFY ALL COPIES OF THE MODEL HAVE THE SAME WEIGTHS AT END OF TRAINING
- weight = model.module.weight.clone()
- torch.distributed.all_reduce(weight)
- assert torch.equal(model.module.weight, weight / world_size)
-
- print("Multi Node Distributed Training Done!")
-
-
-class PyTorchDistributed(LightningWork):
- def run(self, main_address: str, main_port: int, num_nodes: int, node_rank: int):
- nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
- torch.multiprocessing.spawn(
- distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
- )
-
-
-# 8 GPUs: (2 nodes x 4 v 100)
-compute = CloudCompute("gpu-fast-multi") # 4 x v100
-component = MultiNode(PyTorchDistributed, num_nodes=2, cloud_compute=compute)
-app = LightningApp(component)
diff --git a/examples/app/multi_node/train_pytorch_spawn.py b/examples/app/multi_node/train_pytorch_spawn.py
deleted file mode 100644
index 8febfe5dcf696..0000000000000
--- a/examples/app/multi_node/train_pytorch_spawn.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-from lightning.app import CloudCompute, LightningApp, LightningWork
-from lightning.app.components import PyTorchSpawnMultiNode
-from torch.nn.parallel.distributed import DistributedDataParallel
-
-
-class PyTorchDistributed(LightningWork):
- def run(
- self,
- world_size: int,
- node_rank: int,
- global_rank: str,
- local_rank: int,
- ):
- # 1. Prepare the model
- model = torch.nn.Sequential(
- torch.nn.Linear(1, 1),
- torch.nn.ReLU(),
- torch.nn.Linear(1, 1),
- )
-
- # 2. Setup distributed training
- device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
- model = DistributedDataParallel(
- model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None
- )
-
- # 3. Prepare loss and optimizer
- criterion = torch.nn.MSELoss()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
-
- # 4. Train the model for 1000 steps.
- for step in range(1000):
- model.zero_grad()
- x = torch.tensor([0.8]).to(device)
- target = torch.tensor([1.0]).to(device)
- output = model(x)
- loss = criterion(output, target)
- print(f"global_rank: {global_rank} step: {step} loss: {loss}")
- loss.backward()
- optimizer.step()
-
-
-# 8 GPUs: (2 nodes x 4 v 100)
-app = LightningApp(
- PyTorchSpawnMultiNode(
- PyTorchDistributed,
- num_nodes=2,
- cloud_compute=CloudCompute("gpu-fast-multi"), # 4 x v100
- )
-)
diff --git a/examples/app/payload/app.py b/examples/app/payload/app.py
deleted file mode 100644
index c92f589b088cd..0000000000000
--- a/examples/app/payload/app.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.storage import Payload
-
-
-class SourceFileWriterWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.value = None
-
- def run(self):
- self.value = Payload(42)
-
-
-class DestinationWork(LightningWork):
- def run(self, payload):
- assert payload.value == 42
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.src = SourceFileWriterWork()
- self.dst = DestinationWork()
-
- def run(self):
- self.src.run()
- self.dst.run(self.src.value)
- self.stop("Application End!")
-
-
-app = LightningApp(RootFlow())
diff --git a/examples/app/pickle_or_not/app.py b/examples/app/pickle_or_not/app.py
deleted file mode 100644
index aa7c3b01323da..0000000000000
--- a/examples/app/pickle_or_not/app.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import logging
-
-from lightning.app import LightningApp, LightningFlow, LightningWork
-
-logger = logging.getLogger(__name__)
-
-
-class PickleChecker(LightningWork):
- def run(self, pickle_image: bytes):
- parsed = self.parse_image(pickle_image)
- if parsed == b"it is a pickle":
- return True
- if parsed == b"it is not a pickle":
- return False
- raise Exception("Couldn't parse the image")
-
- @staticmethod
- def parse_image(image_str: bytes):
- return image_str
-
-
-class Slack(LightningFlow):
- def __init__(self):
- super().__init__()
-
- @staticmethod
- def send_message(message):
- logger.info(f"Sending message: {message}")
-
- def run(self):
- pass
-
-
-class RootComponent(LightningFlow):
- def __init__(self):
- super().__init__()
- self.pickle_checker = PickleChecker()
- self.slack = Slack()
- self.counter = 3
-
- def run(self):
- if self.counter > 0:
- logger.info(f"Running the app {self.counter}")
- image_str = b"it is not a pickle"
- if self.pickle_checker.run(image_str):
- self.slack.send_message("It's a pickle!")
- else:
- self.slack.send_message("It's not a pickle!")
- self.counter -= 1
- else:
- self.stop("Pickle or Not End")
-
-
-app = LightningApp(RootComponent())
diff --git a/examples/app/pickle_or_not/requirements.txt b/examples/app/pickle_or_not/requirements.txt
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/examples/app/server/app.py b/examples/app/server/app.py
deleted file mode 100644
index 97030179ffa78..0000000000000
--- a/examples/app/server/app.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# !pip install torchvision pydantic
-import base64
-import io
-
-import torch
-import torchvision
-from lightning.app import CloudCompute, LightningApp
-from lightning.app.components.serve import Image as InputImage
-from lightning.app.components.serve import PythonServer
-from PIL import Image
-from pydantic import BaseModel
-
-
-class PyTorchServer(PythonServer):
- def setup(self):
- self._model = torchvision.models.resnet18(pretrained=True)
- self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- self._model.to(self._device)
-
- def predict(self, request):
- image = base64.b64decode(request.image.encode("utf-8"))
- image = Image.open(io.BytesIO(image))
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224),
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ])
- image = transforms(image)
- image = image.to(self._device)
- prediction = self._model(image.unsqueeze(0))
- return {"prediction": prediction.argmax().item()}
-
-
-class OutputData(BaseModel):
- prediction: int
-
-
-component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=CloudCompute("gpu"))
-app = LightningApp(component)
diff --git a/examples/app/server_with_auto_scaler/app.py b/examples/app/server_with_auto_scaler/app.py
deleted file mode 100644
index 1320da6745fa6..0000000000000
--- a/examples/app/server_with_auto_scaler/app.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# ! pip install torch torchvision
-from typing import List
-
-import torch
-import torchvision
-from lightning.app import CloudCompute, LightningApp
-from pydantic import BaseModel
-
-
-class BatchRequestModel(BaseModel):
- inputs: List[app.components.Image]
-
-
-class BatchResponse(BaseModel):
- outputs: List[app.components.Number]
-
-
-class PyTorchServer(app.components.PythonServer):
- def __init__(self, *args, **kwargs):
- super().__init__(
- input_type=BatchRequestModel,
- output_type=BatchResponse,
- *args,
- **kwargs,
- )
-
- def setup(self):
- if torch.cuda.is_available():
- self._device = torch.device("cuda:0")
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
- self._device = torch.device("mps")
- else:
- self._device = torch.device("cpu")
- self._model = torchvision.models.resnet18(pretrained=True).to(self._device)
-
- def predict(self, requests: BatchRequestModel):
- transforms = torchvision.transforms.Compose([
- torchvision.transforms.Resize(224),
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ])
- images = []
- for request in requests.inputs:
- image = app.components.serve.types.image.Image.deserialize(request.image)
- image = transforms(image).unsqueeze(0)
- images.append(image)
- images = torch.cat(images)
- images = images.to(self._device)
- predictions = self._model(images)
- results = predictions.argmax(1).cpu().numpy().tolist()
- return BatchResponse(outputs=[{"prediction": pred} for pred in results])
-
-
-class MyAutoScaler(app.components.AutoScaler):
- def scale(self, replicas: int, metrics: dict) -> int:
- pending_requests = metrics["pending_requests"]
- active_or_pending_works = replicas + metrics["pending_works"]
-
- if active_or_pending_works == 0:
- return 1 if pending_requests > 0 else 0
-
- pending_requests_per_running_or_pending_work = pending_requests / active_or_pending_works
-
- # scale out if the number of pending requests exceeds max batch size.
- max_requests_per_work = self.max_batch_size
- if pending_requests_per_running_or_pending_work >= max_requests_per_work:
- return replicas + 1
-
- # scale in if the number of pending requests is below 25% of max_requests_per_work
- min_requests_per_work = max_requests_per_work * 0.25
- if pending_requests_per_running_or_pending_work < min_requests_per_work:
- return replicas - 1
-
- return replicas
-
-
-app = LightningApp(
- MyAutoScaler(
- # work class and args
- PyTorchServer,
- cloud_compute=CloudCompute("gpu"),
- # autoscaler specific args
- min_replicas=1,
- max_replicas=4,
- scale_out_interval=10,
- scale_in_interval=10,
- endpoint="predict",
- input_type=app.components.Image,
- output_type=app.components.Number,
- timeout_batching=1,
- max_batch_size=8,
- )
-)
diff --git a/examples/app/template_streamlit_ui/app.py b/examples/app/template_streamlit_ui/app.py
deleted file mode 100644
index 21a13036aa782..0000000000000
--- a/examples/app/template_streamlit_ui/app.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.frontend import StreamlitFrontend
-from lightning.app.utilities.state import AppState
-
-
-class StreamlitUI(LightningFlow):
- def __init__(self):
- super().__init__()
- self.message_to_print = "Hello World!"
- self.should_print = False
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=render_fn)
-
-
-def render_fn(state: AppState):
- import streamlit as st
-
- should_print = st.button("Should print to the terminal ?")
-
- if should_print:
- state.should_print = not state.should_print
-
- st.write("Currently printing." if state.should_print else "Currently waiting to print.")
-
-
-class HelloWorld(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
- self.streamlit_ui = StreamlitUI()
-
- def run(self):
- self.streamlit_ui.run()
- if self.streamlit_ui.should_print:
- print(f"{self.counter}: {self.streamlit_ui.message_to_print}")
- self.counter += 1
- self.streamlit_ui.should_print = False
-
- def configure_layout(self):
- return [{"name": "StreamLitUI", "content": self.streamlit_ui}]
-
-
-app = LightningApp(HelloWorld())
diff --git a/examples/app/template_streamlit_ui/requirements.txt b/examples/app/template_streamlit_ui/requirements.txt
deleted file mode 100644
index 12a4706528df6..0000000000000
--- a/examples/app/template_streamlit_ui/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-streamlit
diff --git a/examples/app/v0/.gitignore b/examples/app/v0/.gitignore
deleted file mode 100644
index 186149fa056fe..0000000000000
--- a/examples/app/v0/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-.storage
-.lightning
diff --git a/examples/app/v0/README.md b/examples/app/v0/README.md
deleted file mode 100644
index 516283ae9cedd..0000000000000
--- a/examples/app/v0/README.md
+++ /dev/null
@@ -1,18 +0,0 @@
-# v0 app
-
-This app is a flow-only app with nothing fancy.
-This is meant to present the basic functionalities of the lightning framework.
-
-## Starting it
-
-Local
-
-```bash
-lightning run app app.py
-```
-
-Cloud
-
-```bash
-lightning run app app.py --cloud
-```
diff --git a/examples/app/v0/app.py b/examples/app/v0/app.py
deleted file mode 100644
index d1cbb41c6dc10..0000000000000
--- a/examples/app/v0/app.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# v0_app.py
-import os
-from datetime import datetime
-from time import sleep
-
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.frontend import StaticWebFrontend
-
-
-class Word(LightningFlow):
- def __init__(self, letter):
- super().__init__()
- self.letter = letter
- self.repeats = letter
-
- def run(self):
- self.repeats += self.letter
-
- def configure_layout(self):
- return StaticWebFrontend(os.path.join(os.path.dirname(__file__), f"ui/{self.letter}"))
-
-
-class V0App(LightningFlow):
- def __init__(self):
- super().__init__()
- self.aas = Word("a")
- self.bbs = Word("b")
- self.counter = 0
-
- def run(self):
- now = datetime.now()
- now = now.strftime("%H:%M:%S")
- log = {"time": now, "a": self.aas.repeats, "b": self.bbs.repeats}
- print(log)
- self.aas.run()
- self.bbs.run()
-
- sleep(2.0)
- self.counter += 1
-
- def configure_layout(self):
- tab1 = {"name": "Tab_1", "content": self.aas}
- tab2 = {"name": "Tab_2", "content": self.bbs}
- tab3 = {"name": "Tab_3", "content": "https://tensorboard.dev/experiment/8m1aX0gcQ7aEmH0J7kbBtg/#scalars"}
-
- return [tab1, tab2, tab3]
-
-
-app = LightningApp(V0App(), log_level="debug")
diff --git a/examples/app/v0/emulate_ui.py b/examples/app/v0/emulate_ui.py
deleted file mode 100644
index 1d42c1cdf4c52..0000000000000
--- a/examples/app/v0/emulate_ui.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from time import sleep
-
-import requests
-from lightning.app.utilities.state import headers_for
-
-headers = headers_for({})
-headers["X-Lightning-Type"] = "DEFAULT"
-
-res = requests.get("http://127.0.0.1:7501/state", headers=headers)
-
-
-res = requests.post("http://127.0.0.1:7501/state", json={"stage": "running"}, headers=headers)
-print(res)
-
-sleep(10)
-
-res = requests.post("http://127.0.0.1:7501/state", json={"stage": "stopping"}, headers=headers)
-print(res)
diff --git a/examples/app/v0/requirements.txt b/examples/app/v0/requirements.txt
deleted file mode 100644
index edfce786a4d18..0000000000000
--- a/examples/app/v0/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-py
diff --git a/examples/app/v0/ui/a/index.html b/examples/app/v0/ui/a/index.html
deleted file mode 100644
index 6ddb9a5a1323c..0000000000000
--- a/examples/app/v0/ui/a/index.html
+++ /dev/null
@@ -1 +0,0 @@
-Hello from component A
diff --git a/examples/app/v0/ui/b/index.html b/examples/app/v0/ui/b/index.html
deleted file mode 100644
index 3bfd9e24cb7f7..0000000000000
--- a/examples/app/v0/ui/b/index.html
+++ /dev/null
@@ -1 +0,0 @@
-Hello from component B
diff --git a/examples/app/works_on_default_machine/app_v2.py b/examples/app/works_on_default_machine/app_v2.py
deleted file mode 100644
index 191070041b866..0000000000000
--- a/examples/app/works_on_default_machine/app_v2.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-from lightning import CloudCompute, LightningApp, LightningFlow, LightningWork
-from uvicorn import run
-
-
-class Work(LightningWork):
- def __init__(self, **kwargs):
- super().__init__(parallel=True, **kwargs)
-
- def run(self):
- fastapi_service = FastAPI()
-
- fastapi_service.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
-
- @fastapi_service.get("/")
- def get_root():
- return {"Hello Word!"}
-
- run(fastapi_service, host=self.host, port=self.port)
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- # In the Cloud: All the works defined without passing explicitly a CloudCompute object
- # are running on the default machine.
- # This would apply to `work_a`, `work_b` and the dynamically created `work_d`.
-
- self.work_a = Work()
- self.work_b = Work()
-
- self.work_c = Work(cloud_compute=CloudCompute(name="cpu-small"))
-
- def run(self):
- if not hasattr(self, "work_d"):
- self.work_d = Work()
-
- for work in self.works():
- work.run()
-
- def configure_layout(self):
- return [{"name": w.name, "content": w} for i, w in enumerate(self.works())]
-
-
-app = LightningApp(Flow(), log_level="debug")
diff --git a/examples/app/works_on_default_machine/requirements.txt b/examples/app/works_on_default_machine/requirements.txt
deleted file mode 100644
index 12a4706528df6..0000000000000
--- a/examples/app/works_on_default_machine/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-streamlit
diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py
index c5e6836bfeddf..7af01ede054a8 100644
--- a/examples/fabric/build_your_own_trainer/trainer.py
+++ b/examples/fabric/build_your_own_trainer/trainer.py
@@ -227,7 +227,7 @@ def train_loop(
should_optim_step = self.global_step % self.grad_accum_steps == 0
if should_optim_step:
# currently only supports a single optimizer
- self.fabric.call("on_before_optimizer_step", optimizer, 0)
+ self.fabric.call("on_before_optimizer_step", optimizer)
# optimizer step runs train step internally through closure
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
@@ -264,7 +264,7 @@ def val_loop(
val_loader: Optional[torch.utils.data.DataLoader],
limit_batches: Union[int, float] = float("inf"),
):
- """The validation loop ruunning a single validation epoch.
+ """The validation loop running a single validation epoch.
Args:
model: the LightningModule to evaluate
@@ -285,7 +285,10 @@ def val_loop(
)
return
- self.fabric.call("on_validation_model_eval") # calls `model.eval()`
+ if not is_overridden("on_validation_model_eval", _unwrap_objects(model)):
+ model.eval()
+ else:
+ self.fabric.call("on_validation_model_eval") # calls `model.eval()`
torch.set_grad_enabled(False)
@@ -311,7 +314,10 @@ def val_loop(
self.fabric.call("on_validation_epoch_end")
- self.fabric.call("on_validation_model_train")
+ if not is_overridden("on_validation_model_train", _unwrap_objects(model)):
+ model.train()
+ else:
+ self.fabric.call("on_validation_model_train")
torch.set_grad_enabled(True)
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py
index 1f3f83f3f2025..068359602a096 100644
--- a/examples/fabric/reinforcement_learning/train_fabric.py
+++ b/examples/fabric/reinforcement_learning/train_fabric.py
@@ -80,7 +80,7 @@ def main(args: argparse.Namespace):
# Log hyperparameters
fabric.logger.experiment.add_text(
"hyperparameters",
- "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ "|param|value|\n|-|-|\n{}".format("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# Environment setup
@@ -146,7 +146,7 @@ def main(args: argparse.Namespace):
# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
- rewards[step] = torch.tensor(reward, device=device).view(-1)
+ rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
if "final_info" in info:
diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py
index bbc09c977efcf..7150ac3a12529 100644
--- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py
+++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py
@@ -55,7 +55,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
# Log hyperparameters
logger.experiment.add_text(
"hyperparameters",
- "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ "|param|value|\n|-|-|\n{}".format("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# Environment setup
@@ -135,7 +135,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
# Single environment step
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
- rewards[step] = torch.tensor(reward, device=device).view(-1)
+ rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1)
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
if "final_info" in info:
diff --git a/examples/fabric/reinforcement_learning/train_torch.py b/examples/fabric/reinforcement_learning/train_torch.py
index 2cf755f77f9d9..cf74e03f5202e 100644
--- a/examples/fabric/reinforcement_learning/train_torch.py
+++ b/examples/fabric/reinforcement_learning/train_torch.py
@@ -138,7 +138,7 @@ def main(args: argparse.Namespace):
if global_rank == 0:
logger.add_text(
"hyperparameters",
- "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
+ "|param|value|\n|-|-|\n{}".format("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# Environment setup
diff --git a/examples/fabric/tensor_parallel/README.md b/examples/fabric/tensor_parallel/README.md
new file mode 100644
index 0000000000000..e66d9acd2848b
--- /dev/null
+++ b/examples/fabric/tensor_parallel/README.md
@@ -0,0 +1,45 @@
+## Tensor Parallel and 2D Parallel
+
+This example shows how to apply tensor-parallelism to your model (here Llama 3 7B) with the `ModelParallelStrategy`, and how it can be combined with FSDP (2D parallelism).
+PyTorch 2.3+ and a machine with at least 4 GPUs and 24 GB memory each are required to run this example.
+
+```bash
+pip install 'torch>=2.3'
+```
+
+Navigate to this example folder and run the training script:
+
+```bash
+cd examples/fabric/tensor_parallel
+python train.py
+```
+
+You should see an output like this:
+
+```
+Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
+Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
+Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
+Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
+----------------------------------------------------------------------------------------------------
+distributed_backend=nccl
+All distributed processes registered. Starting with 4 processes
+----------------------------------------------------------------------------------------------------
+
+Number of model parameters: 6.7 B
+Starting training ...
+Iteration 0 complete
+Iteration 1 complete
+Iteration 2 complete
+Iteration 3 complete
+Iteration 4 complete
+Iteration 5 complete
+Iteration 6 complete
+Iteration 7 complete
+Saving a (distributed) checkpoint ...
+Training successfully completed!
+Peak memory usage: 17.95 GB
+```
+
+> \[!NOTE\]
+> The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues).
diff --git a/examples/fabric/tensor_parallel/data.py b/examples/fabric/tensor_parallel/data.py
new file mode 100644
index 0000000000000..ba36987283ffd
--- /dev/null
+++ b/examples/fabric/tensor_parallel/data.py
@@ -0,0 +1,21 @@
+import torch
+from torch.utils.data import Dataset
+
+
+class RandomTokenDataset(Dataset):
+ def __init__(self, vocab_size: int, seq_length: int):
+ self.vocab_size = vocab_size
+ self.seq_length = seq_length
+ self.tokens = torch.randint(
+ self.vocab_size,
+ size=(len(self), self.seq_length + 1),
+ # Set a seed to make this toy dataset the same on each rank
+ # Fabric will add a `DistributedSampler` to shard the data correctly
+ generator=torch.Generator().manual_seed(42),
+ )
+
+ def __len__(self) -> int:
+ return 128
+
+ def __getitem__(self, item: int):
+ return self.tokens[item]
diff --git a/examples/fabric/tensor_parallel/model.py b/examples/fabric/tensor_parallel/model.py
new file mode 100644
index 0000000000000..3c9e7de472b90
--- /dev/null
+++ b/examples/fabric/tensor_parallel/model.py
@@ -0,0 +1,456 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# Llama 2 is licensed under the LLAMA 2 Community License,
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+@dataclass
+class ModelArgs:
+ dim: int = 4096
+ n_layers: int = 32
+ n_heads: int = 32
+ n_kv_heads: Optional[int] = None
+ vocab_size: int = -1 # defined later by tokenizer
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
+ ffn_dim_multiplier: Optional[float] = None
+ norm_eps: float = 1e-5
+ rope_theta: float = 10000
+
+ max_batch_size: int = 32
+ max_seq_len: int = 2048
+ # If `True`, then each transformer block init uses its layer ID, and if
+ # `False`, each uses the total number of transformer blocks
+ depth_init: bool = True
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
+ """Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
+
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device)
+ freqs = torch.outer(t, freqs).float()
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
+ and the first seqlen elements will be sliced, but dim must match x.
+
+ Args:
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ seqlen = x.shape[1]
+ freqs_cis = freqs_cis[0:seqlen]
+ assert freqs_cis.shape == (seqlen, x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+ bs, slen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+class RMSNorm(nn.Module):
+ """Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x: torch.Tensor):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+ def reset_parameters(self):
+ torch.nn.init.ones_(self.weight) # type: ignore
+
+
+class Attention(nn.Module):
+ """Multi-head attention module.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ n_kv_heads (int): Number of key and value heads.
+ n_heads (int): Number of query heads.
+ n_rep (int): Number of repetitions for local heads.
+ head_dim (int): Dimension size of each attention head.
+ wq (Linear): Linear transformation for queries.
+ wk (Linear): Linear transformation for keys.
+ wv (Linear): Linear transformation for values.
+ wo (Linear): Linear transformation for output.
+
+ """
+
+ def __init__(self, model_args: ModelArgs):
+ super().__init__()
+ self.n_heads = model_args.n_heads
+ self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
+ self.n_rep = self.n_heads // self.n_kv_heads
+ self.head_dim = model_args.dim // model_args.n_heads
+
+ self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
+ self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)
+
+ def init_weights(self, init_std: float):
+ for linear in (self.wq, self.wk, self.wv):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
+ nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ """Forward pass of the attention module.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
+
+ Returns:
+ torch.Tensor: Output tensor after attention.
+
+ """
+ bs, seqlen, _ = x.shape
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+ xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
+ xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
+ values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
+
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+ xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+ xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+
+ # we use casual mask for training
+ output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
+ output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
+ output = output.view(bs, seqlen, -1)
+ return self.wo(output)
+
+
+class FeedForward(nn.Module):
+ """FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
+
+ Attributes:
+ w1 (Linear): Linear transformation for the first layer.
+ w2 (Linear): Linear transformation for the second layer.
+ w3 (Linear): Linear transformation for the third layer.
+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+ def init_weights(self, init_std: float):
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
+ for linear in (self.w2, self.w3):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
+
+
+class TransformerBlock(nn.Module):
+ """TransformerBlock Module.
+
+ Args:
+ layer_id (int): Identifier for the layer.
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ n_heads (int): Number of attention heads.
+ dim (int): Dimension size of the model.
+ head_dim (int): Dimension size of each attention head.
+ attention (Attention): Attention module.
+ feed_forward (FeedForward): FeedForward module.
+ layer_id (int): Identifier for the layer.
+ attention_norm (RMSNorm): Layer normalization for attention output.
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
+
+ """
+
+ def __init__(self, layer_id: int, model_args: ModelArgs):
+ super().__init__()
+ self.n_heads = model_args.n_heads
+ self.dim = model_args.dim
+ self.attention = Attention(model_args)
+ self.feed_forward = FeedForward(
+ dim=model_args.dim,
+ hidden_dim=4 * model_args.dim,
+ multiple_of=model_args.multiple_of,
+ ffn_dim_multiplier=model_args.ffn_dim_multiplier,
+ )
+ self.layer_id = layer_id
+ self.num_layers = model_args.n_layers
+
+ self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+ self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+
+ if model_args.depth_init:
+ self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
+ else:
+ self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ """Perform a forward pass through the TransformerBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+ Returns:
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
+
+ """
+ h = x + self.attention(self.attention_norm(x), freqs_cis)
+ return h + self.feed_forward(self.ffn_norm(h))
+
+ def init_weights(self):
+ for norm in (self.attention_norm, self.ffn_norm):
+ norm.reset_parameters()
+ self.attention.init_weights(self.weight_init_std)
+ self.feed_forward.init_weights(self.weight_init_std)
+
+
+class Transformer(nn.Module):
+ """Transformer Module.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ model_args (ModelArgs): Model configuration arguments.
+ vocab_size (int): Vocabulary size.
+ n_layers (int): Number of layers in the model.
+ tok_embeddings (ParallelEmbedding): Token embeddings.
+ layers (torch.nn.ModuleList): List of Transformer blocks.
+ norm (RMSNorm): Layer normalization for the model output.
+ output (ColumnParallelLinear): Linear layer for final output.
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+ """
+
+ def __init__(self, model_args: ModelArgs):
+ super().__init__()
+ self.model_args = model_args
+ self.vocab_size = model_args.vocab_size
+ self.n_layers = model_args.n_layers
+
+ self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
+
+ # TODO persistent should be set to false, since this buffer can be recomputed.
+ # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
+ # compile or pipeline-tracer will not correctly handle non-persistent buffers,
+ # so we need to fix that. (2) if we initialize pipeline-parallel models from
+ # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
+ # initialized by the checkpoint, or we need to add a separate initializer for
+ # just the non-persistent buffers that is called after loading checkpoints.
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
+
+ self.layers = torch.nn.ModuleDict()
+ for layer_id in range(model_args.n_layers):
+ self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
+
+ self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+
+ self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
+ self.init_weights()
+
+ def reset_parameters(self):
+ with torch.device(self.freqs_cis.device):
+ self.freqs_cis = self._precompute_freqs_cis()
+
+ def init_weights(self):
+ """[Note: On ``init_weights`` vs.
+
+ ``reset_parameters``]
+ Modules may define ``reset_parameters`` to initialize parameter values.
+ ``reset_parameters`` is meant to only initialize directly owned
+ parameters/buffers, not those of their child modules, and it can be
+ used to give the initial values for these tensors.
+ Separately, users may want custom initialization for their modules,
+ different from that in ``reset_parameters``. For this, we define
+ ``init_weights``. We only call it in the constructor of this
+ ``Transformer`` root module to avoid reinitializing tensors.
+
+ """
+ with torch.device(self.freqs_cis.device):
+ self.freqs_cis = self._precompute_freqs_cis()
+ nn.init.normal_(self.tok_embeddings.weight)
+ for layer in self.layers.values():
+ layer.init_weights()
+ self.norm.reset_parameters()
+ final_out_std = self.model_args.dim**-0.5
+ cutoff_factor = 3
+ nn.init.trunc_normal_(
+ self.output.weight,
+ mean=0.0,
+ std=final_out_std,
+ a=-cutoff_factor * final_out_std,
+ b=cutoff_factor * final_out_std,
+ )
+
+ def _precompute_freqs_cis(self) -> torch.Tensor:
+ return precompute_freqs_cis(
+ self.model_args.dim // self.model_args.n_heads,
+ # Need to compute until at least the max token limit for generation
+ # (use 2x max sequence length to be safe)
+ self.model_args.max_seq_len * 2,
+ self.model_args.rope_theta,
+ )
+
+ def forward(self, tokens: torch.Tensor):
+ """Perform a forward pass through the Transformer model.
+
+ Args:
+ tokens (torch.Tensor): Input token indices.
+
+ Returns:
+ torch.Tensor: Output logits after applying the Transformer model.
+
+ """
+ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
+ h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
+
+ for layer in self.layers.values():
+ h = layer(h, self.freqs_cis)
+
+ h = self.norm(h) if self.norm else h
+ return self.output(h).float() if self.output else h
+
+ @classmethod
+ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
+ """Initialize a Transformer model from a ModelArgs object.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Returns:
+ Transformer: Transformer model.
+
+ """
+ return cls(model_args)
diff --git a/examples/fabric/tensor_parallel/parallelism.py b/examples/fabric/tensor_parallel/parallelism.py
new file mode 100644
index 0000000000000..44d55c8da1cc9
--- /dev/null
+++ b/examples/fabric/tensor_parallel/parallelism.py
@@ -0,0 +1,106 @@
+import torch
+from model import Transformer
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy
+from torch.distributed._composable.fsdp.fully_shard import fully_shard
+from torch.distributed._tensor import Replicate, Shard
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ PrepareModuleInput,
+ RowwiseParallel,
+ SequenceParallel,
+ parallelize_module,
+)
+
+
+# Taken and modified from torchtitan
+# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
+def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
+ """Apply parallelisms and activation checkpointing to the model.
+
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
+ the model must fit on GPU or CPU memory.
+
+ """
+
+ dp_mesh = device_mesh["data_parallel"]
+ tp_mesh = device_mesh["tensor_parallel"]
+
+ if tp_mesh.size() > 1:
+ # 1. Parallelize the first embedding and the last linear proj layer
+ # 2. Parallelize the root norm layer over the sequence dim
+ # 3. Shard the first transformer block's inputs
+
+ # Parallelize the first embedding and the last linear out projection
+ plan = {
+ "tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
+ "output": ColwiseParallel(
+ input_layouts=Shard(1),
+ # Optional: Shard the output along the class dimension to compute the loss in parallel.
+ # See `loss_parallel` in `train.py`
+ output_layouts=Shard(-1),
+ use_local_output=False,
+ ),
+ "norm": SequenceParallel(),
+ "layers.0": PrepareModuleInput(
+ input_layouts=(Replicate(), None),
+ desired_input_layouts=(Shard(1), None),
+ use_local_output=True,
+ ),
+ }
+ model = parallelize_module(model, tp_mesh, plan)
+
+ # Parallelize each transformer block
+ for transformer_block in model.layers.values():
+ plan = {
+ "attention": PrepareModuleInput(
+ input_layouts=(Shard(1), None),
+ desired_input_layouts=(Replicate(), None),
+ ),
+ "attention.wq": ColwiseParallel(),
+ "attention.wk": ColwiseParallel(),
+ "attention.wv": ColwiseParallel(),
+ "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
+ "attention_norm": SequenceParallel(),
+ "feed_forward": PrepareModuleInput(
+ input_layouts=(Shard(1),),
+ desired_input_layouts=(Replicate(),),
+ ),
+ "feed_forward.w1": ColwiseParallel(),
+ "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
+ "feed_forward.w3": ColwiseParallel(),
+ "ffn_norm": SequenceParallel(),
+ }
+
+ # Adjust attention module to use the local number of heads
+ attn_layer = transformer_block.attention
+ attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
+ attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
+
+ # Apply the plan for the current transformer block
+ parallelize_module(transformer_block, tp_mesh, plan)
+
+ if dp_mesh.size() > 1:
+ assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
+
+ # NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here
+ # because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment.
+ mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
+
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
+ for layer_id, transformer_block in model.layers.items():
+ # Apply activation checkpointing
+ transformer_block = checkpoint_wrapper(transformer_block)
+ # As an optimization, do not reshard after forward for the last
+ # transformer block since FSDP would prefetch it immediately
+ reshard_after_forward = int(layer_id) < len(model.layers) - 1
+ fully_shard(
+ transformer_block,
+ **fsdp_config,
+ reshard_after_forward=reshard_after_forward,
+ )
+ model.layers[layer_id] = transformer_block
+ model = fully_shard(model, **fsdp_config)
+
+ return model
diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py
new file mode 100644
index 0000000000000..4a98f12cf6168
--- /dev/null
+++ b/examples/fabric/tensor_parallel/train.py
@@ -0,0 +1,78 @@
+import lightning as L
+import torch
+import torch.nn.functional as F
+from data import RandomTokenDataset
+from lightning.fabric.strategies import ModelParallelStrategy
+from model import ModelArgs, Transformer
+from parallelism import parallelize
+from torch.distributed.tensor.parallel import loss_parallel
+from torch.utils.data import DataLoader
+
+
+def train():
+ strategy = ModelParallelStrategy(
+ # User-defined function that applies the desired parallelizations specific to the model
+ # (TP, FSDP2, activation checkpointing, ...)
+ parallelize_fn=parallelize,
+ # Define the size of the 2D parallelism
+ # Set to "auto" to apply TP intra-node and DP inter-node
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+
+ fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
+ fabric.launch()
+
+ # Initialize the model
+ model_args = ModelArgs(vocab_size=32000)
+ with fabric.init_module(empty_init=True):
+ model = Transformer(model_args)
+
+ fabric.print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f} B")
+
+ # Set up model and optimizer
+ model = fabric.setup(model)
+ model.init_weights()
+
+ # Define the optimizer
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)
+ optimizer = fabric.setup_optimizers(optimizer)
+
+ # Define dataset/dataloader
+ dataset = RandomTokenDataset(vocab_size=model_args.vocab_size, seq_length=128)
+ dataloader = DataLoader(dataset, batch_size=8)
+
+ # Fabric configures the sampler automatically for you such that
+ # all batches in a tensor-parallel group are identical
+ dataloader = fabric.setup_dataloaders(dataloader)
+
+ # Simplified training loop
+ fabric.print("Starting training ...")
+
+ for i, batch in enumerate(dataloader):
+ inputs = batch[:, :-1]
+ labels = batch[:, 1:]
+
+ output = model(inputs)
+
+ with loss_parallel():
+ loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
+ fabric.backward(loss)
+
+ optimizer.step()
+ optimizer.zero_grad()
+ fabric.print(f"Iteration {i} complete")
+
+ # See `fabric consolidate --help` if you need to convert the checkpoint to a single file
+ fabric.print("Saving a (distributed) checkpoint ...")
+ state = {"model": model, "optimizer": optimizer, "iteration": i}
+ fabric.save("checkpoint.pt", state)
+
+ fabric.print("Training successfully completed!")
+ fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
+
+
+if __name__ == "__main__":
+ assert torch.cuda.device_count() >= 4, "This example requires at least 4 GPUs with 24 GB of memory each."
+ torch.set_float32_matmul_precision("high")
+ train()
diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py
index 311d1c38771b8..417e167df0d93 100644
--- a/examples/pytorch/domain_templates/generative_adversarial_net.py
+++ b/examples/pytorch/domain_templates/generative_adversarial_net.py
@@ -19,9 +19,9 @@
"""
+import math
from argparse import ArgumentParser, Namespace
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -59,7 +59,7 @@ def block(in_feat, out_feat, normalize=True):
*block(128, 256),
*block(256, 512),
*block(512, 1024),
- nn.Linear(1024, int(np.prod(img_shape))),
+ nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)
@@ -80,7 +80,7 @@ def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
- nn.Linear(int(np.prod(img_shape)), 512),
+ nn.Linear(int(math.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
index 9b065db8173fe..497cb658c275f 100644
--- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
+++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
@@ -33,11 +33,11 @@
"""
import argparse
+import random
from collections import OrderedDict, deque, namedtuple
from typing import Iterator, List, Tuple
import gym
-import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
@@ -103,15 +103,15 @@ def append(self, experience: Experience) -> None:
self.buffer.append(experience)
def sample(self, batch_size: int) -> Tuple:
- indices = np.random.choice(len(self.buffer), batch_size, replace=False)
+ indices = random.sample(range(len(self.buffer)), batch_size)
states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))
return (
- np.array(states),
- np.array(actions),
- np.array(rewards, dtype=np.float32),
- np.array(dones, dtype=np.bool),
- np.array(next_states),
+ torch.tensor(states),
+ torch.tensor(actions),
+ torch.tensor(rewards, dtype=torch.float32),
+ torch.tensor(dones, dtype=torch.bool),
+ torch.tensor(next_states),
)
@@ -175,7 +175,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
action
"""
- if np.random.random() < epsilon:
+ if random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])
diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py
index c60518e229cbf..12ecbeeb5f0a9 100644
--- a/examples/pytorch/domain_templates/semantic_segmentation.py
+++ b/examples/pytorch/domain_templates/semantic_segmentation.py
@@ -16,7 +16,6 @@
import random
from argparse import ArgumentParser, Namespace
-import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
@@ -107,11 +106,11 @@ def __len__(self):
def __getitem__(self, idx):
img = Image.open(self.img_list[idx])
img = img.resize(self.img_size)
- img = np.array(img)
+ img = torch.tensor(img)
mask = Image.open(self.mask_list[idx]).convert("L")
mask = mask.resize(self.img_size)
- mask = np.array(mask)
+ mask = torch.tensor(mask)
mask = self.encode_segmap(mask)
if self.transform:
diff --git a/examples/pytorch/tensor_parallel/README.md b/examples/pytorch/tensor_parallel/README.md
new file mode 100644
index 0000000000000..d8b81b6de1bff
--- /dev/null
+++ b/examples/pytorch/tensor_parallel/README.md
@@ -0,0 +1,49 @@
+## Tensor Parallel and 2D Parallel
+
+This example shows how to apply tensor-parallelism to your model (here Llama 3 7B) with the `ModelParallelStrategy`, and how it can be combined with FSDP (2D parallelism).
+PyTorch 2.3+ and a machine with at least 4 GPUs and 24 GB memory each are required to run this example.
+
+```bash
+pip install 'torch>=2.3'
+```
+
+Navigate to this example folder and run the training script:
+
+```bash
+cd examples/pytorch/tensor_parallel
+python train.py
+```
+
+You should see an output like this:
+
+```
+GPU available: True (cuda), used: True
+TPU available: False, using: 0 TPU cores
+HPU available: False, using: 0 HPUs
+
+Number of model parameters: 6.7 B
+Starting training ...
+
+Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
+Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
+Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
+Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
+----------------------------------------------------------------------------------------------------
+distributed_backend=nccl
+All distributed processes registered. Starting with 4 processes
+----------------------------------------------------------------------------------------------------
+
+LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
+LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
+LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
+LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
+
+Epoch 0: 100%|█████████████████████████████████████████████| 10/10 [01:49<00:00, 0.09it/s, v_num=2]
+`Trainer.fit` stopped: `max_epochs=1` reached.
+Saving a (distributed) checkpoint ...
+Training successfully completed!
+Peak memory usage: 36.73 GB
+```
+
+> \[!NOTE\]
+> The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues).
diff --git a/examples/pytorch/tensor_parallel/data.py b/examples/pytorch/tensor_parallel/data.py
new file mode 100644
index 0000000000000..ba36987283ffd
--- /dev/null
+++ b/examples/pytorch/tensor_parallel/data.py
@@ -0,0 +1,21 @@
+import torch
+from torch.utils.data import Dataset
+
+
+class RandomTokenDataset(Dataset):
+ def __init__(self, vocab_size: int, seq_length: int):
+ self.vocab_size = vocab_size
+ self.seq_length = seq_length
+ self.tokens = torch.randint(
+ self.vocab_size,
+ size=(len(self), self.seq_length + 1),
+ # Set a seed to make this toy dataset the same on each rank
+ # Fabric will add a `DistributedSampler` to shard the data correctly
+ generator=torch.Generator().manual_seed(42),
+ )
+
+ def __len__(self) -> int:
+ return 128
+
+ def __getitem__(self, item: int):
+ return self.tokens[item]
diff --git a/examples/pytorch/tensor_parallel/model.py b/examples/pytorch/tensor_parallel/model.py
new file mode 100644
index 0000000000000..3c9e7de472b90
--- /dev/null
+++ b/examples/pytorch/tensor_parallel/model.py
@@ -0,0 +1,456 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# Llama 2 is licensed under the LLAMA 2 Community License,
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+@dataclass
+class ModelArgs:
+ dim: int = 4096
+ n_layers: int = 32
+ n_heads: int = 32
+ n_kv_heads: Optional[int] = None
+ vocab_size: int = -1 # defined later by tokenizer
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
+ ffn_dim_multiplier: Optional[float] = None
+ norm_eps: float = 1e-5
+ rope_theta: float = 10000
+
+ max_batch_size: int = 32
+ max_seq_len: int = 2048
+ # If `True`, then each transformer block init uses its layer ID, and if
+ # `False`, each uses the total number of transformer blocks
+ depth_init: bool = True
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
+ """Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ end (int): End index for precomputing frequencies.
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+ Returns:
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
+
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device)
+ freqs = torch.outer(t, freqs).float()
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
+ and the first seqlen elements will be sliced, but dim must match x.
+
+ Args:
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ seqlen = x.shape[1]
+ freqs_cis = freqs_cis[0:seqlen]
+ assert freqs_cis.shape == (seqlen, x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+ bs, slen, n_kv_heads, head_dim = x.shape
+ if n_rep == 1:
+ return x
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+class RMSNorm(nn.Module):
+ """Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x: torch.Tensor):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+ def reset_parameters(self):
+ torch.nn.init.ones_(self.weight) # type: ignore
+
+
+class Attention(nn.Module):
+ """Multi-head attention module.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ n_kv_heads (int): Number of key and value heads.
+ n_heads (int): Number of query heads.
+ n_rep (int): Number of repetitions for local heads.
+ head_dim (int): Dimension size of each attention head.
+ wq (Linear): Linear transformation for queries.
+ wk (Linear): Linear transformation for keys.
+ wv (Linear): Linear transformation for values.
+ wo (Linear): Linear transformation for output.
+
+ """
+
+ def __init__(self, model_args: ModelArgs):
+ super().__init__()
+ self.n_heads = model_args.n_heads
+ self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
+ self.n_rep = self.n_heads // self.n_kv_heads
+ self.head_dim = model_args.dim // model_args.n_heads
+
+ self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
+ self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)
+
+ def init_weights(self, init_std: float):
+ for linear in (self.wq, self.wk, self.wv):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
+ nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ """Forward pass of the attention module.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
+
+ Returns:
+ torch.Tensor: Output tensor after attention.
+
+ """
+ bs, seqlen, _ = x.shape
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+ xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
+ xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
+ xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)
+
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
+ values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
+
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+ xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+ xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+
+ # we use casual mask for training
+ output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
+ output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
+ output = output.view(bs, seqlen, -1)
+ return self.wo(output)
+
+
+class FeedForward(nn.Module):
+ """FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
+
+ Attributes:
+ w1 (Linear): Linear transformation for the first layer.
+ w2 (Linear): Linear transformation for the second layer.
+ w3 (Linear): Linear transformation for the third layer.
+
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+ def init_weights(self, init_std: float):
+ nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
+ for linear in (self.w2, self.w3):
+ nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
+
+
+class TransformerBlock(nn.Module):
+ """TransformerBlock Module.
+
+ Args:
+ layer_id (int): Identifier for the layer.
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ n_heads (int): Number of attention heads.
+ dim (int): Dimension size of the model.
+ head_dim (int): Dimension size of each attention head.
+ attention (Attention): Attention module.
+ feed_forward (FeedForward): FeedForward module.
+ layer_id (int): Identifier for the layer.
+ attention_norm (RMSNorm): Layer normalization for attention output.
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
+
+ """
+
+ def __init__(self, layer_id: int, model_args: ModelArgs):
+ super().__init__()
+ self.n_heads = model_args.n_heads
+ self.dim = model_args.dim
+ self.attention = Attention(model_args)
+ self.feed_forward = FeedForward(
+ dim=model_args.dim,
+ hidden_dim=4 * model_args.dim,
+ multiple_of=model_args.multiple_of,
+ ffn_dim_multiplier=model_args.ffn_dim_multiplier,
+ )
+ self.layer_id = layer_id
+ self.num_layers = model_args.n_layers
+
+ self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+ self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+
+ if model_args.depth_init:
+ self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
+ else:
+ self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ """Perform a forward pass through the TransformerBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+ Returns:
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
+
+ """
+ h = x + self.attention(self.attention_norm(x), freqs_cis)
+ return h + self.feed_forward(self.ffn_norm(h))
+
+ def init_weights(self):
+ for norm in (self.attention_norm, self.ffn_norm):
+ norm.reset_parameters()
+ self.attention.init_weights(self.weight_init_std)
+ self.feed_forward.init_weights(self.weight_init_std)
+
+
+class Transformer(nn.Module):
+ """Transformer Module.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Attributes:
+ model_args (ModelArgs): Model configuration arguments.
+ vocab_size (int): Vocabulary size.
+ n_layers (int): Number of layers in the model.
+ tok_embeddings (ParallelEmbedding): Token embeddings.
+ layers (torch.nn.ModuleList): List of Transformer blocks.
+ norm (RMSNorm): Layer normalization for the model output.
+ output (ColumnParallelLinear): Linear layer for final output.
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+ """
+
+ def __init__(self, model_args: ModelArgs):
+ super().__init__()
+ self.model_args = model_args
+ self.vocab_size = model_args.vocab_size
+ self.n_layers = model_args.n_layers
+
+ self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
+
+ # TODO persistent should be set to false, since this buffer can be recomputed.
+ # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
+ # compile or pipeline-tracer will not correctly handle non-persistent buffers,
+ # so we need to fix that. (2) if we initialize pipeline-parallel models from
+ # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
+ # initialized by the checkpoint, or we need to add a separate initializer for
+ # just the non-persistent buffers that is called after loading checkpoints.
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
+
+ self.layers = torch.nn.ModuleDict()
+ for layer_id in range(model_args.n_layers):
+ self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
+
+ self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
+
+ self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
+ self.init_weights()
+
+ def reset_parameters(self):
+ with torch.device(self.freqs_cis.device):
+ self.freqs_cis = self._precompute_freqs_cis()
+
+ def init_weights(self):
+ """[Note: On ``init_weights`` vs.
+
+ ``reset_parameters``]
+ Modules may define ``reset_parameters`` to initialize parameter values.
+ ``reset_parameters`` is meant to only initialize directly owned
+ parameters/buffers, not those of their child modules, and it can be
+ used to give the initial values for these tensors.
+ Separately, users may want custom initialization for their modules,
+ different from that in ``reset_parameters``. For this, we define
+ ``init_weights``. We only call it in the constructor of this
+ ``Transformer`` root module to avoid reinitializing tensors.
+
+ """
+ with torch.device(self.freqs_cis.device):
+ self.freqs_cis = self._precompute_freqs_cis()
+ nn.init.normal_(self.tok_embeddings.weight)
+ for layer in self.layers.values():
+ layer.init_weights()
+ self.norm.reset_parameters()
+ final_out_std = self.model_args.dim**-0.5
+ cutoff_factor = 3
+ nn.init.trunc_normal_(
+ self.output.weight,
+ mean=0.0,
+ std=final_out_std,
+ a=-cutoff_factor * final_out_std,
+ b=cutoff_factor * final_out_std,
+ )
+
+ def _precompute_freqs_cis(self) -> torch.Tensor:
+ return precompute_freqs_cis(
+ self.model_args.dim // self.model_args.n_heads,
+ # Need to compute until at least the max token limit for generation
+ # (use 2x max sequence length to be safe)
+ self.model_args.max_seq_len * 2,
+ self.model_args.rope_theta,
+ )
+
+ def forward(self, tokens: torch.Tensor):
+ """Perform a forward pass through the Transformer model.
+
+ Args:
+ tokens (torch.Tensor): Input token indices.
+
+ Returns:
+ torch.Tensor: Output logits after applying the Transformer model.
+
+ """
+ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
+ h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
+
+ for layer in self.layers.values():
+ h = layer(h, self.freqs_cis)
+
+ h = self.norm(h) if self.norm else h
+ return self.output(h).float() if self.output else h
+
+ @classmethod
+ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
+ """Initialize a Transformer model from a ModelArgs object.
+
+ Args:
+ model_args (ModelArgs): Model configuration arguments.
+
+ Returns:
+ Transformer: Transformer model.
+
+ """
+ return cls(model_args)
diff --git a/examples/pytorch/tensor_parallel/parallelism.py b/examples/pytorch/tensor_parallel/parallelism.py
new file mode 100644
index 0000000000000..44d55c8da1cc9
--- /dev/null
+++ b/examples/pytorch/tensor_parallel/parallelism.py
@@ -0,0 +1,106 @@
+import torch
+from model import Transformer
+from torch.distributed._composable.fsdp import MixedPrecisionPolicy
+from torch.distributed._composable.fsdp.fully_shard import fully_shard
+from torch.distributed._tensor import Replicate, Shard
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ PrepareModuleInput,
+ RowwiseParallel,
+ SequenceParallel,
+ parallelize_module,
+)
+
+
+# Taken and modified from torchtitan
+# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
+def parallelize(model: Transformer, device_mesh: DeviceMesh) -> Transformer:
+ """Apply parallelisms and activation checkpointing to the model.
+
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
+ the model must fit on GPU or CPU memory.
+
+ """
+
+ dp_mesh = device_mesh["data_parallel"]
+ tp_mesh = device_mesh["tensor_parallel"]
+
+ if tp_mesh.size() > 1:
+ # 1. Parallelize the first embedding and the last linear proj layer
+ # 2. Parallelize the root norm layer over the sequence dim
+ # 3. Shard the first transformer block's inputs
+
+ # Parallelize the first embedding and the last linear out projection
+ plan = {
+ "tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
+ "output": ColwiseParallel(
+ input_layouts=Shard(1),
+ # Optional: Shard the output along the class dimension to compute the loss in parallel.
+ # See `loss_parallel` in `train.py`
+ output_layouts=Shard(-1),
+ use_local_output=False,
+ ),
+ "norm": SequenceParallel(),
+ "layers.0": PrepareModuleInput(
+ input_layouts=(Replicate(), None),
+ desired_input_layouts=(Shard(1), None),
+ use_local_output=True,
+ ),
+ }
+ model = parallelize_module(model, tp_mesh, plan)
+
+ # Parallelize each transformer block
+ for transformer_block in model.layers.values():
+ plan = {
+ "attention": PrepareModuleInput(
+ input_layouts=(Shard(1), None),
+ desired_input_layouts=(Replicate(), None),
+ ),
+ "attention.wq": ColwiseParallel(),
+ "attention.wk": ColwiseParallel(),
+ "attention.wv": ColwiseParallel(),
+ "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
+ "attention_norm": SequenceParallel(),
+ "feed_forward": PrepareModuleInput(
+ input_layouts=(Shard(1),),
+ desired_input_layouts=(Replicate(),),
+ ),
+ "feed_forward.w1": ColwiseParallel(),
+ "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
+ "feed_forward.w3": ColwiseParallel(),
+ "ffn_norm": SequenceParallel(),
+ }
+
+ # Adjust attention module to use the local number of heads
+ attn_layer = transformer_block.attention
+ attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
+ attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
+
+ # Apply the plan for the current transformer block
+ parallelize_module(transformer_block, tp_mesh, plan)
+
+ if dp_mesh.size() > 1:
+ assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
+
+ # NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here
+ # because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment.
+ mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
+
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
+ for layer_id, transformer_block in model.layers.items():
+ # Apply activation checkpointing
+ transformer_block = checkpoint_wrapper(transformer_block)
+ # As an optimization, do not reshard after forward for the last
+ # transformer block since FSDP would prefetch it immediately
+ reshard_after_forward = int(layer_id) < len(model.layers) - 1
+ fully_shard(
+ transformer_block,
+ **fsdp_config,
+ reshard_after_forward=reshard_after_forward,
+ )
+ model.layers[layer_id] = transformer_block
+ model = fully_shard(model, **fsdp_config)
+
+ return model
diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py
new file mode 100644
index 0000000000000..6a91e1242e4af
--- /dev/null
+++ b/examples/pytorch/tensor_parallel/train.py
@@ -0,0 +1,80 @@
+import lightning as L
+import torch
+import torch.nn.functional as F
+from data import RandomTokenDataset
+from lightning.pytorch.strategies import ModelParallelStrategy
+from model import ModelArgs, Transformer
+from parallelism import parallelize
+from torch.distributed.tensor.parallel import loss_parallel
+from torch.utils.data import DataLoader
+
+
+class Llama3(L.LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.model_args = ModelArgs(vocab_size=32000)
+ self.model = Transformer(self.model_args)
+
+ def configure_model(self):
+ # User-defined function that applies the desired parallelizations specific to the model
+ # (TP, FSDP2, activation checkpointing, ...)
+ parallelize(self.model, device_mesh=self.device_mesh)
+
+ def on_train_start(self) -> None:
+ self.model.init_weights()
+
+ def training_step(self, batch):
+ inputs = batch[:, :-1]
+ labels = batch[:, 1:]
+ output = self.model(inputs)
+ # Optional: Parallelize loss computation across class dimension (see parallelism.py)
+ with loss_parallel():
+ return F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
+
+ def backward(self, *args, **kwargs):
+ with loss_parallel():
+ super().backward(*args, **kwargs)
+
+ def configure_optimizers(self):
+ return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)
+
+ def train_dataloader(self):
+ dataset = RandomTokenDataset(vocab_size=self.model_args.vocab_size, seq_length=128)
+ # Trainer configures the sampler automatically for you such that
+ # all batches in a tensor-parallel group are identical
+ return DataLoader(dataset, batch_size=8, num_workers=4)
+
+
+def train():
+ strategy = ModelParallelStrategy(
+ # Define the size of the 2D parallelism
+ # Set to "auto" to apply TP intra-node and DP inter-node
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+
+ trainer = L.Trainer(
+ accelerator="cuda",
+ devices=4,
+ strategy=strategy,
+ limit_train_batches=10,
+ max_epochs=1,
+ )
+
+ # Initialize the model
+ with trainer.init_module(empty_init=True):
+ model = Llama3()
+
+ trainer.print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f} B")
+ trainer.print("Starting training ...")
+
+ trainer.fit(model)
+
+ trainer.print("Training successfully completed!")
+ trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
+
+
+if __name__ == "__main__":
+ assert torch.cuda.device_count() >= 4, "This example requires at least 4 GPUs with 24 GB of memory each."
+ torch.set_float32_matmul_precision("high")
+ train()
diff --git a/pyproject.toml b/pyproject.toml
index ecfc70736994a..da4cd7f197d5a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,15 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-[metadata]
-name = "lightning"
-author = "Lightning-AI et al."
-url = "https://github.com/Lightning-AI/lightning"
-
[build-system]
requires = [
"setuptools",
"wheel",
+ "packaging",
]
@@ -80,7 +76,6 @@ ignore = [
"S108",
"E203", # conflicts with black
]
-ignore-init-module-imports = true
[tool.ruff.lint.per-file-ignores]
".actions/*" = ["S101", "S310"]
@@ -139,13 +134,7 @@ max-complexity = 10
files = [
"src/lightning",
]
-# This section is for folders with "-" as they are not valid python modules
-exclude = [
- "src/lightning/app/cli/app-template",
- "src/lightning/app/cli/component-template",
- "src/lightning/app/cli/pl-app-template",
- "src/lightning/app/cli/react-ui-template",
-]
+
install_types = "True"
non_interactive = "True"
disallow_untyped_defs = "True"
@@ -160,100 +149,6 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"
-# Ignore mypy errors for these files
-# TODO: the goal is for this to be empty
-[[tool.mypy.overrides]]
-# the list can be generated with:
-# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
-module = [
- "lightning.app.api.http_methods",
- "lightning.app.api.request_types",
- "lightning.app.cli.cmd_install",
- "lightning.app.cli.commands.app_commands",
- "lightning.app.cli.commands.cd",
- "lightning.app.cli.commands.cp",
- "lightning.app.cli.commands.ls",
- "lightning.app.cli.connect.app",
- "lightning.app.components.database.client",
- "lightning.app.components.database.server",
- "lightning.app.components.database.utilities",
- "lightning.app.components.multi_node.base",
- "lightning.app.components.multi_node.fabric",
- "lightning.app.components.multi_node.pytorch_spawn",
- "lightning.app.components.multi_node.trainer",
- "lightning.app.components.python.popen",
- "lightning.app.components.python.tracer",
- "lightning.app.components.serve.auto_scaler",
- "lightning.app.components.serve.gradio_server",
- "lightning.app.components.serve.python_server",
- "lightning.app.components.serve.serve",
- "lightning.app.components.serve.streamlit",
- "lightning.app.components.serve.types.image",
- "lightning.app.components.serve.types.type",
- "lightning.app.components.training",
- "lightning.app.frontend.panel.app_state_comm",
- "lightning.app.frontend.panel.app_state_watcher",
- "lightning.app.frontend.panel.panel_frontend",
- "lightning.app.frontend.panel.panel_serve_render_fn",
- "lightning.app.frontend.streamlit_base",
- "lightning.app.frontend.stream_lit",
- "lightning.app.frontend.utils",
- "lightning.app.frontend.web",
- "lightning.app.launcher.launcher",
- "lightning.app.launcher.lightning_backend",
- "lightning.app.launcher.lightning_hybrid_backend",
- "lightning.app.pdb.pdb",
- "lightning.app.runners.backends.backend",
- "lightning.app.runners.backends.cloud",
- "lightning.app.runners.backends.docker",
- "lightning.app.runners.backends.mp_process",
- "lightning.app.runners.cloud",
- "lightning.app.runners.multiprocess",
- "lightning.app.runners.runtime",
- "lightning.app.source_code.copytree",
- "lightning.app.source_code.hashing",
- "lightning.app.source_code.local",
- "lightning.app.source_code.tar",
- "lightning.app.source_code.uploader",
- "lightning.app.storage.copier",
- "lightning.app.storage.drive",
- "lightning.app.storage.filesystem",
- "lightning.app.storage.orchestrator",
- "lightning.app.storage.path",
- "lightning.app.storage.payload",
- "lightning.app.structures.dict",
- "lightning.app.structures.list",
- "lightning.app.testing.helpers",
- "lightning.app.testing.testing",
- "lightning.app.utilities.app_helpers",
- "lightning.app.utilities.app_logs",
- "lightning.app.utilities.cli_helpers",
- "lightning.app.utilities.cloud",
- "lightning.app.utilities.commands.base",
- "lightning.app.utilities.component",
- "lightning.app.utilities.enum",
- "lightning.app.utilities.exceptions",
- "lightning.app.utilities.git",
- "lightning.app.utilities.imports",
- "lightning.app.utilities.introspection",
- "lightning.app.utilities.layout",
- "lightning.app.utilities.load_app",
- "lightning.app.utilities.log_helpers",
- "lightning.app.utilities.login",
- "lightning.app.utilities.name_generator",
- "lightning.app.utilities.network",
- "lightning.app.utilities.openapi",
- "lightning.app.utilities.packaging.cloud_compute",
- "lightning.app.utilities.packaging.lightning_utils",
- "lightning.app.utilities.proxies",
- "lightning.app.utilities.scheduler",
- "lightning.app.utilities.state",
- "lightning.app.utilities.tracer",
- "lightning.app.utilities.tree",
- "lightning.store.utils",
-]
-ignore_errors = "True"
-
[tool.coverage.report]
exclude_lines = [
diff --git a/requirements.txt b/requirements.txt
index bcb63693fbbe3..4910c7fbe7fc0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,4 @@
# the default package dependencies
--r ./requirements/app/app.txt
-r ./requirements/fabric/base.txt
-r ./requirements/pytorch/base.txt
diff --git a/requirements/app/app.txt b/requirements/app/app.txt
deleted file mode 100644
index 85e5b270c09e8..0000000000000
--- a/requirements/app/app.txt
+++ /dev/null
@@ -1,29 +0,0 @@
-lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
-packaging
-typing-extensions >=4.4.0, <4.10.0
-deepdiff >=5.7.0, <6.6.0
-fsspec[http] >=2022.5.0, <2023.11.0
-croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust.
-traitlets >=5.3.0, <5.12.0
-arrow >=1.2.0, <1.3.0
-lightning-utilities >=0.8.0, <0.12.0
-beautifulsoup4 >=4.8.0, <4.13.0
-inquirer >=2.10.0, <3.2.0
-psutil <5.9.6
-click <8.2
-python-multipart >=0.0.5, <=0.0.6
-backoff >=2.2.1, <2.3.0
-
-fastapi >=0.92.0, <0.104.0
-starlette # https://fastapi.tiangolo.com/deployment/versions/#about-starlette
-pydantic >=1.7.4 # https://fastapi.tiangolo.com/deployment/versions/#about-pydantic
-
-dateutils <0.8.0
-Jinja2 <3.2.0
-PyYAML <=6.0.1
-requests <2.32.0
-rich >=12.3.0, <13.6.0
-urllib3 <2.1.0
-uvicorn <0.24.0
-websocket-client <1.7.0
-websockets <11.1.0
diff --git a/requirements/app/cloud.txt b/requirements/app/cloud.txt
deleted file mode 100644
index ad5d2d583d17f..0000000000000
--- a/requirements/app/cloud.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-redis >=4.0.1, <5.1.0
-docker >=5.0.0, <6.1.4
-s3fs >=2022.5.0, <2023.6.1
-# setuptools==59.5.0
diff --git a/requirements/app/components.txt b/requirements/app/components.txt
deleted file mode 100644
index 78509b6b0269e..0000000000000
--- a/requirements/app/components.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-# deps required by components in the lightning app repository (src/lightning/app/components)
-lightning_api_access >=0.0.3 # serve
-aiohttp >=3.8.0, <3.9.0 # auto_scaler
-lightning-fabric >=1.9.0 # multinode
-pytorch-lightning >=1.9.0 # multinode
diff --git a/requirements/app/docs.txt b/requirements/app/docs.txt
deleted file mode 100644
index f2db5000b9113..0000000000000
--- a/requirements/app/docs.txt
+++ /dev/null
@@ -1 +0,0 @@
--r ../docs.txt
diff --git a/requirements/app/test.txt b/requirements/app/test.txt
deleted file mode 100644
index fd9629649c89b..0000000000000
--- a/requirements/app/test.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-coverage ==7.3.1
-pytest ==7.4.0
-pytest-timeout ==2.1.0
-pytest-cov ==4.1.0
-pytest-doctestplus ==1.0.0
-pytest-asyncio ==0.21.1
-# pytest-random-order ==1.1.0
-pytest-rerunfailures ==12.0
-pytest-xdist ==3.3.1
-
-playwright ==1.38.0
-httpx ==0.25.0
-trio <0.22.0 # strict https://github.com/python-trio/trio/pull/2213
-pympler
-psutil <5.10.0
-setuptools <68.3.0
-requests-mock ==1.11.0
-pandas
diff --git a/requirements/app/ui.txt b/requirements/app/ui.txt
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/requirements/ci.txt b/requirements/ci.txt
index 08c2bd41148ec..cdebc301790e9 100644
--- a/requirements/ci.txt
+++ b/requirements/ci.txt
@@ -1,6 +1,7 @@
-setuptools
-wheel
+setuptools <70.1.1
+wheel <0.44.0
awscli >=1.30.0, <1.31.0
twine ==4.0.1
+importlib-metadata <8.0.0
wget
-packaging
+packaging <24.2
diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt
index c57d24a49e583..0a99614a46870 100644
--- a/requirements/fabric/base.txt
+++ b/requirements/fabric/base.txt
@@ -1,9 +1,8 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-numpy >=1.17.2, <1.27.0
-torch >=1.13.0, <2.3.0
-fsspec[http] >=2022.5.0, <2023.11.0
+torch >=2.1.0, <2.5.0
+fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
-lightning-utilities >=0.8.0, <0.12.0
+lightning-utilities >=0.10.0, <0.12.0
diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt
index e077065766b76..cb4135da2409a 100644
--- a/requirements/fabric/examples.txt
+++ b/requirements/fabric/examples.txt
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-torchvision >=0.14.0, <0.18.0
+torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0
diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt
index 6c302f21269e3..394aceb39cd6b 100644
--- a/requirements/fabric/strategies.txt
+++ b/requirements/fabric/strategies.txt
@@ -5,5 +5,6 @@
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
-deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
-bitsandbytes >=0.42.0,<0.43.0
+deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict
+bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
+bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'
diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt
index 3459e7dc87461..8fb9122051eec 100644
--- a/requirements/fabric/test.txt
+++ b/requirements/fabric/test.txt
@@ -1,4 +1,5 @@
coverage ==7.3.1
+numpy >=1.17.2, <1.27.0
pytest ==7.4.0
pytest-cov ==4.1.0
pytest-timeout ==2.1.0
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index ed4250bb3832b..6ff628d7edfb5 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -1,12 +1,11 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-numpy >=1.17.2, <1.27.0
-torch >=1.13.0, <2.3.0
+torch >=2.1.0, <2.5.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
-fsspec[http] >=2022.5.0, <2023.11.0
+fsspec[http] >=2022.5.0, <2024.4.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
-lightning-utilities >=0.8.0, <0.12.0
+lightning-utilities >=0.10.0, <0.12.0
diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt
index 56b7971eb61b0..9a6ae7e47dfb8 100644
--- a/requirements/pytorch/examples.txt
+++ b/requirements/pytorch/examples.txt
@@ -2,8 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
requests <2.32.0
-torchvision >=0.14.0, <0.18.0
-gym[classic_control] >=0.17.0, <0.27.0
+torchvision >=0.16.0, <0.20.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0
diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt
index 8df61e4834769..12bbdf5a70ab0 100644
--- a/requirements/pytorch/extra.txt
+++ b/requirements/pytorch/extra.txt
@@ -3,9 +3,10 @@
# extended list of package dependencies to reach full functionality
matplotlib>3.1, <3.9.0
-omegaconf >=2.0.5, <2.4.0
-hydra-core >=1.0.5, <1.4.0
+omegaconf >=2.2.3, <2.4.0
+hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.27.7, <4.28.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
-bitsandbytes >=0.42.0,<0.43.0
+bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
+bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'
diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt
index 751ca213d3b53..8d3af408a98fe 100644
--- a/requirements/pytorch/strategies.txt
+++ b/requirements/pytorch/strategies.txt
@@ -3,4 +3,4 @@
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
-deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
+deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict
diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt
index 94a06630df61b..4e1da300dd2fd 100644
--- a/requirements/pytorch/test.txt
+++ b/requirements/pytorch/test.txt
@@ -8,8 +8,9 @@ pytest-random-order ==1.1.0
# needed in tests
cloudpickle >=1.3, <2.3.0
scikit-learn >0.22.1, <1.4.0
-onnx >=0.14.0, <1.15.0
-onnxruntime >=0.15.0, <1.17.0
+numpy >=1.17.2, <1.27.0
+onnx >=1.12.0, <1.17.0
+onnxruntime >=1.12.0, <1.19.0
psutil <5.9.6 # for `DeviceStatsMonitor`
pandas >1.0, <2.2.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
diff --git a/requirements/store/test.txt b/requirements/store/test.txt
deleted file mode 100644
index d30343b08a628..0000000000000
--- a/requirements/store/test.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-coverage ==7.3.1
-pytest ==7.4.0
-pytest-cov ==4.1.0
-pytest-timeout ==2.1.0
-pytest-rerunfailures ==12.0
-pytest-random-order ==1.1.0
diff --git a/requirements/typing.txt b/requirements/typing.txt
index e8a2baaf8713a..0323edfd6098a 100644
--- a/requirements/typing.txt
+++ b/requirements/typing.txt
@@ -1,5 +1,5 @@
-mypy==1.8.0
-torch==2.2.0
+mypy==1.11.0
+torch==2.4.1
types-Markdown
types-PyYAML
diff --git a/setup.py b/setup.py
index 698eaa5abe71e..bfc329bb8fe88 100755
--- a/setup.py
+++ b/setup.py
@@ -17,7 +17,7 @@
There are considered three main scenarios for installing this project:
-1. Using PyPI registry when you can install `pytorch-lightning`, `lightning-app`, etc. or `lightning` for all.
+1. Using PyPI registry when you can install `pytorch-lightning`, etc. or `lightning` for all.
2. Installation from source code after cloning repository.
In such case we recommend to use command `pip install .` or `pip install -e .` for development version
@@ -26,12 +26,11 @@
- for `pytorch-lightning` use `export PACKAGE_NAME=pytorch ; pip install .`
- for `lightning-fabric` use `export PACKAGE_NAME=fabric ; pip install .`
- - for `lightning-app` use `export PACKAGE_NAME=app ; pip install .`
3. Building packages as sdist or binary wheel and installing or publish to PyPI afterwords you use command
`python setup.py sdist` or `python setup.py bdist_wheel` accordingly.
In case you want to build just a particular package you want to set an environment variable:
- `PACKAGE_NAME=lightning|pytorch|app|fabric python setup.py sdist|bdist_wheel`
+ `PACKAGE_NAME=lightning|pytorch|fabric python setup.py sdist|bdist_wheel`
4. Automated releasing with GitHub action is natural extension of 3) is composed of three consecutive steps:
a) determine which packages shall be released based on version increment in `__version__.py` and eventually
@@ -57,7 +56,6 @@
_PACKAGE_MAPPING = {
"lightning": "lightning",
"pytorch": "pytorch_lightning",
- "app": "lightning_app",
"fabric": "lightning_fabric",
}
# https://packaging.python.org/guides/single-sourcing-package-version/
diff --git a/src/app-ui-version.info b/src/app-ui-version.info
deleted file mode 100644
index ae39fab35ff1f..0000000000000
--- a/src/app-ui-version.info
+++ /dev/null
@@ -1 +0,0 @@
-v0.0.0
diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py
index c191334d2c218..1b054ed6715f7 100644
--- a/src/lightning/__init__.py
+++ b/src/lightning/__init__.py
@@ -1,7 +1,6 @@
"""Root package info."""
import logging
-import sys
# explicitly don't set root logger's propagation and leave this to subpackages to manage
_logger = logging.getLogger(__name__)
@@ -31,19 +30,3 @@
"Fabric",
"__version__",
]
-
-
-def _cli_entry_point() -> None:
- from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache
-
- if not (
- ModuleAvailableCache("lightning.app")
- if RequirementCache("lightning-utilities<0.10.0")
- else RequirementCache(module="lightning.app")
- ):
- print("The `lightning` command requires additional dependencies: `pip install lightning[app]`")
- sys.exit(1)
-
- from lightning.app.cli.lightning_cli import main
-
- main()
diff --git a/src/lightning/__main__.py b/src/lightning/__main__.py
deleted file mode 100644
index 57b27ab968c82..0000000000000
--- a/src/lightning/__main__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning.app.cli.lightning_cli import main
-
-if __name__ == "__main__":
- main()
diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py
index 81eae48545180..09eab5601f443 100644
--- a/src/lightning/__setup__.py
+++ b/src/lightning/__setup__.py
@@ -5,7 +5,7 @@
from types import ModuleType
from typing import Any, Dict
-from setuptools import find_packages
+from setuptools import find_namespace_packages
_PROJECT_ROOT = "."
_SOURCE_ROOT = os.path.join(_PROJECT_ROOT, "src")
@@ -45,12 +45,8 @@ def _prepare_extras() -> Dict[str, Any]:
extras["fabric-dev"] = extras["fabric-all"] + extras["fabric-test"]
extras["pytorch-all"] = extras["pytorch-extra"] + extras["pytorch-strategies"] + extras["pytorch-examples"]
extras["pytorch-dev"] = extras["pytorch-all"] + extras["pytorch-test"]
- extras["app-extra"] = extras["app-app"] + extras["app-cloud"] + extras["app-ui"] + extras["app-components"]
- extras["app-all"] = extras["app-extra"]
- extras["app-dev"] = extras["app-all"] + extras["app-test"]
- extras["store-store"] = extras["app-app"] # todo: consider cutting/leaning this dependency
- # merge per-project extras of the same category, e.g. `app-test` + `fabric-test`
+ # merge per-project extras of the same category
for extra in list(extras):
name = "-".join(extra.split("-")[1:])
extras[name] = extras.get(name, []) + extras[extra]
@@ -74,17 +70,6 @@ def _setup_args() -> Dict[str, Any]:
_PROJECT_ROOT, homepage=about.__homepage__, version=version.version
)
- # TODO: remove this once lightning-ui package is ready as a dependency
- ui_ver_file = os.path.join(_SOURCE_ROOT, "app-ui-version.info")
- if os.path.isfile(ui_ver_file):
- with open(ui_ver_file, encoding="utf-8") as fo:
- ui_version = fo.readlines()[0].strip()
- download_fe_version = {"version": ui_version}
- else:
- print(f"Missing file with FE version: {ui_ver_file}")
- download_fe_version = {}
- _ASSISTANT._download_frontend(os.path.join(_PACKAGE_ROOT, "app"), **download_fe_version)
-
# TODO: consider invaliding some additional arguments from packages, for example if include data or safe to zip
install_requires = _ASSISTANT.load_requirements(
@@ -102,19 +87,18 @@ def _setup_args() -> Dict[str, Any]:
"url": about.__homepage__,
"download_url": "https://github.com/Lightning-AI/lightning",
"license": about.__license__,
- "packages": find_packages(where="src", include=["lightning", "lightning.*"]),
+ "packages": find_namespace_packages(where="src", include=["lightning", "lightning.*"]),
"package_dir": {"": "src"},
"long_description": long_description,
"long_description_content_type": "text/markdown",
"include_package_data": True,
"zip_safe": False,
- "keywords": ["deep learning", "pytorch", "AI"], # todo: aggregate tags from all packages
- "python_requires": ">=3.8", # todo: take the lowes based on all packages
+ "keywords": ["deep learning", "pytorch", "AI"],
+ "python_requires": ">=3.9",
"entry_points": {
"console_scripts": [
"fabric = lightning.fabric.cli:_main",
"lightning = lightning.fabric.cli:_legacy_main",
- "lightning_app = lightning:_cli_entry_point",
],
},
"setup_requires": [],
@@ -140,9 +124,9 @@ def _setup_args() -> Dict[str, Any]:
"Operating System :: OS Independent",
# Specify the Python versions you support here.
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
], # todo: consider aggregation/union of tags from particular packages
}
diff --git a/src/lightning/app/CHANGELOG.md b/src/lightning/app/CHANGELOG.md
deleted file mode 100644
index d09c302b22067..0000000000000
--- a/src/lightning/app/CHANGELOG.md
+++ /dev/null
@@ -1,608 +0,0 @@
-# Changelog
-
-All notable changes to this project will be documented in this file.
-
-The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-
-## [2.2.0] - 2024-02-09
-
-## Changed
-
-- Renames the `lightning` cli to `lightning_app` ([#19440](https://github.com/Lightning-AI/pytorch-lightning/pull/19440))
-
-
-## [2.1.4] - 2024-01-31
-
-### Changed
-
-- Remove torch distributed for the Dataset Optimizer ([#19182](https://github.com/Lightning-AI/lightning/pull/19182))
-
-
-## [2.1.3] - 2023-12-21
-
-### Changed
-
-- Lightning App: Use the batch get endpoint ([#19180](https://github.com/Lightning-AI/lightning/pull/19180))
-- Drop starsessions from App's requirements ([#18470](https://github.com/Lightning-AI/lightning/pull/18470))
-- Optimize loading time for chunks to be there ([#19109](https://github.com/Lightning-AI/lightning/pull/19109))
-
-
-## [2.1.2] - 2023-11-15
-
-### Changed
-
-- Forced plugin server to use localhost ([#18976](https://github.com/Lightning-AI/lightning/pull/18976))
-- Enabled bundling additional files into app source ([#18980](https://github.com/Lightning-AI/lightning/pull/18980))
-- Limited rate of requests to http queue ([#18981](https://github.com/Lightning-AI/lightning/pull/18981))
-
-
-## [2.1.1] - 2023-11-06
-
-### Added
-
-- Added flow `fail()` ([#18883](https://github.com/Lightning-AI/lightning/pull/18883))
-
-### Fixed
-
-- Fixed failing lightning cli entry point ([#18821](https://github.com/Lightning-AI/lightning/pull/18821))
-
-
-## [2.1.0] - 2023-10-11
-
-### Added
-
-- Allow customizing `gradio` components with lightning colors ([#17054](https://github.com/Lightning-AI/lightning/pull/17054))
-
-### Changed
-
-- Changed `LocalSourceCodeDir` cache_location to not use home in some certain cases ([#17491](https://github.com/Lightning-AI/lightning/pull/17491))
-
-### Removed
-
-- Remove cluster commands from the CLI ([#18151](https://github.com/Lightning-AI/lightning/pull/18151))
-
-
-## [2.0.9] - 2023-09-14
-
-### Fixed
-
-- Replace LightningClient with import from lightning_cloud ([#18544](https://github.com/Lightning-AI/lightning/pull/18544))
-
-
-## [2.0.8] - 2023-08-29
-
-## Changed
-
-- Change top folder ([#18212](https://github.com/Lightning-AI/lightning/pull/18212))
-- Remove `_handle_is_headless` calls in app run loop ([#18362](https://github.com/Lightning-AI/lightning/pull/18362))
-
-
-## [2.0.7] - 2023-08-14
-
-### Changed
-
-- Removed the top-level import `lightning.pdb`; import `lightning.app.pdb` instead ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
-- Client retries forever ([#18065](https://github.com/Lightning-AI/lightning/pull/18065))
-
-### Fixed
-
-- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
-
-
-## [2.0.6] - 2023-07-20
-
-### Fixed
-
-- Fixed handling a `None` request in the file orchestration queue ([#18111](https://github.com/Lightning-AI/lightning/pull/18111))
-
-
-## [2.0.5] - 2023-07-07
-
-### Added
-
-- plugin: store source app ([#17892](https://github.com/Lightning-AI/lightning/pull/17892))
-- added colocation identifier ([#16796](https://github.com/Lightning-AI/lightning/pull/16796))
-- Added exponential backoff to HTTPQueue put ([#18013](https://github.com/Lightning-AI/lightning/pull/18013))
-- Content for plugins ([#17243](https://github.com/Lightning-AI/lightning/pull/17243))
-
-### Changed
-
-- Save a reference to created tasks, to avoid tasks disappearing ([#17946](https://github.com/Lightning-AI/lightning/pull/17946))
-
-
-## [2.0.4] - 2023-06-22
-
-### Fixed
-
-- bumped several dependencies to address security vulnerabilities.
-
-
-## [2.0.3] - 2023-06-07
-
-- Added the property `LightningWork.public_ip` that exposes the public IP of the `LightningWork` instance ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
-- Add missing python-multipart dependency ([#17244](https://github.com/Lightning-AI/lightning/pull/17244))
-
-### Changed
-
-- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100))
-
-### Fixed
-
-- Fixed `LightningWork.internal_ip` that was mistakenly exposing the public IP instead; now exposes the private/internal IP address ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
-- Fixed resolution of latest version in CLI ([#17351](https://github.com/Lightning-AI/lightning/pull/17351))
-- Fixed property raised instead of returned ([#17595](https://github.com/Lightning-AI/lightning/pull/17595))
-- Fixed get project ([#17617](https://github.com/Lightning-AI/lightning/pull/17617), [#17666](https://github.com/Lightning-AI/lightning/pull/17666))
-
-
-## [2.0.2] - 2023-04-24
-
-### Fixed
-
-- Resolved Lightning App with remote storage ([#17426](https://github.com/Lightning-AI/lightning/pull/17426))
-- Fixed `AppState`, streamlit example ([#17452](https://github.com/Lightning-AI/lightning/pull/17452))
-
-
-## [2.0.1] - 2023-04-11
-
-### Fixed
-
-- Fix frontend hosts when running with multi-process in the cloud ([#17324](https://github.com/Lightning-AI/lightning/pull/17324))
-
-
-## [2.0.0] - 2023-03-15
-
-### Added
-
-- Added `--zip` option to the `lightning cp` command to copy content from the Cloud Platform Filesystem as a zipfile
-
-### Changed
-
-- Changed minimum supported version of `rich` from `10.14.0` to `12.13.0` ([#16798](https://github.com/Lightning-AI/lightning/pull/16798))
-
-### Removed
-
-- Removed support for Python 3.7 ([#16579](https://github.com/Lightning-AI/lightning/pull/16579))
-
-
-## [1.9.4] - 2023-03-01
-
-### Removed
-
-- Removed implicit ui testing with `testing.run_app_in_cloud` in favor of headless login and app selection ([#16741](https://github.com/Lightning-AI/lightning/pull/16741))
-
-
-## [1.9.3] - 2023-02-21
-
-### Fixed
-
-- Fixed `lightning open` command and improved redirects ([#16794](https://github.com/Lightning-AI/lightning/pull/16794))
-
-
-## [1.9.2] - 2023-02-15
-
-- Added Storage Commands ([#16740](https://github.com/Lightning-AI/lightning/pull/16740))
- * `rm`: Delete files from your Cloud Platform Filesystem
-- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738))
-
-
-## [1.9.1] - 2023-02-10
-
-### Added
-- Added `lightning open` command ([#16482](https://github.com/Lightning-AI/lightning/pull/16482))
-- Added experimental support for interruptible GPU in the cloud ([#16399](https://github.com/Lightning-AI/lightning/pull/16399))
-- Added FileSystem abstraction to simply manipulation of files ([#16581](https://github.com/Lightning-AI/lightning/pull/16581))
-- Added Storage Commands ([#16606](https://github.com/Lightning-AI/lightning/pull/16606))
- * `ls`: List files from your Cloud Platform Filesystem
- * `cd`: Change the current directory within your Cloud Platform filesystem (terminal session based)
- * `pwd`: Return the current folder in your Cloud Platform Filesystem
- * `cp`: Copy files between your Cloud Platform Filesystem and local filesystem
-- Prevent to `cd` into non existent folders ([#16645](https://github.com/Lightning-AI/lightning/pull/16645))
-- Enabled `cp` (upload) at project level ([#16631](https://github.com/Lightning-AI/lightning/pull/16631))
-- Enabled `ls` and `cp` (download) at project level ([#16622](https://github.com/Lightning-AI/lightning/pull/16622))
-- Added `lightning connect data` to register data connection to s3 buckets ([#16670](https://github.com/Lightning-AI/lightning/pull/16670))
-- Added support for running with multiprocessing in the cloud ([#16624](https://github.com/Lightning-AI/lightning/pull/16624))
-- Initial plugin server ([#16523](https://github.com/Lightning-AI/lightning/pull/16523))
-- Connect and Disconnect node ([#16700](https://github.com/Lightning-AI/lightning/pull/16700))
-
-### Changed
-
-- Changed the default `LightningClient(retry=False)` to `retry=True` ([#16382](https://github.com/Lightning-AI/lightning/pull/16382))
-- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453))
-- Renamed `lightning.app.components.LiteMultiNode` to `lightning.app.components.FabricMultiNode` ([#16505](https://github.com/Lightning-AI/lightning/pull/16505))
-- Changed the command `lightning connect` to `lightning connect app` for consistency ([#16670](https://github.com/Lightning-AI/lightning/pull/16670))
-- Refactor cloud dispatch and update to new API ([#16456](https://github.com/Lightning-AI/lightning/pull/16456))
-- Updated app URLs to the latest format ([#16568](https://github.com/Lightning-AI/lightning/pull/16568))
-
-### Fixed
-
-- Fixed a deadlock causing apps not to exit properly when running locally ([#16623](https://github.com/Lightning-AI/lightning/pull/16623))
-- Fixed the Drive root_folder not parsed properly ([#16454](https://github.com/Lightning-AI/lightning/pull/16454))
-- Fixed malformed path when downloading files using `lightning cp` ([#16626](https://github.com/Lightning-AI/lightning/pull/16626))
-- Fixed app name in URL ([#16575](https://github.com/Lightning-AI/lightning/pull/16575))
-
-
-## [1.9.0] - 2023-01-17
-
-### Added
-
-- Added a possibility to set up basic authentication for Lightning apps ([#16105](https://github.com/Lightning-AI/lightning/pull/16105))
-
-### Changed
-
-- The LoadBalancer now uses internal ip + port instead of URL exposed ([#16119](https://github.com/Lightning-AI/lightning/pull/16119))
-- Added support for logging in different trainer stages with `DeviceStatsMonitor` ([#16002](https://github.com/Lightning-AI/lightning/pull/16002))
-- Changed `lightning.app.components.serve.gradio` to `lightning.app.components.serve.gradio_server` ([#16201](https://github.com/Lightning-AI/lightning/pull/16201))
-- Made cluster creation/deletion async by default ([#16185](https://github.com/Lightning-AI/lightning/pull/16185))
-- Expose `LightningFlow.stop` method to stop the flow similar to works ([##16378](https://github.com/Lightning-AI/lightning/pull/16378))
-
-### Fixed
-
-- Fixed not being able to run multiple lightning apps locally due to port collision ([#15819](https://github.com/Lightning-AI/lightning/pull/15819))
-- Avoid `relpath` bug on Windows ([#16164](https://github.com/Lightning-AI/lightning/pull/16164))
-- Avoid using the deprecated `LooseVersion` ([#16162](https://github.com/Lightning-AI/lightning/pull/16162))
-- Porting fixes to autoscaler component ([#16249](https://github.com/Lightning-AI/lightning/pull/16249))
-- Fixed a bug where `lightning login` with env variables would not correctly save the credentials ([#16339](https://github.com/Lightning-AI/lightning/pull/16339))
-
-
-## [1.8.6] - 2022-12-21
-
-### Added
-
-- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
-- Added a nicer UI with URL and examples for the autoscaler component ([#16063](https://github.com/Lightning-AI/lightning/pull/16063))
-- Enabled users to have more control over scaling out/in interval ([#16093](https://github.com/Lightning-AI/lightning/pull/16093))
-- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))
-- Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103))
-- Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095))
-- Added `ColdStartProxy` to the AutoScaler ([#16094](https://github.com/Lightning-AI/lightning/pull/16094))
-- Added status endpoint, enable `ready` ([#16075](https://github.com/Lightning-AI/lightning/pull/16075))
-- Implemented `ready` for components ([#16129](https://github.com/Lightning-AI/lightning/pull/16129))
-
-### Changed
-
-- The default `start_method` for creating Work processes locally on MacOS is now 'spawn' (previously 'fork') ([#16089](https://github.com/Lightning-AI/lightning/pull/16089))
-- The utility `lightning.app.utilities.cloud.is_running_in_cloud` now returns `True` during loading of the app locally when running with `--cloud` ([#16045](https://github.com/Lightning-AI/lightning/pull/16045))
-- Updated Multinode Warning ([#16091](https://github.com/Lightning-AI/lightning/pull/16091))
-- Updated app testing ([#16000](https://github.com/Lightning-AI/lightning/pull/16000))
-- Changed overwrite to `True` ([#16009](https://github.com/Lightning-AI/lightning/pull/16009))
-- Simplified messaging in cloud dispatch ([#16160](https://github.com/Lightning-AI/lightning/pull/16160))
-- Added annotations endpoint ([#16159](https://github.com/Lightning-AI/lightning/pull/16159))
-
-### Fixed
-
-- Fixed `PythonServer` messaging "Your app has started" ([#15989](https://github.com/Lightning-AI/lightning/pull/15989))
-- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110))
-- Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092)
-- Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))
-- Fixed Http Queue sleeping for 1 sec by default if no delta were found ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))
-- Fixed the endpoint info tab not showing up in `AutoScaler` UI ([#16128](https://github.com/Lightning-AI/lightning/pull/16128))
-- Fixed an issue where an exception would be raised in the logs when using a recent version of streamlit ([#16139](https://github.com/Lightning-AI/lightning/pull/16139))
-- Fixed e2e tests ([#16146](https://github.com/Lightning-AI/lightning/pull/16146))
-
-
-## [1.8.5] - 2022-12-15
-
-### Added
-
-- Added `Lightning{Flow,Work}.lightningignores` attributes to programmatically ignore files before uploading to the cloud ([#15818](https://github.com/Lightning-AI/lightning/pull/15818))
-- Added a progress bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))
-- Support running on multiple clusters ([#16016](https://github.com/Lightning-AI/lightning/pull/16016))
-- Added guards to cluster deletion from cli ([#16053](https://github.com/Lightning-AI/lightning/pull/16053))
-
-### Changed
-
-- Cleanup cluster waiting ([#16054](https://github.com/Lightning-AI/lightning/pull/16054))
-
-### Fixed
-
-- Fixed `DDPStrategy` import in app framework ([#16029](https://github.com/Lightning-AI/lightning/pull/16029))
-- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991))
-- Fixed and improvements of login flow ([#16052](https://github.com/Lightning-AI/lightning/pull/16052))
-- Fixed the debugger detection mechanism for lightning App in VSCode ([#16068](https://github.com/Lightning-AI/lightning/pull/16068))
-- Fixed bug where components that are re-instantiated several times failed to initialize if they were modifying `self.lightningignore` ([#16080](https://github.com/Lightning-AI/lightning/pull/16080))
-- Fixed a bug where apps that had previously been deleted could not be run again from the CLI ([#16082](https://github.com/Lightning-AI/lightning/pull/16082))
-- Fixed install/upgrade - removing single quote ([#16079](https://github.com/Lightning-AI/lightning/pull/16079))
-
-
-## [1.8.4] - 2022-12-08
-
-### Added
-
-- Add `code_dir` argument to tracer run ([#15771](https://github.com/Lightning-AI/lightning/pull/15771))
-- Added the CLI command `lightning run model` to launch a `LightningLite` accelerated script ([#15506](https://github.com/Lightning-AI/lightning/pull/15506))
-- Added the CLI command `lightning delete app` to delete a lightning app on the cloud ([#15783](https://github.com/Lightning-AI/lightning/pull/15783))
-- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))
-- Utility for pickling work object safely even from a child process ([#15836](https://github.com/Lightning-AI/lightning/pull/15836))
-- Added `AutoScaler` component (
- [#15769](https://github.com/Lightning-AI/lightning/pull/15769),
- [#15971](https://github.com/Lightning-AI/lightning/pull/15971),
- [#15966](https://github.com/Lightning-AI/lightning/pull/15966)
-)
-- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
-- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))
-- Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926))
-- Added the ability to run a Lightning App or Component directly from the Gallery using `lightning run app organization/name` ([#15941](https://github.com/Lightning-AI/lightning/pull/15941))
-- Added automatic conversion of list and dict of works and flows to structures ([#15961](https://github.com/Lightning-AI/lightning/pull/15961))
-
-### Changed
-
-- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806))
-- Cluster creation and deletion now waits by default [#15458](https://github.com/Lightning-AI/lightning/pull/15458)
-- Running an app without a UI locally no longer opens the browser ([#15875](https://github.com/Lightning-AI/lightning/pull/15875))
-- Show a message when `BuildConfig(requirements=[...])` is passed but a `requirements.txt` file is already present in the Work ([#15799](https://github.com/Lightning-AI/lightning/pull/15799))
-- Show a message when `BuildConfig(dockerfile="...")` is passed but a `Dockerfile` file is already present in the Work ([#15799](https://github.com/Lightning-AI/lightning/pull/15799))
-- Dropped name column from cluster list ([#15721](https://github.com/Lightning-AI/lightning/pull/15721))
-- Apps without UIs no longer activate the "Open App" button when running in the cloud ([#15875](https://github.com/Lightning-AI/lightning/pull/15875))
-- Wait for full file to be transferred in Path / Payload ([#15934](https://github.com/Lightning-AI/lightning/pull/15934))
-
-### Removed
-
-- Removed the `SingleProcessRuntime` ([#15933](https://github.com/Lightning-AI/lightning/pull/15933))
-
-### Fixed
-
-- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))
-- Fixed bug when launching apps on multiple clusters ([#15484](https://github.com/Lightning-AI/lightning/pull/15484))
-- Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881))
-- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Lightning-AI/lightning/pull/15748))
-- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))
-- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))
-- Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812))
-- Fixed require/import decorator ([#15849](https://github.com/Lightning-AI/lightning/pull/15849))
-- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911))
-- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))
-- Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950))
-- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
-- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))
-- Fixed MultiNode Component to use separate cloud computes ([#15965](https://github.com/Lightning-AI/lightning/pull/15965))
-- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))
-- Fixed a bug where auto-upgrading to the latest lightning via the CLI could get stuck in a loop ([#15984](https://github.com/Lightning-AI/lightning/pull/15984))
-
-
-## [1.8.3] - 2022-11-22
-
-### Changed
-
-- Deduplicate top level lightning CLI command groups ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
- * `lightning add ssh-key` CLI command has been transitioned to `lightning create ssh-key`
- * `lightning remove ssh-key` CLI command has been transitioned to `lightning delete ssh-key`
-- Set Torch inference mode for prediction ([#15719](https://github.com/Lightning-AI/lightning/pull/15719))
-- Improved `LightningTrainerScript` start-up time ([#15751](https://github.com/Lightning-AI/lightning/pull/15751))
-- Disable XSRF protection in `StreamlitFrontend` to support upload in localhost ([#15684](https://github.com/Lightning-AI/lightning/pull/15684))
-
-### Fixed
-
-- Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747))
-- Fixed setting property to the `LightningFlow` ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))
-- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))
-
-
-## [1.8.2] - 2022-11-17
-
-### Added
-
-- Added title and description to ServeGradio ([#15639](https://github.com/Lightning-AI/lightning/pull/15639))
-- Added a friendly error message when attempting to run the default cloud compute with a custom base image configured ([#14929](https://github.com/Lightning-AI/lightning/pull/14929))
-
-### Changed
-
-- Improved support for running apps when dependencies aren't installed ([#15711](https://github.com/Lightning-AI/lightning/pull/15711))
-- Changed the root directory of the app (which gets uploaded) to be the folder containing the app file, rather than any parent folder containing a `.lightning` file ([#15654](https://github.com/Lightning-AI/lightning/pull/15654))
-- Enabled MultiNode Components to support state broadcasting ([#15607](https://github.com/Lightning-AI/lightning/pull/15607))
-- Prevent artefactual "running from outside your current environment" error ([#15647](https://github.com/Lightning-AI/lightning/pull/15647))
-- Rename failed -> error in tables ([#15608](https://github.com/Lightning-AI/lightning/pull/15608))
-
-### Fixed
-
-- Fixed race condition to over-write the frontend with app infos ([#15398](https://github.com/Lightning-AI/lightning/pull/15398))
-- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))
-- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))
-- Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712))
-- Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714))
-
-
-## [1.8.1] - 2022-11-10
-
-### Added
-
-- Added the `start` method to the work ([#15523](https://github.com/Lightning-AI/lightning/pull/15523))
-- Added a `MultiNode` Component to run with distributed computation with any frameworks ([#15524](https://github.com/Lightning-AI/lightning/pull/15524))
-- Expose `RunWorkExecutor` to the work and provides default ones for the `MultiNode` Component ([#15561](https://github.com/Lightning-AI/lightning/pull/15561))
-- Added a `start_with_flow` flag to the `LightningWork` which can be disabled to prevent the work from starting at the same time as the flow ([#15591](https://github.com/Lightning-AI/lightning/pull/15591))
-- Added support for running Lightning App with VSCode IDE debugger ([#15590](https://github.com/Lightning-AI/lightning/pull/15590))
-- Added `bi-directional` delta updates between the flow and the works ([#15582](https://github.com/Lightning-AI/lightning/pull/15582))
-- Added `--setup` flag to `lightning run app` CLI command allowing for dependency installation via app comments ([#15577](https://github.com/Lightning-AI/lightning/pull/15577))
-- Auto-upgrade / detect environment mismatch from the CLI ([#15434](https://github.com/Lightning-AI/lightning/pull/15434))
-- Added Serve component ([#15609](https://github.com/Lightning-AI/lightning/pull/15609))
-
-
-### Changed
-
-- Changed the `flow.flows` to be recursive won't to align the behavior with the `flow.works` ([#15466](https://github.com/Lightning-AI/lightning/pull/15466))
-- The `params` argument in `TracerPythonScript.run` no longer prepends `--` automatically to parameters ([#15518](https://github.com/Lightning-AI/lightning/pull/15518))
-- Only check versions / env when not in the cloud ([#15504](https://github.com/Lightning-AI/lightning/pull/15504))
-- Periodically sync database to the drive ([#15441](https://github.com/Lightning-AI/lightning/pull/15441))
-- Slightly safer multi node ([#15538](https://github.com/Lightning-AI/lightning/pull/15538))
-- Reuse existing commands when running connect more than once ([#15471](https://github.com/Lightning-AI/lightning/pull/15471))
-
-### Fixed
-
-- Fixed writing app name and id in connect.txt file for the command CLI ([#15443](https://github.com/Lightning-AI/lightning/pull/15443))
-- Fixed missing root flow among the flows of the app ([#15531](https://github.com/Lightning-AI/lightning/pull/15531))
-- Fixed bug with Multi Node Component and add some examples ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))
-- Fixed a bug where payload would take a very long time locally ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))
-- Fixed an issue with the `lightning` CLI taking a long time to error out when the cloud is not reachable ([#15412](https://github.com/Lightning-AI/lightning/pull/15412))
-
-
-## [1.8.0] - 2022-11-01
-
-### Added
-
-- Added `load_state_dict` and `state_dict` hooks for `LightningFlow` components ([#14100](https://github.com/Lightning-AI/lightning/pull/14100))
-- Added a `--secret` option to CLI to allow binding secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612))
-- Added support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819))
-- Added an HTTPQueue as an optional replacement for the default redis queue ([#14978](https://github.com/Lightning-AI/lightning/pull/14978)
-- Added support for configuring flow cloud compute ([#14831](https://github.com/Lightning-AI/lightning/pull/14831))
-- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
-- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
-- Added an Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
-- Added authentication to HTTP queue ([#15202](https://github.com/Lightning-AI/lightning/pull/15202))
-- Added support to pass a `LightningWork` to the `LightningApp` ([#15215](https://github.com/Lightning-AI/lightning/pull/15215)
-- Added support getting CLI help for connected apps even if the app isn't running ([#15196](https://github.com/Lightning-AI/lightning/pull/15196)
-- Added support for adding requirements to commands and installing them when missing when running an app command ([#15198](https://github.com/Lightning-AI/lightning/pull/15198)
-- Added Lightning CLI Connection to be terminal session instead of global ([#15241](https://github.com/Lightning-AI/lightning/pull/15241)
-- Added support for managing SSH-keys via CLI ([#15291](https://github.com/Lightning-AI/lightning/pull/15291))
-- Add a `JustPyFrontend` to ease UI creation with `https://github.com/justpy-org/justpy` ([#15002](https://github.com/Lightning-AI/lightning/pull/15002))
-- Added a layout endpoint to the Rest API and enable to disable pulling or pushing to the state ([#15367](https://github.com/Lightning-AI/lightning/pull/15367)
-- Added support for functions for `configure_api` and `configure_commands` to be executed in the Rest API process ([#15098](https://github.com/Lightning-AI/lightning/pull/15098)
-- Added support for accessing Lightning Apps via SSH ([#15310](https://github.com/Lightning-AI/lightning/pull/15310))
-- Added support to start lightning app on cloud without needing to install dependencies locally ([#15019](https://github.com/Lightning-AI/lightning/pull/15019)
-
-### Changed
-
-- Improved the show logs command to be standalone and reusable ([#15343](https://github.com/Lightning-AI/lightning/pull/15343)
-- Removed the `--instance-types` option when creating clusters ([#15314](https://github.com/Lightning-AI/lightning/pull/15314))
-
-### Fixed
-
-- Fixed an issue when using the CLI without arguments ([#14877](https://github.com/Lightning-AI/lightning/pull/14877))
-- Fixed a bug where the upload files endpoint would raise an error when running locally ([#14924](https://github.com/Lightning-AI/lightning/pull/14924))
-- Fixed BYOC cluster region selector -> hiding it from help since only us-east-1 has been tested and is recommended ([#15277]https://github.com/Lightning-AI/lightning/pull/15277)
-- Fixed a bug when launching an app on multiple clusters ([#15226](https://github.com/Lightning-AI/lightning/pull/15226))
-- Fixed a bug with a default CloudCompute for Lightning flows ([#15371](https://github.com/Lightning-AI/lightning/pull/15371))
-
-## [0.6.2] - 2022-09-21
-
-### Changed
-
-- Improved Lightning App connect logic by disconnecting automatically ([#14532](https://github.com/Lightning-AI/lightning/pull/14532))
-- Improved the error message when the `LightningWork` is missing the `run` method ([#14759](https://github.com/Lightning-AI/lightning/pull/14759))
-- Improved the error message when the root `LightningFlow` passed to `LightningApp` is missing the `run` method ([#14760](https://github.com/Lightning-AI/lightning/pull/14760))
-
-### Fixed
-
-- Fixed a bug where the uploaded command file wasn't properly parsed ([#14532](https://github.com/Lightning-AI/lightning/pull/14532))
-- Fixed an issue where custom property setters were not being used `LightningWork` class ([#14259](https://github.com/Lightning-AI/lightning/pull/14259))
-- Fixed an issue where some terminals would display broken icons in the PL app CLI ([#14226](https://github.com/Lightning-AI/lightning/pull/14226))
-
-
-## [0.6.1] - 2022-09-19
-
-### Added
-
-- Add support to upload files to the Drive through an asynchronous `upload_file` endpoint ([#14703](https://github.com/Lightning-AI/lightning/pull/14703))
-
-### Changed
-
-- Application storage prefix moved from `app_id` to `project_id/app_id` ([#14583](https://github.com/Lightning-AI/lightning/pull/14583))
-- LightningCloud client calls to use keyword arguments instead of positional arguments ([#14685](https://github.com/Lightning-AI/lightning/pull/14685))
-
-### Fixed
-
-- Making `threadpool` non-default from LightningCloud client ([#14757](https://github.com/Lightning-AI/lightning/pull/14757))
-- Resolved a bug where the state change detection using DeepDiff won't work with Path, Drive objects ([#14465](https://github.com/Lightning-AI/lightning/pull/14465))
-- Resolved a bug where the wrong client was passed to collect cloud logs ([#14684](https://github.com/Lightning-AI/lightning/pull/14684))
-- Resolved the memory leak issue with the Lightning Cloud package and bumped the requirements to use the latest version ([#14697](https://github.com/Lightning-AI/lightning/pull/14697))
-- Fixing 5000 log line limitation for Lightning AI BYOC cluster logs ([#14458](https://github.com/Lightning-AI/lightning/pull/14458))
-- Fixed a bug where the uploaded command file wasn't properly parsed ([#14532](https://github.com/Lightning-AI/lightning/pull/14532))
-- Resolved `LightningApp(..., debug=True)` ([#14464](https://github.com/Lightning-AI/lightning/pull/14464))
-
-
-## [0.6.0] - 2022-09-08
-
-### Added
-
-- Introduce lightning connect ([#14452](https://github.com/Lightning-AI/lightning/pull/14452))
-- Adds `PanelFrontend` to easily create complex UI in Python ([#13531](https://github.com/Lightning-AI/lightning/pull/13531))
-- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
-- Add support for Lightning AI BYOC cluster management ([#13835](https://github.com/Lightning-AI/lightning/pull/13835))
-- Add support to see Lightning AI BYOC cluster logs ([#14334](https://github.com/Lightning-AI/lightning/pull/14334))
-- Add support to run Lightning apps on Lightning AI BYOC clusters ([#13894](https://github.com/Lightning-AI/lightning/pull/13894))
-- Add support for listing Lightning AI apps ([#13987](https://github.com/Lightning-AI/lightning/pull/13987))
-- Adds `LightningTrainerScript`. `LightningTrainerScript` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))
-- Add support for printing application logs using CLI `lightning show logs [components]` ([#13634](https://github.com/Lightning-AI/lightning/pull/13634))
-- Add support for `Lightning API` through the `configure_api` hook on the Lightning Flow and the `Post`, `Get`, `Delete`, `Put` HttpMethods ([#13945](https://github.com/Lightning-AI/lightning/pull/13945))
-- Added a warning when `configure_layout` returns URLs configured with http instead of https ([#14233](https://github.com/Lightning-AI/lightning/pull/14233))
-- Add `--app_args` support from the CLI ([#13625](https://github.com/Lightning-AI/lightning/pull/13625))
-
-### Changed
-
-- Default values and parameter names for Lightning AI BYOC cluster management ([#14132](https://github.com/Lightning-AI/lightning/pull/14132))
-- Run the flow only if the state has changed from the previous execution ([#14076](https://github.com/Lightning-AI/lightning/pull/14076))
-- Increased DeepDiff's verbose level to properly handle dict changes ([#13960](https://github.com/Lightning-AI/lightning/pull/13960))
-- Setup: added requirement freeze for next major version ([#14480](https://github.com/Lightning-AI/lightning/pull/14480))
-
-### Fixed
-
-- Unification of app template: moved `app.py` to root dir for `lightning init app ` template ([#13853](https://github.com/Lightning-AI/lightning/pull/13853))
-- Fixed an issue with `lightning --version` command ([#14433](https://github.com/Lightning-AI/lightning/pull/14433))
-- Fixed imports of collections.abc for py3.10 ([#14345](https://github.com/Lightning-AI/lightning/pull/14345))
-
-## [0.5.7] - 2022-08-22
-
-### Changed
-
-- Release LAI docs as stable ([#14250](https://github.com/Lightning-AI/lightning/pull/14250))
-- Compatibility for Python 3.10
-
-### Fixed
-
-- Pinning starsessions to 1.x ([#14333](https://github.com/Lightning-AI/lightning/pull/14333))
-- Parsed local package versions ([#13933](https://github.com/Lightning-AI/lightning/pull/13933))
-
-
-## [0.5.6] - 2022-08-16
-
-### Fixed
-
-- Resolved a bug where the `install` command was not installing the latest version of an app/component by default ([#14181](https://github.com/Lightning-AI/lightning/pull/14181))
-
-
-- Fixed the `examples/app_dag` example ([#14359](https://github.com/Lightning-AI/lightning/pull/14359))
-
-
-## [0.5.5] - 2022-08-9
-
-### Deprecated
-
-- Deprecate sheety API ([#14004](https://github.com/Lightning-AI/lightning/pull/14004))
-
-### Fixed
-
-- Resolved a bug where the work statuses will grow quickly and be duplicated ([#13970](https://github.com/Lightning-AI/lightning/pull/13970))
-- Resolved a bug about a race condition when sending the work state through the caller_queue ([#14074](https://github.com/Lightning-AI/lightning/pull/14074))
-- Fixed Start Lightning App on Cloud if Repo Begins With Name "Lightning" ([#14025](https://github.com/Lightning-AI/lightning/pull/14025))
-
-
-## [0.5.4] - 2022-08-01
-
-### Changed
-
-- Wrapped imports for traceability ([#13924](https://github.com/Lightning-AI/lightning/pull/13924))
-- Set version as today ([#13906](https://github.com/Lightning-AI/lightning/pull/13906))
-
-### Fixed
-
-- Included app templates to the lightning and app packages ([#13731](https://github.com/Lightning-AI/lightning/pull/13731))
-- Added UI for install all ([#13732](https://github.com/Lightning-AI/lightning/pull/13732))
-- Fixed build meta pkg flow ([#13926](https://github.com/Lightning-AI/lightning/pull/13926))
-
-## [0.5.3] - 2022-07-25
-
-### Changed
-
-- Pruned requirements duplicity ([#13739](https://github.com/Lightning-AI/lightning/pull/13739))
-
-### Fixed
-
-- Use correct python version in lightning component template ([#13790](https://github.com/Lightning-AI/lightning/pull/13790))
-
-## [0.5.2] - 2022-07-18
-
-### Added
-
-- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))
-
-### Changed
-
-- Added `LIGHTNING_` prefix to Platform AWS credentials ([#13703](https://github.com/Lightning-AI/lightning/pull/13703))
diff --git a/src/lightning/app/__init__.py b/src/lightning/app/__init__.py
deleted file mode 100644
index 5c904cc4a908c..0000000000000
--- a/src/lightning/app/__init__.py
+++ /dev/null
@@ -1,51 +0,0 @@
-"""Root package info."""
-
-import logging
-import os
-
-from lightning_utilities.core.imports import module_available, package_available
-
-_root_logger = logging.getLogger()
-_logger = logging.getLogger(__name__)
-_logger.setLevel(logging.INFO)
-
-_console = logging.StreamHandler()
-_console.setLevel(logging.INFO)
-
-formatter = logging.Formatter("%(levelname)s: %(message)s")
-_console.setFormatter(formatter)
-
-# if root logger has handlers, propagate messages up and let root logger process them,
-# otherwise use our own handler
-if not _root_logger.hasHandlers():
- _logger.addHandler(_console)
- _logger.propagate = False
-
-
-if os.path.isfile(os.path.join(os.path.dirname(__file__), "__about__.py")):
- from lightning.app.__about__ import * # noqa: F403
-if "__version__" not in locals():
- if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
- from lightning.app.__version__ import version as __version__
- elif package_available("lightning"):
- from lightning import __version__ # noqa: F401
-
-from lightning.app.core.app import LightningApp # noqa: E402
-from lightning.app.core.flow import LightningFlow # noqa: E402
-from lightning.app.core.work import LightningWork # noqa: E402
-from lightning.app.plugin.plugin import LightningPlugin # noqa: E402
-from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402
-
-if module_available("lightning.app.components.demo"):
- from lightning.app.components import demo # noqa: F401
-
-__package_name__ = "lightning.app".split(".")[0]
-
-_PACKAGE_ROOT = os.path.dirname(__file__)
-_PROJECT_ROOT = os.path.dirname(os.path.dirname(_PACKAGE_ROOT))
-if __package_name__ == "lightning":
- _PACKAGE_ROOT = os.path.dirname(_PACKAGE_ROOT)
- _PROJECT_ROOT = os.path.dirname(_PROJECT_ROOT)
-
-__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin", "BuildConfig", "CloudCompute"]
diff --git a/src/lightning/app/api/__init__.py b/src/lightning/app/api/__init__.py
deleted file mode 100644
index d850e874da5a2..0000000000000
--- a/src/lightning/app/api/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from lightning.app.api.http_methods import Delete, Get, Post, Put
-
-__all__ = [
- "Delete",
- "Get",
- "Post",
- "Put",
-]
diff --git a/src/lightning/app/api/http_methods.py b/src/lightning/app/api/http_methods.py
deleted file mode 100644
index aa9e68528e487..0000000000000
--- a/src/lightning/app/api/http_methods.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import inspect
-import time
-from copy import deepcopy
-from dataclasses import dataclass
-from functools import wraps
-from multiprocessing import Queue
-from typing import Any, Callable, Dict, List, Optional
-from uuid import uuid4
-
-from fastapi import FastAPI, HTTPException, Request, status
-from lightning_utilities.core.apply_func import apply_to_collection
-
-from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-def _signature_proxy_function():
- pass
-
-
-@dataclass
-class _FastApiMockRequest:
- """This class is meant to mock FastAPI Request class that isn't pickle-able.
-
- If a user relies on FastAPI Request annotation, the Lightning framework
- patches the annotation before pickling and replace them right after.
-
- Finally, the FastAPI request is converted back to the _FastApiMockRequest
- before being delivered to the users.
-
- Example:
-
- from lightning.app import LightningFlow
- from fastapi import Request
- from lightning.app.api import Post
-
- class Flow(LightningFlow):
-
- def request(self, request: Request) -> OutputRequestModel:
- ...
-
- def configure_api(self):
- return [Post("/api/v1/request", self.request)]
-
- """
-
- _body: Optional[str] = None
- _json: Optional[str] = None
- _method: Optional[str] = None
- _headers: Optional[Dict] = None
-
- @property
- def receive(self):
- raise NotImplementedError
-
- @property
- def method(self):
- return self._method
-
- @property
- def headers(self):
- return self._headers
-
- def body(self):
- return self._body
-
- def json(self):
- return self._json
-
- def stream(self):
- raise NotImplementedError
-
- def form(self):
- raise NotImplementedError
-
- def close(self):
- raise NotImplementedError
-
- def is_disconnected(self):
- raise NotImplementedError
-
-
-async def _mock_fastapi_request(request: Request):
- # TODO: Add more requests parameters.
- return _FastApiMockRequest(
- _body=await request.body(),
- _json=await request.json(),
- _headers=request.headers,
- _method=request.method,
- )
-
-
-class _HttpMethod:
- def __init__(
- self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs: Any
- ):
- """This class is used to inject user defined methods within the App Rest API.
-
- Arguments:
- route: The path used to route the requests
- method: The associated flow method
- timeout: The time in seconds taken before raising a timeout exception.
-
- """
- self.route = route
- self.attached_to_flow = hasattr(method, "__self__")
- self.method_name = method_name or method.__name__
- self.method_annotations = method.__annotations__
- # TODO: Validate the signature contains only pydantic models.
- self.method_signature = inspect.signature(method)
-
- if not self.attached_to_flow:
- self.component_name = method.__name__
- self.method = method
- else:
- self.component_name = method.__self__.name
-
- self.timeout = timeout
- self.kwargs = kwargs
-
- # Enable the users to rely on FastAPI annotation typing with Request.
- # Note: Only a part of the Request functionatilities are supported.
- self._patch_fast_api_request()
-
- def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
- # 1: Get the route associated with the http method.
- route = getattr(app, self.__class__.__name__.lower())
-
- self._unpatch_fast_api_request()
-
- # 2: Create a proxy function with the signature of the wrapped method.
- fn = deepcopy(_signature_proxy_function)
- fn.__annotations__ = self.method_annotations
- fn.__name__ = self.method_name
- setattr(fn, "__signature__", self.method_signature)
-
- # Note: Handle requests differently if attached to a flow.
- if not self.attached_to_flow:
- # 3: Define the request handler.
- @wraps(_signature_proxy_function)
- async def _handle_request(*args: Any, **kwargs: Any):
- if inspect.iscoroutinefunction(self.method):
- return await self.method(*args, **kwargs)
- return self.method(*args, **kwargs)
-
- else:
- request_cls = _CommandRequest if self.route.startswith("/command/") else _APIRequest
-
- # 3: Define the request handler.
- @wraps(_signature_proxy_function)
- async def _handle_request(*args: Any, **kwargs: Any):
- async def fn(*args: Any, **kwargs: Any):
- args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
- for k, v in kwargs.items():
- if hasattr(v, "__await__"):
- kwargs[k] = await v
-
- request_id = str(uuid4()).split("-")[0]
- logger.debug(f"Processing request {request_id} for route: {self.route}")
- request_queue.put(
- request_cls(
- name=self.component_name,
- method_name=self.method_name,
- args=args,
- kwargs=kwargs,
- id=request_id,
- )
- )
-
- t0 = time.time()
- while request_id not in responses_store:
- await asyncio.sleep(0.01)
- if (time.time() - t0) > self.timeout:
- raise HTTPException(
- status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="The response was never received.",
- )
-
- logger.debug(f"Processed request {request_id} for route: {self.route}")
-
- return responses_store.pop(request_id)
-
- response: _RequestResponse = await asyncio.create_task(fn(*args, **kwargs))
-
- if response.status_code != 200:
- raise HTTPException(response.status_code, detail=response.content)
-
- return response.content
-
- # 4: Register the user provided route to the Rest API.
- route(self.route, **self.kwargs)(_handle_request)
-
- def _patch_fast_api_request(self):
- """This function replaces signature annotation for Request with its mock."""
- for k, v in self.method_annotations.items():
- if v == Request:
- self.method_annotations[k] = _FastApiMockRequest
-
- for v in self.method_signature.parameters.values():
- if v._annotation == Request:
- v._annotation = _FastApiMockRequest
-
- def _unpatch_fast_api_request(self):
- """This function replaces back signature annotation to fastapi Request."""
- for k, v in self.method_annotations.items():
- if v == _FastApiMockRequest:
- self.method_annotations[k] = Request
-
- for v in self.method_signature.parameters.values():
- if v._annotation == _FastApiMockRequest:
- v._annotation = Request
-
-
-class Post(_HttpMethod):
- pass
-
-
-class Get(_HttpMethod):
- pass
-
-
-class Put(_HttpMethod):
- pass
-
-
-class Delete(_HttpMethod):
- pass
-
-
-def _add_tags_to_api(apis: List[_HttpMethod], tags: List[str]) -> None:
- for api in apis:
- if not api.kwargs.get("tag"):
- api.kwargs["tags"] = tags
-
-
-def _validate_api(apis: List[_HttpMethod]) -> None:
- for api in apis:
- if not isinstance(api, _HttpMethod):
- raise Exception(f"The provided api should be either [{Delete}, {Get}, {Post}, {Put}]")
- if api.route.startswith("/command"):
- raise Exception("The route `/command` is reserved for commands. Please, use something else.")
diff --git a/src/lightning/app/api/request_types.py b/src/lightning/app/api/request_types.py
deleted file mode 100644
index def50e3a20e10..0000000000000
--- a/src/lightning/app/api/request_types.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import asdict, dataclass
-from typing import Any, Optional
-
-from deepdiff import Delta
-
-
-@dataclass
-class _BaseRequest:
- def to_dict(self):
- return asdict(self)
-
-
-@dataclass
-class _DeltaRequest(_BaseRequest):
- delta: Delta
-
- def to_dict(self):
- return self.delta.to_dict()
-
-
-@dataclass
-class _CommandRequest(_BaseRequest):
- id: str
- name: str
- method_name: str
- args: Any
- kwargs: Any
-
-
-@dataclass
-class _APIRequest(_BaseRequest):
- id: str
- name: str
- method_name: str
- args: Any
- kwargs: Any
-
-
-@dataclass
-class _RequestResponse(_BaseRequest):
- status_code: int
- content: Optional[str] = None
diff --git a/src/lightning/app/cli/__init__.py b/src/lightning/app/cli/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/app-template/.gitignore b/src/lightning/app/cli/app-template/.gitignore
deleted file mode 100644
index 70ba25888435f..0000000000000
--- a/src/lightning/app/cli/app-template/.gitignore
+++ /dev/null
@@ -1,157 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-*install-app*
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-
-*.egg
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Sphinx documentation
-docs/_build/
-docs/source/api/
-docs/source/*.md
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.local_env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# PyCharm
-.idea/
-
-# Lightning logs
-lightning_logs
-*.gz
-.DS_Store
-.*_submit.py
-.vscode
-
-MNIST
-*.pt
-.storage/
-.shared/
-infra
-data
-coverage.*
-# Frontend build artifacts
-*lightning/app/ui*
-gradio_cached_examples
-/docs/source/api_reference/generated/*
-examples/my_own_leaderboard/submissions/*
-docs/source/api_reference/generated/*
-*.ckpt
-redis-stable
-node_modules
-*.rdb
-*.webm
-*hars
-examples/quick_start/*
-examples/quick_start
-examples/template_react_ui/*
-examples/template_react_ui
-# Ignore external components
-lightning/app/components/*
-!lightning/app/components/python
-!lightning/app/components/serve
-!lightning/app/components/__init__.py
-!lightning/app/components/README.md
-train_script.py
-*return_values*
-scratch
-storage
diff --git a/src/lightning/app/cli/app-template/LICENSE b/src/lightning/app/cli/app-template/LICENSE
deleted file mode 100644
index 261eeb9e9f8b2..0000000000000
--- a/src/lightning/app/cli/app-template/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/src/lightning/app/cli/app-template/README.md b/src/lightning/app/cli/app-template/README.md
deleted file mode 100644
index 76c88e6cedb38..0000000000000
--- a/src/lightning/app/cli/app-template/README.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# placeholdername app
-
-This ⚡ [Lightning app](https://lightning.ai/) ⚡ was generated automatically with:
-
-```bash
-lightning_app init app placeholdername
-```
-
-## To run placeholdername
-
-First, install placeholdername (warning: this app has not been officially approved on the lightning gallery):
-
-```bash
-lightning_app install app https://github.com/theUser/placeholdername
-```
-
-Once the app is installed, run it locally with:
-
-```bash
-lightning_app run app placeholdername/app.py
-```
-
-Run it on the [lightning cloud](https://lightning.ai/) with:
-
-```bash
-lightning_app run app placeholdername/app.py --cloud
-```
-
-## to test and link
-
-Run flake to make sure all your styling is consistent (it keeps your team from going insane)
-
-```bash
-flake8 .
-```
-
-To test, follow the README.md instructions in the tests folder.
diff --git a/src/lightning/app/cli/app-template/app.py b/src/lightning/app/cli/app-template/app.py
deleted file mode 100644
index 4b86551324ccc..0000000000000
--- a/src/lightning/app/cli/app-template/app.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from lightning.app import LightningApp, LightningFlow
-from placeholdername import ComponentA, ComponentB
-
-
-class LitApp(LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.component_a = ComponentA()
- self.component_b = ComponentB()
-
- def run(self):
- self.component_a.run()
- self.component_b.run()
-
-
-app = LightningApp(LitApp())
diff --git a/src/lightning/app/cli/app-template/placeholdername/__init__.py b/src/lightning/app/cli/app-template/placeholdername/__init__.py
deleted file mode 100644
index cf954823e0315..0000000000000
--- a/src/lightning/app/cli/app-template/placeholdername/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from placeholdername.components.component_a import ComponentA
-from placeholdername.components.component_b import ComponentB
-
-__all__ = ["ComponentA", "ComponentB"]
diff --git a/src/lightning/app/cli/app-template/placeholdername/components/component_a/__init__.py b/src/lightning/app/cli/app-template/placeholdername/components/component_a/__init__.py
deleted file mode 100644
index 82753954e0e03..0000000000000
--- a/src/lightning/app/cli/app-template/placeholdername/components/component_a/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from placeholdername.components.component_a.component_a import ComponentA
-
-__all__ = ["ComponentA"]
diff --git a/src/lightning/app/cli/app-template/placeholdername/components/component_a/component_a.py b/src/lightning/app/cli/app-template/placeholdername/components/component_a/component_a.py
deleted file mode 100644
index e11ff40c299db..0000000000000
--- a/src/lightning/app/cli/app-template/placeholdername/components/component_a/component_a.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class ComponentA(LightningFlow):
- def run(self):
- print("hello from component A")
diff --git a/src/lightning/app/cli/app-template/placeholdername/components/component_b/__init__.py b/src/lightning/app/cli/app-template/placeholdername/components/component_b/__init__.py
deleted file mode 100644
index 876454576ad90..0000000000000
--- a/src/lightning/app/cli/app-template/placeholdername/components/component_b/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from placeholdername.components.component_b.component_a import ComponentB
-
-__all__ = ["ComponentB"]
diff --git a/src/lightning/app/cli/app-template/placeholdername/components/component_b/component_a.py b/src/lightning/app/cli/app-template/placeholdername/components/component_b/component_a.py
deleted file mode 100644
index d80505d986026..0000000000000
--- a/src/lightning/app/cli/app-template/placeholdername/components/component_b/component_a.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class ComponentB(LightningFlow):
- def run(self):
- print("hello from component B")
diff --git a/src/lightning/app/cli/app-template/requirements.txt b/src/lightning/app/cli/app-template/requirements.txt
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/app-template/setup.py b/src/lightning/app/cli/app-template/setup.py
deleted file mode 100644
index c398ca985f759..0000000000000
--- a/src/lightning/app/cli/app-template/setup.py
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/usr/bin/env python
-
-from setuptools import find_packages, setup
-
-setup(
- name="placeholdername",
- version="0.0.0",
- description="⚡ Lightning app ⚡ generated with command: lightning init app",
- author="",
- author_email="",
- # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
- url="https://github.com/Lightning-AI/lightning-app-template",
- install_requires=[],
- packages=find_packages(),
-)
diff --git a/src/lightning/app/cli/app-template/tests/README.md b/src/lightning/app/cli/app-template/tests/README.md
deleted file mode 100644
index 85e8c7faa08f9..0000000000000
--- a/src/lightning/app/cli/app-template/tests/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Run tests
-
-To run the tests:
-
-```bash
-# go to your app folder
-cd placeholdername
-
-# go to tests folder
-cd tests
-
-# install testing deps
-pip install -r requirements.txt
-
-# run tests
-pytest .
-```
diff --git a/src/lightning/app/cli/app-template/tests/__init__.py b/src/lightning/app/cli/app-template/tests/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/app-template/tests/requirements.txt b/src/lightning/app/cli/app-template/tests/requirements.txt
deleted file mode 100644
index 3185d1c44f033..0000000000000
--- a/src/lightning/app/cli/app-template/tests/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-coverage
-codecov>=2.1
-pytest>=5.0.0
-pytest-cov
-pytest-flake8
-flake8
-check-manifest
-twine==4.0.1
diff --git a/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py b/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py
deleted file mode 100644
index 6c7743b93ce1e..0000000000000
--- a/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py
+++ /dev/null
@@ -1,44 +0,0 @@
-r"""
-To test a lightning app:
-1. Use LightningTestApp which is a subclass of LightningApp.
-2. Subclass run_once in LightningTestApp.
-3. in run_once, come up with a way to verify the behavior you wanted.
-
-run_once runs your app through one cycle of the event loop and then terminates
-"""
-
-import io
-import os
-from contextlib import redirect_stdout
-
-from lightning.app.testing.testing import LightningTestApp, application_testing
-
-
-class LightningAppTestInt(LightningTestApp):
- def run_once(self) -> bool:
- f = io.StringIO()
- with redirect_stdout(f):
- super().run_once()
- out = f.getvalue()
- assert out == "hello from component A\nhello from component B\n"
- return True
-
-
-def test_templatename_app():
- start_dir = os.getcwd()
- os.chdir("..")
-
- cwd = os.getcwd()
- cwd = os.path.join(cwd, "placeholdername/app.py")
- command_line = [
- cwd,
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ]
- result = application_testing(LightningAppTestInt, command_line)
- assert result.exit_code == 0
-
- # reset dir
- os.chdir(start_dir)
diff --git a/src/lightning/app/cli/cmd_apps.py b/src/lightning/app/cli/cmd_apps.py
deleted file mode 100644
index d8d7deace2bb4..0000000000000
--- a/src/lightning/app/cli/cmd_apps.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-from datetime import datetime
-from typing import List, Optional
-
-from lightning_cloud.openapi import (
- Externalv1LightningappInstance,
- Externalv1Lightningwork,
- V1LightningappInstanceState,
- V1LightningappInstanceStatus,
-)
-from rich.console import Console
-from rich.table import Table
-from rich.text import Text
-
-from lightning.app.cli.core import Formatable
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.network import LightningClient
-
-
-class _AppManager:
- """_AppManager implements API calls specific to Lightning AI BYOC apps."""
-
- def __init__(self) -> None:
- self.api_client = LightningClient(retry=False)
-
- def get_app(self, app_id: str) -> Externalv1LightningappInstance:
- project = _get_project(self.api_client)
- return self.api_client.lightningapp_instance_service_get_lightningapp_instance(
- project_id=project.project_id, id=app_id
- )
-
- def list_apps(self, limit: int = 100, phase_in: Optional[List[str]] = None) -> List[Externalv1LightningappInstance]:
- phase_in = phase_in or []
- project = _get_project(self.api_client)
-
- kwargs = {
- "project_id": project.project_id,
- "limit": limit,
- "phase_in": phase_in,
- }
-
- resp = self.api_client.lightningapp_instance_service_list_lightningapp_instances(**kwargs)
- apps = resp.lightningapps
- while resp.next_page_token is not None and resp.next_page_token != "":
- kwargs["page_token"] = resp.next_page_token
- resp = self.api_client.lightningapp_instance_service_list_lightningapp_instances(**kwargs)
- apps = apps + resp.lightningapps
- return apps
-
- def list_components(self, app_id: str, phase_in: Optional[List[str]] = None) -> List[Externalv1Lightningwork]:
- phase_in = phase_in or []
- project = _get_project(self.api_client)
- resp = self.api_client.lightningwork_service_list_lightningwork(
- project_id=project.project_id,
- app_id=app_id,
- phase_in=phase_in,
- )
- return resp.lightningworks
-
- def list(self, limit: int = 100) -> None:
- console = Console()
- console.print(_AppList(self.list_apps(limit=limit)).as_table())
-
- def delete(self, app_id: str) -> None:
- project = _get_project(self.api_client)
- self.api_client.lightningapp_instance_service_delete_lightningapp_instance(
- project_id=project.project_id,
- id=app_id,
- )
-
-
-class _AppList(Formatable):
- def __init__(self, apps: List[Externalv1LightningappInstance]) -> None:
- self.apps = apps
-
- @staticmethod
- def _textualize_state_transitions(
- desired_state: V1LightningappInstanceState, current_state: V1LightningappInstanceStatus
- ) -> Text:
- phases = {
- V1LightningappInstanceState.IMAGE_BUILDING: Text("building image", style="bold yellow"),
- V1LightningappInstanceState.PENDING: Text("pending", style="bold yellow"),
- V1LightningappInstanceState.RUNNING: Text("running", style="bold green"),
- V1LightningappInstanceState.FAILED: Text("failed", style="bold red"),
- V1LightningappInstanceState.STOPPED: Text("stopped"),
- V1LightningappInstanceState.NOT_STARTED: Text("not started"),
- V1LightningappInstanceState.DELETED: Text("deleted", style="bold red"),
- V1LightningappInstanceState.UNSPECIFIED: Text("unspecified", style="bold red"),
- }
-
- if current_state.phase == V1LightningappInstanceState.UNSPECIFIED and current_state.start_timestamp is None:
- return Text("not yet started", style="bold yellow")
-
- if (
- desired_state == V1LightningappInstanceState.DELETED
- and current_state.phase != V1LightningappInstanceState.DELETED
- ):
- return Text("terminating", style="bold red")
-
- if (
- any(
- phase == current_state.phase
- for phase in [V1LightningappInstanceState.PENDING, V1LightningappInstanceState.STOPPED]
- )
- and desired_state == V1LightningappInstanceState.RUNNING
- ):
- return Text("restarting", style="bold yellow")
-
- return phases[current_state.phase]
-
- def as_json(self) -> str:
- return json.dumps(self.apps)
-
- def as_table(self) -> Table:
- table = Table("id", "name", "status", "created", show_header=True, header_style="bold green")
-
- for app in self.apps:
- status = self._textualize_state_transitions(desired_state=app.spec.desired_state, current_state=app.status)
-
- # this guard is necessary only until 0.3.93 releases which includes the `created_at`
- # field to the external API
- created_at = datetime.now()
- if hasattr(app, "created_at"):
- created_at = app.created_at
-
- table.add_row(
- app.id,
- app.name,
- status,
- created_at.strftime("%Y-%m-%d") if created_at else "",
- )
- return table
diff --git a/src/lightning/app/cli/cmd_init.py b/src/lightning/app/cli/cmd_init.py
deleted file mode 100644
index db83fd41e47d9..0000000000000
--- a/src/lightning/app/cli/cmd_init.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import shutil
-from typing import List, Optional, Tuple
-
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-def app(app_name: str) -> None:
- if app_name is None:
- app_name = _capture_valid_app_component_name(resource_type="app")
-
- # generate resource template
- new_resource_name, _ = _make_resource(resource_dir="app-template", resource_name=app_name)
-
- m = f"""
- ⚡ Lightning app template created! ⚡
- {new_resource_name}
-
- run your app with:
- lightning run app {app_name}/app.py
-
- run it on the cloud to share with your collaborators:
- lightning run app {app_name}/app.py --cloud
- """
- logger.info(m)
-
-
-def _make_resource(resource_dir: str, resource_name: str) -> Tuple[str, str]:
- path = os.path.dirname(os.path.abspath(__file__))
- template_dir = os.path.join(path, resource_dir)
- name_for_files = re.sub("-", "_", resource_name)
-
- new_resource_name = os.path.join(os.getcwd(), resource_name)
-
- # lay out scaffolding
- logger.info(f"laying out component template at {new_resource_name}")
- shutil.copytree(template_dir, new_resource_name)
-
- # rename main folder
- os.rename(os.path.join(new_resource_name, "placeholdername"), os.path.join(new_resource_name, name_for_files))
-
- # for each file, rename the word
- trouble_names = {".DS_Store"}
- files = _ls_recursively(new_resource_name)
- for bad_file in files:
- if bad_file.split("/")[-1] in trouble_names:
- continue
- # find the words and replace
- with open(bad_file) as fo:
- content = fo.read().replace("placeholdername", name_for_files)
- with open(bad_file, "w") as fw:
- fw.write(content)
-
- # rename files
- for file_name in files:
- new_file = re.sub("placeholdername", name_for_files, file_name)
- os.rename(file_name, new_file)
-
- return new_resource_name, name_for_files
-
-
-def _ls_recursively(dir_name: str) -> List[str]:
- fname = []
- for root, d_names, f_names in os.walk(dir_name):
- for f in f_names:
- if "__pycache__" not in root:
- fname.append(os.path.join(root, f))
-
- return fname
-
-
-def _capture_valid_app_component_name(value: Optional[str] = None, resource_type: str = "app") -> str:
- prompt = f"""
- ⚡ Creating Lightning {resource_type} ⚡
- """
- logger.info(prompt)
-
- try:
- if value is None:
- value = input(f"\nName your Lightning {resource_type} (example: the-{resource_type}-name) > ")
- value = value.strip().lower()
- unsafe_chars = set(re.findall(r"[^a-z0-9\-]", value))
- if len(unsafe_chars) > 0:
- m = f"""
- Error: your Lightning {resource_type} name:
- {value}
-
- contains the following unsupported characters:
- {unsafe_chars}
-
- A Lightning {resource_type} name can only contain letters (a-z) numbers (0-9) and the '-' character
-
- valid example:
- lightning-{resource_type}
- """
- raise SystemExit(m)
-
- except KeyboardInterrupt:
- raise SystemExit(
- f"""
- ⚡ {resource_type} init aborted! ⚡
- """
- )
-
- return value
-
-
-def component(component_name: str) -> None:
- if component_name is None:
- component_name = _capture_valid_app_component_name(resource_type="component")
-
- # generate resource template
- new_resource_name, name_for_files = _make_resource(resource_dir="component-template", resource_name=component_name)
-
- m = f"""
- ⚡ Lightning component template created! ⚡
- {new_resource_name}
-
- ⚡ To use your component, first pip install it (with these 3 commands): ⚡
- cd {component_name}
- pip install -r requirements.txt
- pip install -e .
-
- ⚡ Use the component inside an app: ⚡
-
- from {name_for_files} import TemplateComponent
- import lightning.app as la
-
- class LitApp(la.LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.{name_for_files} = TemplateComponent()
-
- def run(self):
- print('this is a simple Lightning app to verify your component is working as expected')
- self.{name_for_files}.run()
-
- app = la.LightningApp(LitApp())
-
- ⚡ Checkout the demo app with your {component_name} component: ⚡
- lightning run app {component_name}/app.py
-
- ⚡ Tip: Publish your component to the Lightning Gallery to enable users to install it like so:
- lightning install component YourLightningUserName/{component_name}
-
- so the Lightning community can use it like:
- from {name_for_files} import TemplateComponent
-
- """
- logger.info(m)
diff --git a/src/lightning/app/cli/cmd_install.py b/src/lightning/app/cli/cmd_install.py
deleted file mode 100644
index b43aa3f88fac9..0000000000000
--- a/src/lightning/app/cli/cmd_install.py
+++ /dev/null
@@ -1,657 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import shutil
-import subprocess
-import sys
-from typing import Dict, Optional, Tuple
-
-import click
-import requests
-from packaging.version import Version
-
-from lightning.app.core.constants import LIGHTNING_APPS_PUBLIC_REGISTRY, LIGHTNING_COMPONENT_PUBLIC_REGISTRY
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-@click.group(name="install")
-def install() -> None:
- """Install Lightning AI selfresources."""
- pass
-
-
-@install.command("app")
-@click.argument("name", type=str)
-@click.option(
- "--yes",
- "-y",
- is_flag=True,
- help="disables prompt to ask permission to create env and run install cmds",
-)
-@click.option(
- "--version",
- "-v",
- type=str,
- help="Specify the version to install. By default it uses 'latest'",
- default="latest",
- show_default=True,
-)
-@click.option(
- "--overwrite",
- "-f",
- is_flag=True,
- default=False,
- help="When set, overwrite the app directory without asking if it already exists.",
-)
-def install_app(name: str, yes: bool, version: str, overwrite: bool = False) -> None:
- _install_app_command(name, yes, version, overwrite=overwrite)
-
-
-@install.command("component")
-@click.argument("name", type=str)
-@click.option(
- "--yes",
- "-y",
- is_flag=True,
- help="disables prompt to ask permission to create env and run install cmds",
-)
-@click.option(
- "--version",
- "-v",
- type=str,
- help="Specify the version to install. By default it uses 'latest'",
- default="latest",
- show_default=True,
-)
-def install_component(name: str, yes: bool, version: str) -> None:
- _install_component_command(name, yes, version)
-
-
-def _install_app_command(name: str, yes: bool, version: str, overwrite: bool = False) -> None:
- if "github.com" in name:
- if version != "latest":
- logger.warn(
- "When installing from GitHub, only the 'latest' version is supported. "
- f"The provided version ({version}) will be ignored."
- )
- return non_gallery_app(name, yes, overwrite=overwrite)
-
- return gallery_app(name, yes, version, overwrite=overwrite)
-
-
-def _install_component_command(name: str, yes: bool, version: str, overwrite: bool = False) -> None:
- if "github.com" in name:
- if version != "latest":
- logger.warn(
- "When installing from GitHub, only the 'latest' version is supported. "
- f"The provided version ({version}) will be ignored."
- )
- return non_gallery_component(name, yes)
-
- return gallery_component(name, yes, version)
-
-
-def gallery_apps_and_components(
- name: str, yes_arg: bool, version_arg: str, cwd: Optional[str] = None, overwrite: bool = False
-) -> Optional[str]:
- try:
- org, app_or_component = name.split("/")
- except Exception:
- return None
-
- entry, kind = _resolve_entry(name, version_arg)
-
- if kind == "app":
- # give the user the chance to do a manual install
- source_url, git_url, folder_name, git_sha = _show_install_app_prompt(
- entry, app_or_component, org, yes_arg, resource_type="app"
- )
- # run installation if requested
- _install_app_from_source(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha)
-
- return os.path.join(os.getcwd(), *entry["appEntrypointFile"].split("/"))
-
- if kind == "component":
- # give the user the chance to do a manual install
- source_url, git_url, folder_name, git_sha = _show_install_app_prompt(
- entry, app_or_component, org, yes_arg, resource_type="component"
- )
- if "@" in git_url:
- git_url = git_url.split("git+")[1].split("@")[0]
- # run installation if requested
- _install_app_from_source(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha)
-
- return os.path.join(os.getcwd(), *entry["entrypointFile"].split("/"))
-
- return None
-
-
-def gallery_component(name: str, yes_arg: bool, version_arg: str, cwd: Optional[str] = None) -> str:
- # make sure org/component-name name is correct
- org, component = _validate_name(name, resource_type="component", example="lightning/LAI-slack-component")
-
- # resolve registry (orgs can have a private registry through their environment variables)
- registry_url = _resolve_component_registry()
-
- # load the component resource
- component_entry = _resolve_resource(registry_url, name=name, version_arg=version_arg, resource_type="component")
-
- # give the user the chance to do a manual install
- git_url = _show_install_component_prompt(component_entry, component, org, yes_arg)
-
- # run installation if requested
- _install_component_from_source(git_url)
-
- return os.path.join(os.getcwd(), component_entry["entrypointFile"])
-
-
-def non_gallery_component(gh_url: str, yes_arg: bool, cwd: Optional[str] = None) -> None:
- # give the user the chance to do a manual install
- git_url = _show_non_gallery_install_component_prompt(gh_url, yes_arg)
-
- # run installation if requested
- _install_component_from_source(git_url)
-
-
-def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: Optional[str] = None, overwrite: bool = False) -> str:
- # make sure org/app-name syntax is correct
- org, app = _validate_name(name, resource_type="app", example="lightning/quick-start")
-
- # resolve registry (orgs can have a private registry through their environment variables)
- registry_url = _resolve_app_registry()
-
- # load the app resource
- app_entry = _resolve_resource(registry_url, name=name, version_arg=version_arg, resource_type="app")
-
- # give the user the chance to do a manual install
- source_url, git_url, folder_name, git_sha = _show_install_app_prompt(
- app_entry, app, org, yes_arg, resource_type="app"
- )
-
- # run installation if requested
- _install_app_from_source(source_url, git_url, folder_name, cwd=cwd, overwrite=overwrite, git_sha=git_sha)
-
- return os.path.join(os.getcwd(), folder_name, app_entry["appEntrypointFile"])
-
-
-def non_gallery_app(gh_url: str, yes_arg: bool, cwd: Optional[str] = None, overwrite: bool = False) -> None:
- # give the user the chance to do a manual install
- repo_url, folder_name = _show_non_gallery_install_app_prompt(gh_url, yes_arg)
-
- # run installation if requested
- _install_app_from_source(repo_url, repo_url, folder_name, cwd=cwd, overwrite=overwrite)
-
-
-def _show_install_component_prompt(entry: Dict[str, str], component: str, org: str, yes_arg: bool) -> str:
- git_url = entry["gitUrl"]
-
- # yes arg does not prompt the user for permission to install anything
- # automatically creates env and sets up the project
- if yes_arg:
- return git_url
-
- prompt = f"""
- ⚡ Installing Lightning component ⚡
-
- component name : {component}
- developer : {org}
-
- Installation runs the following command for you:
-
- pip install {git_url}
- """
- logger.info(prompt)
-
- try:
- value = input("\nPress enter to continue: ")
- value = value.strip().lower()
- should_install = len(value) == 0 or value in {"y", "yes", 1}
- if not should_install:
- raise KeyboardInterrupt()
-
- return git_url
- except KeyboardInterrupt:
- repo = entry["sourceUrl"]
- raise SystemExit(
- f"""
- ⚡ Installation aborted! ⚡
-
- Install the component yourself by visiting:
- {repo}
- """
- )
-
-
-def _show_non_gallery_install_component_prompt(gh_url: str, yes_arg: bool) -> str:
- if ".git@" not in gh_url:
- m = """
- Error, your github url must be in the following format:
- git+https://github.com/OrgName/repo-name.git@ALongCommitSHAString
-
- Example:
- git+https://github.com/Lightning-AI/LAI-slack-messenger.git@14f333456ffb6758bd19458e6fa0bf12cf5575e1
- """
- raise SystemExit(m)
-
- developer = gh_url.split("/")[3]
- component_name = gh_url.split("/")[4].split(".git")[0]
- repo_url = re.search(r"git\+(.*).git", gh_url).group(1) # type: ignore
-
- # yes arg does not prompt the user for permission to install anything
- # automatically creates env and sets up the project
- if yes_arg:
- return gh_url
-
- prompt = f"""
- ⚡ Installing Lightning component ⚡
-
- component name : {component_name}
- developer : {developer}
-
- ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
- WARNING: this is NOT an official Lightning Gallery component
- Install at your own risk
- ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
-
- Installation runs the following command for you:
-
- pip install {gh_url}
- """
- logger.info(prompt)
-
- try:
- value = input("\nPress enter to continue: ")
- value = value.strip().lower()
- should_install = len(value) == 0 or value in {"y", "yes", 1}
- if not should_install:
- raise KeyboardInterrupt()
-
- return gh_url
- except KeyboardInterrupt:
- raise SystemExit(
- f"""
- ⚡ Installation aborted! ⚡
-
- Install the component yourself by visiting:
- {repo_url}
- """
- )
-
-
-def _show_install_app_prompt(
- entry: Dict[str, str], app: str, org: str, yes_arg: bool, resource_type: str
-) -> Tuple[str, str, str, Optional[str]]:
- source_url = entry["sourceUrl"] # This URL is used only to display the repo and extract folder name
- full_git_url = entry["gitUrl"] # Used to clone the repo (can include tokens for private repos)
- git_url_parts = full_git_url.split("#ref=")
- git_url = git_url_parts[0]
- git_sha = git_url_parts[1] if len(git_url_parts) == 2 else None
-
- folder_name = source_url.split("/")[-1]
-
- # yes arg does not prompt the user for permission to install anything
- # automatically creates env and sets up the project
- if yes_arg:
- return source_url, git_url, folder_name, git_sha
-
- prompt = f"""
- ⚡ Installing Lightning {resource_type} ⚡
-
- {resource_type} name : {app}
- developer: {org}
-
- Installation creates and runs the following commands for you:
-
- git clone {source_url}
- cd {folder_name}
- pip install -r requirements.txt
- pip install -e .
- """
- logger.info(prompt)
-
- try:
- value = input("\nPress enter to continue: ")
- value = value.strip().lower()
- should_install = len(value) == 0 or value in {"y", "yes", 1}
- if not should_install:
- raise KeyboardInterrupt()
-
- return source_url, git_url, folder_name, git_sha
- except KeyboardInterrupt:
- repo = entry["sourceUrl"]
- raise SystemExit(
- f"""
- ⚡ Installation aborted! ⚡
-
- Install the {resource_type} yourself by visiting:
- {repo}
- """
- )
-
-
-def _show_non_gallery_install_app_prompt(gh_url: str, yes_arg: bool) -> Tuple[str, str]:
- try:
- if gh_url.endswith(".git"):
- # folder_name when it's a GH url with .git
- folder_name = gh_url.split("/")[-1]
- folder_name = folder_name[:-4]
- else:
- # the last part of the url is the folder name otherwise
- folder_name = gh_url.split("/")[-1]
-
- org = re.search(r"github.com\/(.*)\/", gh_url).group(1) # type: ignore
- except Exception:
- raise SystemExit(
- """
- Your github url is not supported. Here's the supported format:
- https://github.com/YourOrgName/your-repo-name
-
- Example:
- https://github.com/Lightning-AI/lightning
- """
- )
-
- # yes arg does not prompt the user for permission to install anything
- # automatically creates env and sets up the project
- if yes_arg:
- return gh_url, folder_name
-
- prompt = f"""
- ⚡ Installing Lightning app ⚡
-
- app source : {gh_url}
- developer : {org}
-
- ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
- WARNING: this is NOT an official Lightning Gallery app
- Install at your own risk
- ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
-
- Installation creates and runs the following commands for you:
-
- git clone {gh_url}
- cd {folder_name}
- pip install -r requirements.txt
- pip install -e .
- """
- logger.info(prompt)
-
- try:
- value = input("\nPress enter to continue: ")
- value = value.strip().lower()
- should_install = len(value) == 0 or value in {"y", "yes", 1}
- if not should_install:
- raise KeyboardInterrupt()
-
- return gh_url, folder_name
- except KeyboardInterrupt:
- raise SystemExit(
- f"""
- ⚡ Installation aborted! ⚡
-
- Install the app yourself by visiting {gh_url}
- """
- )
-
-
-def _validate_name(name: str, resource_type: str, example: str) -> Tuple[str, str]:
- # ensure resource identifier is properly formatted
- try:
- org, resource = name.split("/")
- except Exception:
- raise SystemExit(
- f"""
- {resource_type} name format must have organization/{resource_type}-name
-
- Examples:
- {example}
- user/{resource_type}-name
-
- You passed in: {name}
- """
- )
- return org, resource
-
-
-def _resolve_entry(name, version_arg) -> Tuple[Optional[Dict], Optional[str]]:
- entry = None
- kind = None
-
- # resolve registry (orgs can have a private registry through their environment variables)
- registry_url = _resolve_app_registry()
-
- # load the app resource
- entry = _resolve_resource(registry_url, name=name, version_arg=version_arg, resource_type="app", raise_error=False)
-
- if not entry:
- registry_url = _resolve_component_registry()
-
- # load the component resource
- entry = _resolve_resource(
- registry_url, name=name, version_arg=version_arg, resource_type="component", raise_error=False
- )
- kind = "component" if entry else None
-
- else:
- kind = "app"
-
- return entry, kind
-
-
-def _resolve_resource(
- registry_url: str, name: str, version_arg: str, resource_type: str, raise_error: bool = True
-) -> Dict[str, str]:
- gallery_entries = []
- try:
- response = requests.get(registry_url)
- data = response.json()
-
- if resource_type == "app":
- gallery_entries = [a for a in data["apps"] if a["canDownloadSourceCode"]]
-
- elif resource_type == "component":
- gallery_entries = data["components"]
- except requests.ConnectionError:
- sys.tracebacklimit = 0
- raise SystemError(
- f"""
- Network connection error, could not load list of available Lightning {resource_type}s.
-
- Try again when you have a network connection!
- """
- )
-
- entries = []
- all_versions = []
- for x in gallery_entries:
- if name == x["name"]:
- entries.append(x)
- all_versions.append(x["version"])
-
- if len(entries) == 0:
- if raise_error:
- raise SystemExit(f"{resource_type}: '{name}' is not available on ⚡ Lightning AI ⚡")
-
- return None
-
- entry = None
- if version_arg == "latest":
- entry = max(entries, key=lambda app: Version(app["version"]))
- else:
- for e in entries:
- if e["version"] == version_arg:
- entry = e
- break
- if entry is None and raise_error:
- if raise_error:
- raise Exception(
- f"{resource_type}: 'Version {version_arg} for {name}' is not available on ⚡ Lightning AI ⚡. "
- f"Here is the list of all availables versions:{os.linesep}{os.linesep.join(all_versions)}"
- )
- return None
-
- return entry
-
-
-def _install_with_env(repo_url: str, folder_name: str, cwd: Optional[str] = None) -> None:
- if not cwd:
- cwd = os.getcwd()
-
- # clone repo
- logger.info(f"⚡ RUN: git clone {repo_url}")
- subprocess.call(["git", "clone", repo_url])
-
- # step into the repo folder
- os.chdir(f"{folder_name}")
- cwd = os.getcwd()
-
- # create env
- logger.info(f"⚡ CREATE: virtual env at {cwd}")
- subprocess.call(["python", "-m", "venv", cwd])
-
- # activate and install reqs
- # TODO: remove shell=True... but need to run command in venv
- logger.info("⚡ RUN: install requirements (pip install -r requirements.txt)")
- subprocess.call("source bin/activate && pip install -r requirements.txt", shell=True)
-
- # install project
- # TODO: remove shell=True... but need to run command in venv
- logger.info("⚡ RUN: setting up project (pip install -e .)")
- subprocess.call("source bin/activate && pip install -e .", shell=True)
-
- m = f"""
- ⚡ Installed! ⚡ to use your app
- go into the folder: cd {folder_name}
- activate the environment: source bin/activate
- run the app: lightning run app [the_app_file.py]
- """
- logger.info(m)
-
-
-def _install_app_from_source(
- source_url: str,
- git_url: str,
- folder_name: str,
- cwd: Optional[str] = None,
- overwrite: bool = False,
- git_sha: Optional[str] = None,
-) -> None:
- """Installing lighting app from the `git_url`
-
- Args:
- source_url:
- source repo url without any tokens and params, this param is used only for displaying
- git_url:
- repo url that is used to clone, this can contain tokens
- folder_name:
- where to clone the repo ?
- cwd:
- Working director. If not specified, current working directory is used.
- overwrite:
- If true, overwrite the app directory without asking if it already exists
- git_sha:
- The git_sha for checking out the git repo of the app.
-
- """
-
- if not cwd:
- cwd = os.getcwd()
-
- destination = os.path.join(cwd, folder_name)
- if os.path.exists(destination):
- if not overwrite:
- raise SystemExit(
- f"Folder {folder_name} exists, please delete it and try again, "
- f"or force to overwrite the existing folder by passing `--overwrite`.",
- )
- shutil.rmtree(destination)
- # clone repo
- logger.info(f"⚡ RUN: git clone {source_url}")
- try:
- subprocess.check_output(["git", "clone", git_url], stderr=subprocess.STDOUT)
- except subprocess.CalledProcessError as ex:
- if "Repository not found" in str(ex.output):
- raise SystemExit(
- f"""
- Looks like the github url was not found or doesn't exist. Do you have a typo?
- {source_url}
- """
- )
- raise Exception(ex)
-
- # step into the repo folder
- os.chdir(f"{folder_name}")
- cwd = os.getcwd()
-
- try:
- if git_sha:
- subprocess.check_output(["git", "checkout", git_sha], stderr=subprocess.STDOUT)
- except subprocess.CalledProcessError as ex:
- if "did not match any" in str(ex.output):
- raise SystemExit("Looks like the git SHA is not valid or doesn't exist in app repo.")
- raise Exception(ex)
-
- # activate and install reqs
- # TODO: remove shell=True... but need to run command in venv
- logger.info("⚡ RUN: install requirements (pip install -r requirements.txt)")
- subprocess.call("pip install -r requirements.txt", shell=True)
-
- # install project
- # TODO: remove shell=True... but need to run command in venv
- logger.info("⚡ RUN: setting up project (pip install -e .)")
- subprocess.call("pip install -e .", shell=True)
-
- m = f"""
- ⚡ Installed! ⚡ to use your app:
-
- cd {folder_name}
- lightning run app app.py
- """
- logger.info(m)
-
-
-def _install_component_from_source(git_url: str) -> None:
- logger.info("⚡ RUN: pip install")
-
- out = subprocess.check_output(["pip", "install", git_url])
- possible_success_message = [x for x in str(out).split("\\n") if "Successfully installed" in x]
- if len(possible_success_message) > 0:
- uninstall_step = possible_success_message[0]
- uninstall_step = re.sub("Successfully installed", "", uninstall_step).strip()
- uninstall_step = re.sub("-0.0.0", "", uninstall_step).strip()
- m = """
- ⚡ Installed! ⚡
-
- to use your component:
- from the_component import TheClass
-
- make sure to add this entry to your Lightning APP requirements.txt file:
- {git_url}
-
- if you want to uninstall, run this command:
- pip uninstall {uninstall_step}
- """
- logger.info(m)
-
-
-def _resolve_app_registry() -> str:
- return os.environ.get("LIGHTNING_APP_REGISTRY", LIGHTNING_APPS_PUBLIC_REGISTRY)
-
-
-def _resolve_component_registry() -> str:
- return os.environ.get("LIGHTNING_COMPONENT_REGISTRY", LIGHTNING_COMPONENT_PUBLIC_REGISTRY)
diff --git a/src/lightning/app/cli/cmd_pl_init.py b/src/lightning/app/cli/cmd_pl_init.py
deleted file mode 100644
index 2436c28179ef2..0000000000000
--- a/src/lightning/app/cli/cmd_pl_init.py
+++ /dev/null
@@ -1,187 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import pathlib
-import re
-import shutil
-import subprocess
-import sys
-import tarfile
-import urllib.request
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import Any, Dict, List, Optional
-
-import click
-from jinja2 import Environment, FileSystemLoader
-from rich import print
-from rich.panel import Panel
-from rich.status import Status
-from rich.text import Text
-from rich.tree import Tree
-
-import lightning.app
-
-_REPORT_HELP_TEXTS = {
- "core": "Important files for the app such as various components",
- "source": "A copy of all your source code, including the PL script ⚡",
- "tests": "This app comes with tests!",
- "ui": "Source and build files for the user interface",
- "app.py": "This is the main app file!",
- "requirements.txt": "Lists the dependencies required to be installed before running the app",
-}
-
-_REPORT_IGNORE_PATTERNS = [
- r"__pycache__",
- r"__init__\.py",
- r".*egg-info",
- r"\..*",
-]
-
-
-def pl_app(source_dir: str, script_path: str, name: str, overwrite: bool) -> None:
- source_dir = Path(source_dir).resolve()
- script_path = Path(script_path).resolve()
-
- if not source_dir.is_dir():
- click.echo(f"The given source directory does not exist: {source_dir}", err=True)
- raise SystemExit(1)
-
- if not script_path.exists():
- click.echo(f"The given script path does not exist: {script_path}", err=True)
- raise SystemExit(1)
-
- if not script_path.is_file():
- click.echo(f"The given script path must be a file, you passed: {script_path}", err=True)
- raise SystemExit(1)
-
- if source_dir not in script_path.parents:
- click.echo(
- "The given script path must be a subpath of the source directory. Example:"
- " lightning init pl-app ./code ./code/scripts/train.py",
- err=True,
- )
- raise SystemExit(1)
-
- rel_script_path = script_path.relative_to(source_dir)
- cwd = Path.cwd()
- destination = cwd / name
-
- if destination.exists():
- if not overwrite:
- click.echo(
- f"There is already an app with the name {name} in the current working directory. Choose a different"
- f" name with `--name` or force to overwrite the existing folder by passing `--overwrite`.",
- err=True,
- )
- raise SystemExit(1)
-
- shutil.rmtree(destination)
-
- template_dir = Path(lightning.app.cli.__file__).parent / "pl-app-template"
-
- with Status("[bold green]Copying app files"):
- shutil.copytree(template_dir, destination, ignore=shutil.ignore_patterns("node_modules", "build"))
- if (template_dir / "ui" / "build").exists():
- shutil.copytree(template_dir / "ui" / "build", destination / "ui" / "build")
- else:
- download_frontend(destination / "ui" / "build")
-
- with Status("[bold green]Copying source files"):
- shutil.copytree(source_dir, destination / "source", ignore=shutil.ignore_patterns(name))
- project_file_from_template(template_dir, destination, "app.py", script_path=str(rel_script_path))
- project_file_from_template(template_dir, destination, "setup.py", app_name=name)
-
- with Status("[bold green]Installing"):
- subprocess.call(["pip", "install", "--quiet", "-e", str(destination)])
- # TODO: download the ui files
-
- print_pretty_report(
- destination,
- ignore_patterns=_REPORT_IGNORE_PATTERNS,
- help_texts=_REPORT_HELP_TEXTS,
- )
-
-
-def download_frontend(destination: Path) -> None:
- # TODO: Update the URL to the release in GitHub once the PL app repo is public
- url = "https://storage.googleapis.com/grid-packages/pytorch-lightning-app/v0.0.0/build.tar.gz"
- build_dir_name = "build"
- with TemporaryDirectory() as download_dir:
- response = urllib.request.urlopen(url) # noqa: S310
- file = tarfile.open(fileobj=response, mode="r|gz")
- file.extractall(path=download_dir) # noqa: S202
- shutil.move(str(Path(download_dir, build_dir_name)), destination)
-
-
-def project_file_from_template(template_dir: Path, destination_dir: Path, template_name: str, **kwargs: Any) -> None:
- env = Environment(loader=FileSystemLoader(template_dir)) # noqa: S701
- template = env.get_template(template_name)
- rendered_template = template.render(**kwargs)
- with open(destination_dir / template_name, "w") as file:
- file.write(rendered_template)
-
-
-def print_pretty_report(
- directory: pathlib.Path,
- ignore_patterns: Optional[List[str]] = None,
- help_texts: Optional[Dict[str, str]] = None,
-) -> None:
- """Prints a report for the generated app."""
- tree = Tree(
- f":open_file_folder: [link file://{directory}]{directory}",
- guide_style="bold bright_blue",
- )
-
- help_texts = {} if help_texts is None else help_texts
-
- paths = sorted(
- directory.glob("*"),
- key=lambda p: (p.is_file(), p.name.lower()),
- )
- max_witdth = max(len(p.name) for p in paths)
-
- patterns_to_ignore = [] if ignore_patterns is None else ignore_patterns
- for path in paths:
- if any(re.match(pattern, path.name) for pattern in patterns_to_ignore):
- # Only display relevant files
- continue
-
- help_text = help_texts.get(path.name, "")
- padding = " " * (max_witdth - len(path.name))
-
- text_pathname = Text(path.name, "green")
- text_pathname.highlight_regex(r"\..*$", "bold red")
- text_pathname.stylize(f"link file://{path}")
- text_pathname.append(f" {padding} {help_text}", "blue")
-
- icon = "📂 " if path.is_dir() else "📄 "
- icon = icon if _can_encode_icon(icon) else ""
-
- tree.add(Text(icon) + text_pathname)
-
- print("\n")
- print("Done. The app is ready here:\n")
- print(tree)
- print("\nRun it:\n")
- print(Panel(f"[red]lightning run app {directory.relative_to(Path.cwd()) / 'app.py'}"))
-
-
-def _can_encode_icon(icon: str) -> bool:
- """Helper function to check whether an icon can be encoded."""
- try:
- icon.encode(sys.stdout.encoding)
- return True
- except UnicodeEncodeError:
- return False
diff --git a/src/lightning/app/cli/cmd_react_ui_init.py b/src/lightning/app/cli/cmd_react_ui_init.py
deleted file mode 100644
index 22e668433e233..0000000000000
--- a/src/lightning/app/cli/cmd_react_ui_init.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import shutil
-import subprocess
-from typing import Optional
-
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-def react_ui(dest_dir: Optional[str] = None) -> None:
- # verify all the prereqs for install are met
- _check_react_prerequisites()
-
- # copy template files to the dir
- _copy_and_setup_react_ui(dest_dir)
-
-
-def _copy_and_setup_react_ui(dest_dir: Optional[str] = None) -> None:
- logger.info("⚡ setting up react-ui template")
- path = os.path.dirname(os.path.abspath(__file__))
- template_dir = os.path.join(path, "react-ui-template")
-
- if dest_dir is None:
- dest_dir = os.path.join(os.getcwd(), "react-ui")
-
- shutil.copytree(template_dir, dest_dir)
-
- logger.info("⚡ install react project deps")
- ui_path = os.path.join(dest_dir, "ui")
- subprocess.run(f"cd {ui_path} && yarn install", shell=True)
-
- logger.info("⚡ building react project")
- subprocess.run(f"cd {ui_path} && yarn build", shell=True)
-
- m = f"""
- ⚡⚡ react-ui created! ⚡⚡
-
- ⚡ Connect it to your component using `configure_layout`:
-
- # Use a LightningFlow or LightningWork
- class YourComponent(la.LightningFlow):
- def configure_layout(self):
- return la.frontend.StaticWebFrontend(Path(__file__).parent / "react-ui/src/dist")
-
- ⚡ run the example_app.py to see it live!
- lightning_app run app {dest_dir}/example_app.py
-
- """
- logger.info(m)
-
-
-def _check_react_prerequisites() -> None:
- """Args are for test purposes only."""
- missing_msgs = []
- version_regex = r"\d{1,2}\.\d{1,2}\.\d{1,3}"
-
- logger.info("Checking pre-requisites for react")
-
- # make sure npm is installed
- npm_version = subprocess.check_output(["npm", "--version"])
- has_npm = bool(re.search(version_regex, str(npm_version)))
- npm_version = re.search(version_regex, str(npm_version))
- npm_version = None if npm_version is None else npm_version.group(0)
-
- if not has_npm:
- m = """
- This machine is missing 'npm'. Please install npm and rerun 'lightning_app init react-ui' again.
-
- Install instructions: https://docs.npmjs.com/downloading-and-installing-node-js-and-npm
- """
- missing_msgs.append(m)
-
- # make sure node is installed
- node_version = subprocess.check_output(["node", "--version"])
- has_node = bool(re.search(version_regex, str(node_version)))
- node_version = re.search(version_regex, str(node_version))
- node_version = None if node_version is None else node_version.group(0)
-
- if not has_node:
- m = """
- This machine is missing 'node'. Please install node and rerun 'lightning_app init react-ui' again.
-
- Install instructions: https://docs.npmjs.com/downloading-and-installing-node-js-and-npm
- """
- missing_msgs.append(m)
-
- # make sure yarn is installed
- yarn_version = subprocess.check_output(["yarn", "--version"])
- has_yarn = bool(re.search(version_regex, str(yarn_version)))
- yarn_version = re.search(version_regex, str(yarn_version))
- yarn_version = None if yarn_version is None else yarn_version.group(0)
-
- if not has_yarn:
- m = """
- This machine is missing 'yarn'. Please install npm+node first, then run
-
- npm install --global yarn
-
- Full install instructions: https://classic.yarnpkg.com/lang/en/docs/install/#mac-stable
- """
- missing_msgs.append(m)
-
- # exit or show success message
- if len(missing_msgs) > 0:
- missing_msg = "\n".join(missing_msgs)
- raise SystemExit(missing_msg)
- logger.info(
- f"""
- found npm version: {npm_version}
- found node version: {node_version}
- found yarn version: {yarn_version}
-
- Pre-requisites met!
- """
- )
diff --git a/src/lightning/app/cli/commands/__init__.py b/src/lightning/app/cli/commands/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/commands/app_commands.py b/src/lightning/app/cli/commands/app_commands.py
deleted file mode 100644
index bbecceabb6e28..0000000000000
--- a/src/lightning/app/cli/commands/app_commands.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-from typing import Dict, Optional
-
-import requests
-
-from lightning.app.cli.connect.app import (
- _clean_lightning_connection,
- _install_missing_requirements,
- _resolve_command_path,
-)
-from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever
-from lightning.app.utilities.commands.base import _download_command
-from lightning.app.utilities.enum import OpenAPITags
-
-
-def _is_running_help(argv) -> bool:
- return argv[-1] in ["--help", "-"] if argv else False
-
-
-def _run_app_command(app_name: str, app_id: Optional[str]):
- """Execute a function in a running App from its name."""
- # 1: Collect the url and comments from the running application
- _clean_lightning_connection()
-
- running_help = _is_running_help(sys.argv)
-
- retriever = _LightningAppOpenAPIRetriever(app_id, use_cache=running_help)
-
- if not running_help and (retriever.url is None or retriever.api_commands is None):
- if app_name == "localhost":
- print("The command couldn't be executed as your local Lightning App isn't running.")
- else:
- print(f"The command couldn't be executed as your cloud Lightning App `{app_name}` isn't running.")
- sys.exit(0)
-
- if not retriever.api_commands:
- raise Exception("This application doesn't expose any commands yet.")
-
- full_command = "_".join(sys.argv)
-
- has_found = False
- for command in list(retriever.api_commands):
- if command in full_command:
- has_found = True
- for value in sys.argv:
- if value == command and "_" in value:
- print(
- f"The command `{value}` was provided with an underscore and it isn't allowed."
- f"Instead, use `lightning_app {value.replace('_', ' ')}`."
- )
- sys.exit(0)
- break
-
- if not has_found:
- raise Exception(f"The provided command isn't available in {list(retriever.api_commands)}")
-
- # 2: Send the command from the user
- metadata = retriever.api_commands[command]
-
- try:
- # 3: Execute the command
- if metadata["tag"] == OpenAPITags.APP_COMMAND:
- _handle_command_without_client(command, metadata, retriever.url)
- else:
- _handle_command_with_client(command, metadata, app_name, app_id, retriever.url)
- except ModuleNotFoundError:
- _install_missing_requirements(retriever, fail_if_missing=True)
-
- if running_help:
- print("Your command execution was successful.")
-
-
-def _handle_command_without_client(command: str, metadata: Dict, url: str) -> None:
- supported_params = list(metadata["parameters"])
- if _is_running_help(sys.argv):
- print(f"Usage: lightning_app {command} [ARGS]...")
- print(" ")
- print("Options")
- for param in supported_params:
- print(f" {param}: Add description")
- return
-
- provided_params = [param.replace("--", "") for param in sys.argv[1 + len(command.split("_")) :]]
-
- # TODO: Add support for more argument types.
- if any("=" not in param for param in provided_params):
- raise Exception("Please, use --x=y syntax when providing the command arguments.")
-
- if any(param.split("=")[0] not in supported_params for param in provided_params):
- raise Exception(f"Some arguments need to be provided. The keys are {supported_params}.")
-
- # TODO: Encode the parameters and validate their type.
- query_parameters = "&".join(provided_params)
- resp = requests.post(url + f"/command/{command}?{query_parameters}")
- assert resp.status_code == 200, resp.json()
- print(resp.json())
-
-
-def _handle_command_with_client(command: str, metadata: Dict, app_name: str, app_id: Optional[str], url: str):
- debug_mode = bool(int(os.getenv("DEBUG", "0")))
-
- if app_name == "localhost":
- target_file = metadata["cls_path"]
- else:
- target_file = _resolve_command_path(command) if debug_mode else _resolve_command_path(command)
-
- if debug_mode:
- print(target_file)
-
- client_command = _download_command(
- command,
- metadata["cls_path"],
- metadata["cls_name"],
- app_id,
- debug_mode=debug_mode,
- target_file=target_file if debug_mode else _resolve_command_path(command),
- )
- client_command._setup(command_name=command, app_url=url)
- sys.argv = sys.argv[len(command.split("_")) :]
- client_command.run()
diff --git a/src/lightning/app/cli/commands/cd.py b/src/lightning/app/cli/commands/cd.py
deleted file mode 100644
index 7f84b894bf155..0000000000000
--- a/src/lightning/app/cli/commands/cd.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from typing import Optional, Tuple, Union
-
-import click
-from rich.live import Live
-from rich.spinner import Spinner
-from rich.text import Text
-
-from lightning.app.cli.commands import ls
-from lightning.app.cli.connect.app import _LIGHTNING_CONNECTION_FOLDER
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cli_helpers import _error_and_exit
-
-logger = Logger(__name__)
-
-_HOME = os.path.expanduser("~")
-_CD_FILE = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "cd.txt")
-
-
-@click.argument("path", nargs=-1)
-def cd(path: Optional[Union[Tuple[str], str]], verify: bool = True) -> None:
- """Change the current directory within the Lightning Cloud filesystem."""
- with Live(Spinner("point", text=Text("pending...", style="white")), transient=True) as live:
- root = "/"
-
- if isinstance(path, Tuple) and len(path) > 0:
- path = " ".join(path)
-
- # handle ~/
- if isinstance(path, str) and path.startswith(_HOME):
- path = "/" + path.replace(_HOME, "")
-
- # handle no path -> /
- if path is None or len(path) == 0:
- path = "/"
-
- if not os.path.exists(_LIGHTNING_CONNECTION_FOLDER):
- os.makedirs(_LIGHTNING_CONNECTION_FOLDER)
-
- if not os.path.exists(_CD_FILE):
- # Start from the root
- if path.startswith(".."):
- root = _apply_double_dots(root, path)
-
- with open(_CD_FILE, "w") as f:
- f.write(root + "\n")
-
- live.stop()
-
- print(f"cd {root}")
-
- return root
-
- # read from saved cd
- with open(_CD_FILE) as f:
- lines = f.readlines()
- root = lines[0].replace("\n", "")
-
- if verify:
- if path.startswith("/"):
- paths = [os.path.join(path, p) for p in ls.ls(path, print=False, use_live=False)]
- else:
- paths = [os.path.join(root, p) for p in ls.ls(root, print=False, use_live=False)]
-
- # generate new root
- if root == "/":
- if path == "/":
- root = "/"
- elif not path.startswith(".."):
- if not path.startswith("/"):
- path = "/" + path
- root = path
- else:
- root = _apply_double_dots(root, path)
- else:
- if path.startswith(".."):
- root = _apply_double_dots(root, path)
- elif path.startswith("~"):
- root = path[2:]
- else:
- root = os.path.join(root, path)
-
- if verify and root != "/" and not any(p.startswith(root) or root.startswith(p) for p in paths):
- _error_and_exit(f"no such file or directory: {path}")
-
- os.remove(_CD_FILE)
-
- # store new root
- with open(_CD_FILE, "w") as f:
- f.write(root + "\n")
-
- live.stop()
-
- print(f"cd {root}")
-
- return root
-
-
-def _apply_double_dots(root: str, path: str) -> str:
- splits = [split for split in path.split("/") if split != ""]
- for split in splits:
- root = "/" + os.path.join(*root.split("/")[:-1]) if split == ".." else os.path.join(root, split)
- return root
diff --git a/src/lightning/app/cli/commands/cp.py b/src/lightning/app/cli/commands/cp.py
deleted file mode 100644
index 0b11a874b216d..0000000000000
--- a/src/lightning/app/cli/commands/cp.py
+++ /dev/null
@@ -1,350 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import concurrent
-import contextlib
-import os
-import sys
-from functools import partial
-from multiprocessing.pool import ApplyResult
-from pathlib import Path
-from textwrap import dedent
-from typing import Any, Optional, Tuple, Union
-
-import click
-import requests
-import urllib3
-from lightning_cloud.openapi import (
- Externalv1Cluster,
- Externalv1LightningappInstance,
- ProjectIdStorageBody,
- V1CloudSpace,
-)
-from rich.live import Live
-from rich.progress import BarColumn, DownloadColumn, Progress, TaskID, TextColumn
-from rich.spinner import Spinner
-from rich.text import Text
-
-from lightning.app.cli.commands.ls import _collect_artifacts, _get_prefix
-from lightning.app.cli.commands.pwd import _pwd
-from lightning.app.source_code import FileUploader
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.auth import _AuthTokenGetter
-from lightning.app.utilities.cli_helpers import _error_and_exit
-from lightning.app.utilities.network import LightningClient
-
-logger = Logger(__name__)
-
-
-@click.argument("src_path", required=True)
-@click.argument("dst_path", required=True)
-@click.option("-r", required=False, hidden=True)
-@click.option("--recursive", required=False, hidden=True)
-@click.option("--zip", required=False, is_flag=True, default=False)
-def cp(src_path: str, dst_path: str, r: bool = False, recursive: bool = False, zip: bool = False) -> None:
- """Copy files between your local filesystem and the Lightning Cloud filesystem."""
- if sys.platform == "win32":
- print("`cp` isn't supported on windows. Open an issue on Github.")
- sys.exit(0)
-
- with Live(Spinner("point", text=Text("pending...", style="white")), transient=True) as live:
- pwd = _pwd()
-
- client = LightningClient(retry=False)
-
- src_path, src_remote = _sanitize_path(src_path, pwd)
- dst_path, dst_remote = _sanitize_path(dst_path, pwd)
-
- if src_remote and dst_remote:
- return _error_and_exit("Moving files remotely isn't supported yet. Please, open a Github issue.")
-
- if not src_remote and dst_remote:
- if dst_path == "/" or len(dst_path.split("/")) == 1:
- return _error_and_exit("Uploading files at the project level isn't allowed yet.")
- if zip:
- return _error_and_exit("Zipping uploads isn't supported yet. Please, open a Github issue.")
- _upload_files(live, client, src_path, dst_path, pwd)
- return None
- if src_remote and not dst_remote:
- if zip:
- return _zip_files(live, src_path, dst_path)
- _download_files(live, client, src_path, dst_path, pwd)
- return None
-
- return _error_and_exit("Moving files locally isn't supported yet. Please, open a Github issue.")
-
-
-def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str, pwd: str) -> str:
- remote_splits = [split for split in remote_dst.split("/") if split != ""]
- remote_dst = os.path.join(*remote_splits)
-
- if not os.path.exists(local_src):
- return _error_and_exit(f"The provided source path {local_src} doesn't exist.")
-
- lit_resource = None
-
- if len(remote_splits) > 1:
- project_id, lit_resource = _get_project_id_and_resource(pwd)
- else:
- project_id = _get_project_id_from_name(remote_dst)
-
- if len(remote_splits) > 2:
- remote_dst = os.path.join(*remote_splits[2:])
-
- local_src = Path(local_src).resolve()
- upload_paths = []
-
- if os.path.isdir(local_src):
- for root_dir, _, paths in os.walk(local_src):
- for path in paths:
- upload_paths.append(os.path.join(root_dir, path))
- else:
- upload_paths = [local_src]
-
- _upload_urls = []
-
- clusters = client.projects_service_list_project_cluster_bindings(project_id)
-
- live.stop()
-
- for upload_path in upload_paths:
- for cluster in clusters.clusters:
- filename = str(upload_path).replace(str(os.getcwd()), "")[1:]
- filename = _get_prefix(os.path.join(remote_dst, filename), lit_resource) if lit_resource else "/" + filename
-
- response = client.lightningapp_instance_service_upload_project_artifact(
- project_id=project_id,
- body=ProjectIdStorageBody(cluster_id=cluster.cluster_id, filename=filename),
- async_req=True,
- )
- _upload_urls.append(response)
-
- upload_urls = []
- for upload_url in _upload_urls:
- upload_urls.extend(upload_url.get().urls)
-
- live.stop()
-
- if not upload_paths:
- print("There were no files to upload.")
- return None
-
- progress = _get_progress_bar()
-
- total_size = sum([Path(path).stat().st_size for path in upload_paths]) // max(len(clusters.clusters), 1)
- task_id = progress.add_task("upload", filename="", total=total_size)
-
- progress.start()
-
- _upload_partial = partial(_upload, progress=progress, task_id=task_id)
-
- with concurrent.futures.ThreadPoolExecutor(4) as executor:
- results = executor.map(_upload_partial, upload_paths, upload_urls)
-
- progress.stop()
-
- # Raise the first exception found
- exception = next((e for e in results if isinstance(e, Exception)), None)
- if exception:
- _error_and_exit("We detected errors in uploading your files.")
- return None
- return None
-
-
-def _upload(source_file: str, presigned_url: ApplyResult, progress: Progress, task_id: TaskID) -> Optional[Exception]:
- source_file = Path(source_file)
- file_uploader = FileUploader(
- presigned_url,
- source_file,
- total_size=None,
- name=str(source_file),
- )
- file_uploader.progress = progress
- file_uploader.task_id = task_id
- file_uploader.upload()
-
-
-def _zip_files(live: Live, remote_src: str, local_dst: str) -> None:
- if len(remote_src.split("/")) < 3:
- return _error_and_exit(
- dedent(
- f"""
- The source path must be at least two levels deep (e.g. r:/my-project/my-lit-resource).
-
- The path provided was: r:{remote_src}
- """
- )
- )
-
- if os.path.isdir(local_dst):
- local_dst = os.path.join(local_dst, os.path.basename(remote_src) + ".zip")
-
- project_id, lit_resource = _get_project_id_and_resource(remote_src)
-
- # /my-project/my-lit-resource/artfact-path -> cloudspace/my-lit-resource-id/artifact-path
- artifact = "/".join(remote_src.split("/")[3:])
- prefix = _get_prefix(artifact, lit_resource)
-
- token = _AuthTokenGetter(LightningClient().api_client)._get_api_token()
- endpoint = f"/v1/projects/{project_id}/artifacts/download?prefix={prefix}&token={token}"
-
- cluster = _cluster_from_lit_resource(lit_resource)
- url = _storage_host(cluster) + endpoint
-
- live.stop()
- progress = _get_progress_bar(transient=True)
- progress.start()
- task_id = progress.add_task("download zip", total=None)
-
- _download_file(local_dst, url, progress, task_id)
- progress.stop()
-
- click.echo(f"Downloaded to {local_dst}")
- return None
-
-
-def _download_files(live, client, remote_src: str, local_dst: str, pwd: str):
- project_id, lit_resource = _get_project_id_and_resource(pwd)
-
- download_paths = []
- download_urls = []
- total_size = []
-
- prefix = _get_prefix("/".join(pwd.split("/")[3:]), lit_resource) + "/"
-
- for artifact in _collect_artifacts(client, project_id, prefix, include_download_url=True):
- path = os.path.join(local_dst, artifact.filename.replace(remote_src, ""))
- path = Path(path).resolve()
- os.makedirs(path.parent, exist_ok=True)
- download_paths.append(path)
- download_urls.append(artifact.url)
- total_size.append(int(artifact.size_bytes))
-
- live.stop()
-
- if not download_paths:
- print("There were no files to download.")
- return
-
- progress = progress = _get_progress_bar()
-
- progress.start()
-
- task_id = progress.add_task("download", filename="", total=sum(total_size))
-
- _download_file_fn = partial(_download_file, progress=progress, task_id=task_id)
-
- with concurrent.futures.ThreadPoolExecutor(4) as executor:
- results = executor.map(_download_file_fn, download_paths, download_urls)
-
- progress.stop()
-
- # Raise the first exception found
- exception = next((e for e in results if isinstance(e, Exception)), None)
- if exception:
- _error_and_exit("There was an error downloading your files.")
-
-
-def _download_file(path: str, url: str, progress: Progress, task_id: TaskID) -> None:
- # Disable warning about making an insecure request
- urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
-
- with contextlib.suppress(ConnectionError):
- request = requests.get(url, stream=True, verify=False) # noqa: S501
-
- chunk_size = 1024
-
- with open(path, "wb") as fp:
- for chunk in request.iter_content(chunk_size=chunk_size):
- fp.write(chunk) # type: ignore
- progress.update(task_id, advance=len(chunk))
-
-
-def _sanitize_path(path: str, pwd: str) -> Tuple[str, bool]:
- is_remote = _is_remote(path)
- if is_remote:
- path = _remove_remote(path)
- path = pwd if path == "." else os.path.join(pwd, path)
- return path, is_remote
-
-
-def _is_remote(path: str) -> bool:
- return path.startswith("r:") or path.startswith("remote:")
-
-
-def _remove_remote(path: str) -> str:
- return path.replace("r:", "").replace("remote:", "")
-
-
-def _get_project_id_and_resource(pwd: str) -> Tuple[str, Union[Externalv1LightningappInstance, V1CloudSpace]]:
- """Convert a root path to a project id and app id."""
- # TODO: Handle project level
- project_name, resource_name, *_ = pwd.split("/")[1:3]
-
- # 1. Collect the projects of the user
- client = LightningClient()
- projects = client.projects_service_list_memberships()
- project_id = [project.project_id for project in projects.memberships if project.name == project_name][0]
-
- # 2. Collect resources
- lit_apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id).lightningapps
-
- lit_cloud_spaces = client.cloud_space_service_list_cloud_spaces(project_id=project_id).cloudspaces
-
- lit_ressources = [lit_resource for lit_resource in lit_cloud_spaces if lit_resource.name == resource_name]
-
- if len(lit_ressources) == 0:
- lit_ressources = [lit_resource for lit_resource in lit_apps if lit_resource.name == resource_name]
-
- if len(lit_ressources) == 0:
- print(f"ERROR: There isn't any Lightning Ressource matching the name {resource_name}.")
- sys.exit(0)
-
- return project_id, lit_ressources[0]
-
-
-def _get_project_id_from_name(project_name: str) -> str:
- # 1. Collect the projects of the user
- client = LightningClient()
- projects = client.projects_service_list_memberships()
- return [project.project_id for project in projects.memberships if project.name == project_name][0]
-
-
-def _get_progress_bar(**kwargs: Any) -> Progress:
- return Progress(
- TextColumn("[bold blue]{task.description}", justify="left"),
- BarColumn(bar_width=None),
- "[self.progress.percentage]{task.percentage:>3.1f}%",
- DownloadColumn(),
- **kwargs,
- )
-
-
-def _storage_host(cluster: Externalv1Cluster) -> str:
- dev_host = os.environ.get("LIGHTNING_STORAGE_HOST")
- if dev_host:
- return dev_host
- return f"https://storage.{cluster.spec.driver.kubernetes.root_domain_name}"
-
-
-def _cluster_from_lit_resource(lit_resource: Union[Externalv1LightningappInstance, V1CloudSpace]) -> Externalv1Cluster:
- client = LightningClient()
- if isinstance(lit_resource, Externalv1LightningappInstance):
- return client.cluster_service_get_cluster(lit_resource.spec.cluster_id)
-
- clusters = client.cluster_service_list_clusters()
- for cluster in clusters.clusters:
- if cluster.id == clusters.default_cluster:
- return cluster
- return None
diff --git a/src/lightning/app/cli/commands/logs.py b/src/lightning/app/cli/commands/logs.py
deleted file mode 100644
index 4587987ae5f17..0000000000000
--- a/src/lightning/app/cli/commands/logs.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import List
-
-import click
-import rich
-from rich.color import ANSI_COLOR_NAMES
-
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.app_logs import _app_logs_reader
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.logs_socket_api import _LightningLogsSocketAPI
-from lightning.app.utilities.network import LightningClient
-
-logger = Logger(__name__)
-
-
-@click.argument("app_name", required=False)
-@click.argument("components", nargs=-1, required=False)
-@click.option("-f", "--follow", required=False, is_flag=True, help="Wait for new logs, to exit use CTRL+C.")
-def logs(app_name: str, components: List[str], follow: bool) -> None:
- """Show cloud application logs. By default, prints logs for all currently available components.
-
- Example uses:
-
- Print all application logs:
-
- $ lightning show logs my-application
-
- Print logs only from the flow (no work):
-
- $ lightning show logs my-application flow
-
- Print logs only from selected works:
-
- $ lightning show logs my-application root.work_a root.work_b
-
- """
- _show_logs(app_name, components, follow)
-
-
-def _show_logs(app_name: str, components: List[str], follow: bool) -> None:
- client = LightningClient(retry=False)
- project = _get_project(client)
-
- apps = {
- getattr(app, "display_name", None) or app.name: app
- for app in client.lightningapp_instance_service_list_lightningapp_instances(
- project_id=project.project_id
- ).lightningapps
- }
-
- if not apps:
- raise click.ClickException(
- "You don't have any application in the cloud. Please, run an application first with `--cloud`."
- )
-
- if not app_name:
- raise click.ClickException(
- f"You have not specified any Lightning App. Please select one of the following: [{', '.join(apps.keys())}]."
- )
-
- if app_name not in apps:
- raise click.ClickException(
- f"The Lightning App '{app_name}' does not exist. "
- f"Please select one of the following: [{', '.join(apps.keys())}]."
- )
-
- # Fetch all lightning works from given application
- # 'Flow' component is somewhat implicit, only one for whole app,
- # and not listed in lightningwork API - so we add it directly to the list
- works = client.lightningwork_service_list_lightningwork(
- project_id=project.project_id, app_id=apps[app_name].id
- ).lightningworks
-
- app_component_names = ["flow"] + [f.name for f in apps[app_name].spec.flow_servers] + [w.name for w in works]
-
- if not components:
- components = app_component_names
-
- else:
-
- def add_prefix(c: str) -> str:
- if c == "flow":
- return c
- if not c.startswith("root."):
- return "root." + c
- return c
-
- components = [add_prefix(c) for c in components]
-
- for component in components:
- if component not in app_component_names:
- raise click.ClickException(f"Component '{component}' does not exist in app {app_name}.")
-
- log_reader = _app_logs_reader(
- logs_api_client=_LightningLogsSocketAPI(client.api_client),
- project_id=project.project_id,
- app_id=apps[app_name].id,
- component_names=components,
- follow=follow,
- )
-
- rich_colors = list(ANSI_COLOR_NAMES)
- colors = {c: rich_colors[i + 1] for i, c in enumerate(components)}
-
- for log_event in log_reader:
- date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
- color = colors[log_event.component_name]
- rich.print(f"[{color}]{log_event.component_name}[/{color}] {date} {log_event.message}")
diff --git a/src/lightning/app/cli/commands/ls.py b/src/lightning/app/cli/commands/ls.py
deleted file mode 100644
index e16a354f66d8c..0000000000000
--- a/src/lightning/app/cli/commands/ls.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import contextlib
-import os
-import sys
-from contextlib import nullcontext
-from typing import Generator, List, Optional
-
-import click
-import lightning_cloud
-import rich
-from lightning_cloud.openapi import Externalv1LightningappInstance
-from rich.console import Console
-from rich.live import Live
-from rich.spinner import Spinner
-from rich.text import Text
-
-from lightning.app.cli.connect.app import _LIGHTNING_CONNECTION_FOLDER
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cli_helpers import _error_and_exit
-from lightning.app.utilities.network import LightningClient
-
-_FOLDER_COLOR = "sky_blue1"
-_FILE_COLOR = "white"
-
-logger = Logger(__name__)
-
-
-@click.argument("path", required=False)
-def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) -> List[str]:
- """List the contents of a folder in the Lightning Cloud Filesystem."""
- from lightning.app.cli.commands.cd import _CD_FILE
-
- if sys.platform == "win32":
- _error_and_exit("`ls` isn't supported on windows. Open an issue on Github.")
-
- root = "/"
-
- context = (
- Live(Spinner("point", text=Text("pending...", style="white")), transient=True) if use_live else nullcontext()
- )
-
- with context:
- if not os.path.exists(_LIGHTNING_CONNECTION_FOLDER):
- os.makedirs(_LIGHTNING_CONNECTION_FOLDER)
-
- if not os.path.exists(_CD_FILE):
- with open(_CD_FILE, "w") as f:
- f.write(root + "\n")
- else:
- with open(_CD_FILE) as f:
- lines = f.readlines()
- root = lines[0].replace("\n", "")
-
- client = LightningClient(retry=False)
- projects = client.projects_service_list_memberships()
-
- if root == "/":
- project_names = [project.name for project in projects.memberships]
- if print:
- _print_names_with_colors(project_names, [_FOLDER_COLOR] * len(project_names))
- return project_names
-
- # Note: Root format has the following structure:
- # /{PROJECT_NAME}/{APP_NAME}/{ARTIFACTS_PATHS}
- splits = root.split("/")[1:]
-
- project = [project for project in projects.memberships if project.name == splits[0]]
-
- # This happens if the user changes cluster and the project doesn't exit.
- if len(project) == 0:
- return _error_and_exit(
- f"There isn't any Lightning Project matching the name {splits[0]}." " HINT: Use `lightning_app cd`."
- )
-
- project_id = project[0].project_id
-
- # Parallelise calls
- lit_apps = client.lightningapp_instance_service_list_lightningapp_instances(
- project_id=project_id, async_req=True
- )
- lit_cloud_spaces = client.cloud_space_service_list_cloud_spaces(project_id=project_id, async_req=True)
-
- lit_apps = lit_apps.get().lightningapps
- lit_cloud_spaces = lit_cloud_spaces.get().cloudspaces
-
- if len(splits) == 1:
- apps = [lit_app.name for lit_app in lit_apps]
- cloud_spaces = [lit_cloud_space.name for lit_cloud_space in lit_cloud_spaces]
- ressource_names = sorted(set(cloud_spaces + apps))
- if print:
- _print_names_with_colors(ressource_names, [_FOLDER_COLOR] * len(ressource_names))
- return ressource_names
-
- lit_ressources = [lit_resource for lit_resource in lit_cloud_spaces if lit_resource.name == splits[1]]
-
- if len(lit_ressources) == 0:
- lit_ressources = [lit_resource for lit_resource in lit_apps if lit_resource.name == splits[1]]
-
- if len(lit_ressources) == 0:
- _error_and_exit(f"There isn't any Lightning Ressource matching the name {splits[1]}.")
-
- lit_resource = lit_ressources[0]
-
- app_paths = []
- app_colors = []
-
- cloud_spaces_paths = []
- cloud_spaces_colors = []
-
- depth = len(splits)
-
- prefix = "/".join(splits[2:])
- prefix = _get_prefix(prefix, lit_resource)
-
- for artifact in _collect_artifacts(client=client, project_id=project_id, prefix=prefix):
- if str(artifact.filename).startswith("/"):
- artifact.filename = artifact.filename[1:]
-
- path = os.path.join(project_id, prefix[1:], artifact.filename)
-
- artifact_splits = path.split("/")
-
- if len(artifact_splits) <= depth + 1:
- continue
-
- path = artifact_splits[depth + 1]
-
- paths = app_paths if isinstance(lit_resource, Externalv1LightningappInstance) else cloud_spaces_paths
- colors = app_colors if isinstance(lit_resource, Externalv1LightningappInstance) else cloud_spaces_colors
-
- if path not in paths:
- paths.append(path)
-
- # display files otherwise folders
- colors.append(_FILE_COLOR if len(artifact_splits) == depth + 1 else _FOLDER_COLOR)
-
- if print:
- if app_paths and cloud_spaces_paths:
- if app_paths:
- rich.print("Lightning App")
- _print_names_with_colors(app_paths, app_colors)
-
- if cloud_spaces_paths:
- rich.print("Lightning CloudSpaces")
- _print_names_with_colors(cloud_spaces_paths, cloud_spaces_colors)
- else:
- _print_names_with_colors(app_paths + cloud_spaces_paths, app_colors + cloud_spaces_colors)
-
- return app_paths + cloud_spaces_paths
-
-
-def _add_colors(filename: str, color: Optional[str] = None) -> str:
- return f"[{color}]{filename}[/{color}]"
-
-
-def _print_names_with_colors(names: List[str], colors: List[str], padding: int = 5) -> None:
- console = Console()
- width = console.width
-
- max_L = max([len(name) for name in names] + [0]) + padding
-
- use_spacing = False
-
- if max_L * len(names) < width:
- use_spacing = True
-
- num_cols = width // max_L
-
- columns = {}
- for index, (name, color) in enumerate(zip(names, colors)):
- row = index // num_cols
- if row not in columns:
- columns[row] = []
- columns[row].append((name, color))
-
- for row_index in sorted(columns):
- row = ""
- for name, color in columns[row_index]:
- spacing = padding if use_spacing else max_L - len(name)
- spaces = " " * spacing
- row += _add_colors(name, color) + spaces
- rich.print(row)
-
-
-def _collect_artifacts(
- client: LightningClient,
- project_id: str,
- prefix: str = "",
- page_token: Optional[str] = "",
- cluster_id: Optional[str] = None,
- page_size: int = 100_000,
- tokens=None,
- include_download_url: bool = False,
-) -> Generator:
- if tokens is None:
- tokens = []
-
- if cluster_id is None:
- clusters = client.projects_service_list_project_cluster_bindings(project_id)
- for cluster in clusters.clusters:
- yield from _collect_artifacts(
- client,
- project_id,
- prefix=prefix,
- cluster_id=cluster.cluster_id,
- page_token=page_token,
- tokens=tokens,
- page_size=page_size,
- include_download_url=include_download_url,
- )
- else:
- if page_token in tokens:
- return
-
- # Note: This is triggered when the request is wrong.
- # This is currently happening due to looping through the user clusters.
- with contextlib.suppress(lightning_cloud.openapi.rest.ApiException):
- response = client.lightningapp_instance_service_list_project_artifacts(
- project_id,
- prefix=prefix,
- cluster_id=cluster_id,
- page_token=page_token,
- include_download_url=include_download_url,
- page_size=str(page_size),
- )
- for artifact in response.artifacts:
- if ".lightning-app-sync" in artifact.filename:
- continue
- yield artifact
-
- if response.next_page_token:
- tokens.append(page_token)
- yield from _collect_artifacts(
- client,
- project_id,
- prefix=prefix,
- cluster_id=cluster_id,
- page_token=response.next_page_token,
- tokens=tokens,
- )
-
-
-def _add_resource_prefix(prefix: str, resource_path: str):
- if resource_path in prefix:
- return prefix
- prefix = os.path.join(resource_path, prefix)
- if not prefix.startswith("/"):
- prefix = "/" + prefix
- return prefix
-
-
-def _get_prefix(prefix: str, lit_resource) -> str:
- if isinstance(lit_resource, Externalv1LightningappInstance):
- return _add_resource_prefix(prefix, f"lightningapps/{lit_resource.id}")
-
- return _add_resource_prefix(prefix, f"cloudspaces/{lit_resource.id}")
diff --git a/src/lightning/app/cli/commands/pwd.py b/src/lightning/app/cli/commands/pwd.py
deleted file mode 100644
index 7768309e4e6bb..0000000000000
--- a/src/lightning/app/cli/commands/pwd.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-from rich.live import Live
-from rich.spinner import Spinner
-from rich.text import Text
-
-from lightning.app.cli.commands.cd import _CD_FILE
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-def pwd() -> str:
- """Print your current working directory in the Lightning Cloud filesystem."""
- if sys.platform == "win32":
- print("`pwd` isn't supported on windows. Open an issue on Github.")
- sys.exit(0)
-
- with Live(Spinner("point", text=Text("pending...", style="white")), transient=True):
- root = _pwd()
-
- print(root)
-
- return root
-
-
-def _pwd() -> str:
- root = "/"
-
- if not os.path.exists(_CD_FILE):
- with open(_CD_FILE, "w") as f:
- f.write(root + "\n")
- else:
- with open(_CD_FILE) as f:
- lines = f.readlines()
- root = lines[0].replace("\n", "")
-
- return root
diff --git a/src/lightning/app/cli/commands/rm.py b/src/lightning/app/cli/commands/rm.py
deleted file mode 100644
index 587cc50469131..0000000000000
--- a/src/lightning/app/cli/commands/rm.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import contextlib
-import os
-
-import click
-import lightning_cloud
-import rich
-
-from lightning.app.cli.commands.ls import _add_colors, _get_prefix
-from lightning.app.cli.commands.pwd import _pwd
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cli_helpers import _error_and_exit
-from lightning.app.utilities.network import LightningClient
-
-logger = Logger(__name__)
-
-
-@click.argument("rm_path", required=True)
-@click.option("-r", required=False, hidden=True)
-@click.option("--recursive", required=False, hidden=True)
-def rm(rm_path: str, r: bool = False, recursive: bool = False) -> None:
- """Delete files on the Lightning Cloud filesystem."""
- root = _pwd()
-
- if rm_path in (".", ".."):
- return _error_and_exit('rm "." and ".." may not be removed')
-
- if ".." in rm_path:
- return _error_and_exit('rm ".." or higher may not be removed')
-
- root = os.path.join(root, rm_path)
- splits = [split for split in root.split("/") if split != ""]
-
- if root == "/" or len(splits) == 1:
- return _error_and_exit("rm at the project level isn't supported")
-
- client = LightningClient(retry=False)
- projects = client.projects_service_list_memberships()
-
- project = [project for project in projects.memberships if project.name == splits[0]]
-
- # This happens if the user changes cluster and the project doesn't exist.
- if len(project) == 0:
- return _error_and_exit(
- f"There isn't any Lightning Project matching the name {splits[0]}." " HINT: Use `lightning cd`."
- )
-
- project_id = project[0].project_id
-
- # Parallelise calls
- lit_apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id, async_req=True)
- lit_cloud_spaces = client.cloud_space_service_list_cloud_spaces(project_id=project_id, async_req=True)
-
- lit_apps = lit_apps.get().lightningapps
- lit_cloud_spaces = lit_cloud_spaces.get().cloudspaces
-
- lit_ressources = [lit_resource for lit_resource in lit_cloud_spaces if lit_resource.name == splits[1]]
-
- if len(lit_ressources) == 0:
- lit_ressources = [lit_resource for lit_resource in lit_apps if lit_resource.name == splits[1]]
-
- if len(lit_ressources) == 0:
- _error_and_exit(f"There isn't any Lightning Ressource matching the name {splits[1]}.")
-
- lit_resource = lit_ressources[0]
-
- prefix = "/".join(splits[2:])
- prefix = _get_prefix(prefix, lit_resource)
-
- clusters = client.projects_service_list_project_cluster_bindings(project_id)
- succeeded = False
-
- for cluster in clusters.clusters:
- with contextlib.suppress(lightning_cloud.openapi.rest.ApiException):
- client.lightningapp_instance_service_delete_project_artifact(
- project_id=project_id,
- cluster_id=cluster.cluster_id,
- filename=prefix,
- )
- succeeded = True
- break
-
- prefix = os.path.join(*splits)
-
- if succeeded:
- rich.print(_add_colors(f"Successfuly deleted `{prefix}`.", color="green"))
- return None
-
- return _error_and_exit(f"No file or folder named `{prefix}` was found.")
diff --git a/src/lightning/app/cli/component-template/.github/workflows/ci-testing.yml b/src/lightning/app/cli/component-template/.github/workflows/ci-testing.yml
deleted file mode 100644
index 16abb8f418b89..0000000000000
--- a/src/lightning/app/cli/component-template/.github/workflows/ci-testing.yml
+++ /dev/null
@@ -1,79 +0,0 @@
-name: CI testing
-
-# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
-on:
- # Trigger the workflow on push or pull request, but only for the master branch
- push:
- branches: [main]
- pull_request:
- branches: [main]
-
-jobs:
- pytest:
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-20.04, macOS-11, windows-2019]
- python-version: [3.8]
-
- # Timeout: https://stackoverflow.com/a/59076067/4521646
- timeout-minutes: 35
-
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
-
- # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646
- - name: Setup macOS
- if: runner.os == 'macOS'
- run: |
- brew install libomp # https://github.com/pytorch/pytorch/issues/20030
-
- - name: Get pip cache dir
- id: pip-cache
- run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
-
- - name: Cache pip
- uses: actions/cache@v2
- with:
- path: ${{ steps.pip-cache.outputs.dir }}
- key: ${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}
- restore-keys: |
- ${{ runner.os }}-py${{ matrix.python-version }}-
-
- - name: Clone Template React UI Repo
- uses: actions/checkout@v3
- with:
- repository: Lightning-AI/lightning
- token: ${{ secrets.PAT_GHOST }}
- ref: "master"
- path: lightning
-
- - name: Install Lightning
- run: |
- cd lightning
- pip install -r requirements.txt
- pip install -e .
- shell: bash
-
- - name: Install dependencies
- run: |
- python --version
- pip --version
- pip install --requirement requirements.txt --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install --requirement tests/requirements.txt --quiet
- pip list
- shell: bash
-
- - name: Tests
- run: |
- coverage run --source placeholdername -m py.test placeholdername tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml
-
- - name: Statistics
- if: success()
- run: |
- coverage report
diff --git a/src/lightning/app/cli/component-template/.gitignore b/src/lightning/app/cli/component-template/.gitignore
deleted file mode 100644
index 70ba25888435f..0000000000000
--- a/src/lightning/app/cli/component-template/.gitignore
+++ /dev/null
@@ -1,157 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-*install-app*
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-
-*.egg
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Sphinx documentation
-docs/_build/
-docs/source/api/
-docs/source/*.md
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.local_env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# PyCharm
-.idea/
-
-# Lightning logs
-lightning_logs
-*.gz
-.DS_Store
-.*_submit.py
-.vscode
-
-MNIST
-*.pt
-.storage/
-.shared/
-infra
-data
-coverage.*
-# Frontend build artifacts
-*lightning/app/ui*
-gradio_cached_examples
-/docs/source/api_reference/generated/*
-examples/my_own_leaderboard/submissions/*
-docs/source/api_reference/generated/*
-*.ckpt
-redis-stable
-node_modules
-*.rdb
-*.webm
-*hars
-examples/quick_start/*
-examples/quick_start
-examples/template_react_ui/*
-examples/template_react_ui
-# Ignore external components
-lightning/app/components/*
-!lightning/app/components/python
-!lightning/app/components/serve
-!lightning/app/components/__init__.py
-!lightning/app/components/README.md
-train_script.py
-*return_values*
-scratch
-storage
diff --git a/src/lightning/app/cli/component-template/LICENSE b/src/lightning/app/cli/component-template/LICENSE
deleted file mode 100644
index 261eeb9e9f8b2..0000000000000
--- a/src/lightning/app/cli/component-template/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/src/lightning/app/cli/component-template/README.md b/src/lightning/app/cli/component-template/README.md
deleted file mode 100644
index 1d700e286461b..0000000000000
--- a/src/lightning/app/cli/component-template/README.md
+++ /dev/null
@@ -1,35 +0,0 @@
-# placeholdername component
-
-This ⚡ [Lightning component](https://lightning.ai/) ⚡ was generated automatically with:
-
-```bash
-lightning_app init component placeholdername
-```
-
-## To run placeholdername
-
-First, install placeholdername (warning: this component has not been officially approved on the lightning gallery):
-
-```bash
-lightning_app install component https://github.com/theUser/placeholdername
-```
-
-Once the app is installed, use it in an app:
-
-```python
-from placeholdername import TemplateComponent
-import lightning as L
-
-
-class LitApp(L.LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.placeholdername = TemplateComponent()
-
- def run(self):
- print("this is a simple Lightning app to verify your component is working as expected")
- self.placeholdername.run()
-
-
-app = L.LightningApp(LitApp())
-```
diff --git a/src/lightning/app/cli/component-template/app.py b/src/lightning/app/cli/component-template/app.py
deleted file mode 100644
index 0a10532204043..0000000000000
--- a/src/lightning/app/cli/component-template/app.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from lightning.app import LightningApp, LightningFlow
-from placeholdername import TemplateComponent
-
-
-class LitApp(LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.placeholdername = TemplateComponent()
-
- def run(self):
- print("this is a simple Lightning app to verify your component is working as expected")
- self.placeholdername.run()
-
-
-app = LightningApp(LitApp())
diff --git a/src/lightning/app/cli/component-template/placeholdername/__init__.py b/src/lightning/app/cli/component-template/placeholdername/__init__.py
deleted file mode 100644
index 92b4ef47d8062..0000000000000
--- a/src/lightning/app/cli/component-template/placeholdername/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from placeholdername.component import TemplateComponent
-
-__all__ = ["TemplateComponent"]
diff --git a/src/lightning/app/cli/component-template/placeholdername/component.py b/src/lightning/app/cli/component-template/placeholdername/component.py
deleted file mode 100644
index 251a4e10c6a9f..0000000000000
--- a/src/lightning/app/cli/component-template/placeholdername/component.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from lightning.app import LightningWork
-
-
-class TemplateComponent(LightningWork):
- def __init__(self) -> None:
- super().__init__()
- self.value = 0
-
- def run(self):
- self.value += 1
- print("welcome to your work component")
- print("this is running inside a work")
diff --git a/src/lightning/app/cli/component-template/requirements.txt b/src/lightning/app/cli/component-template/requirements.txt
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/component-template/setup.py b/src/lightning/app/cli/component-template/setup.py
deleted file mode 100644
index 78631901190b2..0000000000000
--- a/src/lightning/app/cli/component-template/setup.py
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/usr/bin/env python
-
-from setuptools import find_packages, setup
-
-setup(
- name="placeholdername",
- version="0.0.0",
- description="⚡ Lightning component ⚡ generated with command: lightning_app init component",
- author="",
- author_email="",
- # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
- url="https://github.com/Lightning-AI/lightning-component-template",
- install_requires=[],
- packages=find_packages(),
-)
diff --git a/src/lightning/app/cli/component-template/tests/README.md b/src/lightning/app/cli/component-template/tests/README.md
deleted file mode 100644
index bef681691185d..0000000000000
--- a/src/lightning/app/cli/component-template/tests/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Run tests
-
-To run the tests:
-
-```bash
-# go to your component folder
-cd placeholdername
-
-# go to tests folder
-cd tests
-
-# install testing deps
-pip install -r requirements.txt
-
-# run tests
-pytest .
-```
diff --git a/src/lightning/app/cli/component-template/tests/__init__.py b/src/lightning/app/cli/component-template/tests/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/component-template/tests/requirements.txt b/src/lightning/app/cli/component-template/tests/requirements.txt
deleted file mode 100644
index 3185d1c44f033..0000000000000
--- a/src/lightning/app/cli/component-template/tests/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-coverage
-codecov>=2.1
-pytest>=5.0.0
-pytest-cov
-pytest-flake8
-flake8
-check-manifest
-twine==4.0.1
diff --git a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py
deleted file mode 100644
index 6b9c28845749c..0000000000000
--- a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py
+++ /dev/null
@@ -1,14 +0,0 @@
-r"""To test a lightning component:
-
-1. Init the component.
-2. call .run()
-
-"""
-
-from placeholdername.component import TemplateComponent
-
-
-def test_placeholder_component():
- messenger = TemplateComponent()
- messenger.run()
- assert messenger.value == 1
diff --git a/src/lightning/app/cli/connect/__init__.py b/src/lightning/app/cli/connect/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/connect/app.py b/src/lightning/app/cli/connect/app.py
deleted file mode 100644
index e3f0d0b151bb5..0000000000000
--- a/src/lightning/app/cli/connect/app.py
+++ /dev/null
@@ -1,387 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import os
-import shutil
-import sys
-from subprocess import Popen
-from typing import List, Optional, Tuple
-
-import click
-import psutil
-from lightning_utilities.core.imports import package_available
-from rich.progress import Progress
-
-from lightning.app.utilities.cli_helpers import _get_app_display_name, _LightningAppOpenAPIRetriever
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.enum import OpenAPITags
-from lightning.app.utilities.log import get_logfile
-from lightning.app.utilities.network import LightningClient
-
-_HOME = os.path.expanduser("~")
-_PPID = os.getenv("LIGHTNING_CONNECT_PPID", str(psutil.Process(os.getpid()).ppid()))
-_LIGHTNING_CONNECTION = os.path.join(_HOME, ".lightning", "lightning_connection")
-_LIGHTNING_CONNECTION_FOLDER = os.path.join(_LIGHTNING_CONNECTION, _PPID)
-
-
-@click.argument("app_name_or_id", required=True)
-def connect_app(app_name_or_id: str):
- """Connect your local terminal to a running lightning app.
-
- After connecting, the lightning CLI will respond to commands exposed by the app.
-
- Example:
-
- \b
- # connect to an app named pizza-cooker-123
- lightning connect pizza-cooker-123
- \b
- # this will now show the commands exposed by pizza-cooker-123
- lightning --help
- \b
- # while connected, you can run the cook-pizza command exposed
- # by pizza-cooker-123.BTW, this should arguably generate an exception :-)
- lightning cook-pizza --flavor pineapple
- \b
- # once done, disconnect and go back to the standard lightning CLI commands
- lightning disconnect
-
- """
- from lightning.app.utilities.commands.base import _download_command
-
- _clean_lightning_connection()
-
- if not os.path.exists(_LIGHTNING_CONNECTION_FOLDER):
- os.makedirs(_LIGHTNING_CONNECTION_FOLDER)
-
- connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")
-
- matched_connection_path = _scan_lightning_connections(app_name_or_id)
-
- if os.path.exists(connected_file):
- with open(connected_file) as f:
- result = f.readlines()[0].replace("\n", "")
-
- if result == app_name_or_id:
- if app_name_or_id == "localhost":
- click.echo("You are connected to the local Lightning App.")
- else:
- click.echo(f"You are already connected to the cloud Lightning App: {app_name_or_id}.")
- else:
- disconnect_app()
- connect_app(app_name_or_id)
-
- elif app_name_or_id.startswith("localhost"):
- with Progress() as progress_bar:
- connecting = progress_bar.add_task("[magenta]Setting things up for you...", total=1.0)
-
- if app_name_or_id != "localhost":
- raise Exception("You need to pass localhost to connect to the local Lightning App.")
-
- retriever = _LightningAppOpenAPIRetriever(None)
-
- if retriever.api_commands is None:
- raise Exception(f"Connection wasn't successful. Is your app {app_name_or_id} running?")
-
- increment = 1 / (1 + len(retriever.api_commands))
-
- progress_bar.update(connecting, advance=increment)
-
- commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
- if not os.path.exists(commands_folder):
- os.makedirs(commands_folder)
-
- _write_commands_metadata(retriever.api_commands)
-
- with open(os.path.join(commands_folder, "openapi.json"), "w") as f:
- json.dump(retriever.openapi, f)
-
- _install_missing_requirements(retriever)
-
- for command_name, metadata in retriever.api_commands.items():
- if "cls_path" in metadata:
- target_file = os.path.join(commands_folder, f"{command_name.replace(' ', '_')}.py")
- _download_command(
- command_name,
- metadata["cls_path"],
- metadata["cls_name"],
- None,
- target_file=target_file,
- )
- else:
- with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f:
- f.write(command_name)
-
- progress_bar.update(connecting, advance=increment)
-
- with open(connected_file, "w") as f:
- f.write(app_name_or_id + "\n")
-
- click.echo("The lightning App CLI now responds to app commands. Use 'lightning_app --help' to see them.")
- click.echo(" ")
-
- Popen(
- f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning_app --help",
- shell=True,
- stdout=sys.stdout,
- stderr=sys.stderr,
- ).wait()
-
- elif matched_connection_path:
- matched_connected_file = os.path.join(matched_connection_path, "connect.txt")
- matched_commands = os.path.join(matched_connection_path, "commands")
- if os.path.isdir(matched_commands):
- commands = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
- shutil.copytree(matched_commands, commands)
- shutil.copy(matched_connected_file, connected_file)
-
- click.echo("The lightning App CLI now responds to app commands. Use 'lightning_app --help' to see them.")
- click.echo(" ")
-
- Popen(
- f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning_app --help",
- shell=True,
- stdout=sys.stdout,
- stderr=sys.stderr,
- ).wait()
-
- else:
- with Progress() as progress_bar:
- connecting = progress_bar.add_task("[magenta]Setting things up for you...", total=1.0)
-
- retriever = _LightningAppOpenAPIRetriever(app_name_or_id)
-
- if not retriever.api_commands:
- client = LightningClient(retry=False)
- project = _get_project(client)
- apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id)
- click.echo(
- "We didn't find a matching App. Here are the available Apps that you can "
- f"connect to {[_get_app_display_name(app) for app in apps.lightningapps]}."
- )
- return
-
- increment = 1 / (1 + len(retriever.api_commands))
-
- progress_bar.update(connecting, advance=increment)
-
- _install_missing_requirements(retriever)
-
- commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
- if not os.path.exists(commands_folder):
- os.makedirs(commands_folder)
-
- _write_commands_metadata(retriever.api_commands)
-
- for command_name, metadata in retriever.api_commands.items():
- if "cls_path" in metadata:
- target_file = os.path.join(commands_folder, f"{command_name}.py")
- _download_command(
- command_name,
- metadata["cls_path"],
- metadata["cls_name"],
- retriever.app_id,
- target_file=target_file,
- )
- else:
- with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f:
- f.write(command_name)
-
- progress_bar.update(connecting, advance=increment)
-
- with open(connected_file, "w") as f:
- f.write(retriever.app_name + "\n")
- f.write(retriever.app_id + "\n")
-
- click.echo("The lightning App CLI now responds to app commands. Use 'lightning_app --help' to see them.")
- click.echo(" ")
-
- Popen(
- f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning_app --help",
- shell=True,
- stdout=sys.stdout,
- stderr=sys.stderr,
- ).wait()
-
-
-def disconnect_app(logout: bool = False):
- """Disconnect from an App."""
- _clean_lightning_connection()
-
- connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")
- if os.path.exists(connected_file):
- with open(connected_file) as f:
- result = f.readlines()[0].replace("\n", "")
-
- os.remove(connected_file)
- commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
- if os.path.exists(commands_folder):
- shutil.rmtree(commands_folder)
-
- if result == "localhost":
- click.echo("You are disconnected from the local Lightning App.")
- else:
- click.echo(f"You are disconnected from the cloud Lightning App: {result}.")
- else:
- if not logout:
- click.echo(
- "You aren't connected to any Lightning App. "
- "Please use `lightning_app connect app_name_or_id` to connect to one."
- )
-
-
-def _read_connected_file(connected_file):
- if os.path.exists(connected_file):
- with open(connected_file) as f:
- lines = [line.replace("\n", "") for line in f.readlines()]
- if len(lines) == 2:
- return lines[0], lines[1]
- return lines[0], None
- return None, None
-
-
-def _retrieve_connection_to_an_app() -> Tuple[Optional[str], Optional[str]]:
- connected_file = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "connect.txt")
- return _read_connected_file(connected_file)
-
-
-def _get_commands_folder() -> str:
- return os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands")
-
-
-def _write_commands_metadata(api_commands):
- metadata = dict(api_commands.items())
- metadata_path = os.path.join(_get_commands_folder(), ".meta.json")
- with open(metadata_path, "w") as f:
- json.dump(metadata, f)
-
-
-def _get_commands_metadata():
- metadata_path = os.path.join(_get_commands_folder(), ".meta.json")
- with open(metadata_path) as f:
- return json.load(f)
-
-
-def _resolve_command_path(command: str) -> str:
- return os.path.join(_get_commands_folder(), f"{command}.py")
-
-
-def _list_app_commands(echo: bool = True) -> List[str]:
- metadata = _get_commands_metadata()
- metadata = {key.replace("_", " "): value for key, value in metadata.items()}
-
- command_names = sorted(metadata.keys())
- if not command_names:
- click.echo("The current Lightning App doesn't have commands.")
- return []
-
- app_info = metadata[command_names[0]].get("app_info", None)
-
- title, description, on_connect_end = "Lightning", None, None
- if app_info:
- title = app_info.get("title")
- description = app_info.get("description")
- on_connect_end = app_info.get("on_connect_end")
-
- if echo:
- click.echo(f"{title} App")
- if description:
- click.echo("")
- click.echo("Description:")
- if description.endswith("\n"):
- description = description[:-2]
- click.echo(f" {description}")
- click.echo("")
- click.echo("Commands:")
- max_length = max(len(n) for n in command_names)
- for command_name in command_names:
- padding = (max_length + 1 - len(command_name)) * " "
- click.echo(f" {command_name}{padding}{metadata[command_name].get('description', '')}")
- if "LIGHTNING_CONNECT_PPID" in os.environ and on_connect_end:
- if on_connect_end.endswith("\n"):
- on_connect_end = on_connect_end[:-2]
- click.echo(on_connect_end)
- return command_names
-
-
-def _install_missing_requirements(
- retriever: _LightningAppOpenAPIRetriever,
- fail_if_missing: bool = False,
-):
- requirements = set()
- for metadata in retriever.api_commands.values():
- if metadata["tag"] == OpenAPITags.APP_CLIENT_COMMAND:
- for req in metadata.get("requirements", []) or []:
- requirements.add(req)
-
- if requirements:
- missing_requirements = []
- for req in requirements:
- if not (package_available(req) or package_available(req.replace("-", "_"))):
- missing_requirements.append(req)
-
- if missing_requirements:
- if fail_if_missing:
- missing_requirements = " ".join(missing_requirements)
- print(f"The command failed as you are missing the following requirements: `{missing_requirements}`.")
- sys.exit(0)
-
- for req in missing_requirements:
- std_out_out = get_logfile("output.log")
- with open(std_out_out, "wb") as stdout:
- Popen(
- f"{sys.executable} -m pip install {req}",
- shell=True,
- stdout=stdout,
- stderr=stdout,
- ).wait()
- os.remove(std_out_out)
-
-
-def _clean_lightning_connection():
- if not os.path.exists(_LIGHTNING_CONNECTION):
- return
-
- for ppid in os.listdir(_LIGHTNING_CONNECTION):
- try:
- psutil.Process(int(ppid))
- except (psutil.NoSuchProcess, ValueError):
- connection = os.path.join(_LIGHTNING_CONNECTION, str(ppid))
- if os.path.exists(connection):
- shutil.rmtree(connection)
-
-
-def _scan_lightning_connections(app_name_or_id):
- if not os.path.exists(_LIGHTNING_CONNECTION):
- return None
-
- for ppid in os.listdir(_LIGHTNING_CONNECTION):
- try:
- psutil.Process(int(ppid))
- except (psutil.NoSuchProcess, ValueError):
- continue
-
- connection_path = os.path.join(_LIGHTNING_CONNECTION, str(ppid))
-
- connected_file = os.path.join(connection_path, "connect.txt")
- curr_app_name, curr_app_id = _read_connected_file(connected_file)
-
- if not curr_app_name:
- continue
-
- if app_name_or_id in (curr_app_name, curr_app_id):
- return connection_path
-
- return None
diff --git a/src/lightning/app/cli/connect/data.py b/src/lightning/app/cli/connect/data.py
deleted file mode 100644
index 6069432ddb187..0000000000000
--- a/src/lightning/app/cli/connect/data.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright The Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import ast
-import sys
-
-import click
-import lightning_cloud
-import rich
-from rich.live import Live
-from rich.spinner import Spinner
-from rich.text import Text
-
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cli_helpers import _error_and_exit
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.network import LightningClient
-
-logger = Logger(__name__)
-
-
-@click.argument("name", required=True)
-@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True)
-@click.option(
- "--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True
-)
-@click.option(
- "--secret_arn_name",
- help="The name of role stored as a secret on Lightning AI to access your data. "
- "Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. "
- "Example: `my_role`.",
- required=False,
-)
-@click.option(
- "--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False
-)
-@click.option("--project_name", help="The project name on which to create the data connection.", required=False)
-def connect_data(
- name: str,
- region: str,
- source: str,
- secret_arn_name: str = "",
- destination: str = "",
- project_name: str = "",
-) -> None:
- """Create a new data connection."""
-
- from lightning_cloud.openapi import Create, V1AwsDataConnection
-
- if sys.platform == "win32":
- _error_and_exit("Data connection isn't supported on windows. Open an issue on Github.")
-
- with Live(Spinner("point", text=Text("pending...", style="white")), transient=True) as live:
- live.stop()
-
- client = LightningClient(retry=False)
- projects = client.projects_service_list_memberships()
-
- project_id = None
-
- for project in projects.memberships:
- if project.name == project_name:
- project_id = project.project_id
- break
-
- if project_id is None:
- project_id = _get_project(client).project_id
-
- if not source.startswith("s3://"):
- return _error_and_exit(
- "Only public S3 folders are supported for now. Please, open a Github issue with your use case."
- )
-
- try:
- client.data_connection_service_create_data_connection(
- body=Create(
- name=name,
- aws=V1AwsDataConnection(
- region=region,
- source=source,
- destination=destination,
- secret_arn_name=secret_arn_name,
- ),
- ),
- project_id=project_id,
- )
-
- # Note: Expose through lightning show data {DATA_NAME}
- # response = client.data_connection_service_list_data_connection_artifacts(
- # project_id=project_id,
- # id=response.id,
- # )
- except lightning_cloud.openapi.rest.ApiException as e:
- message = ast.literal_eval(e.body.decode("utf-8"))["message"]
- _error_and_exit(f"The data connection creation failed. Message: {message}")
-
- rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.")
- return None
diff --git a/src/lightning/app/cli/core.py b/src/lightning/app/cli/core.py
deleted file mode 100644
index 6d54c31426ee1..0000000000000
--- a/src/lightning/app/cli/core.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-
-from rich.table import Table
-
-
-class Formatable(abc.ABC):
- @abc.abstractmethod
- def as_table(self) -> Table:
- pass
-
- @abc.abstractmethod
- def as_json(self) -> str:
- pass
diff --git a/src/lightning/app/cli/lightning_cli.py b/src/lightning/app/cli/lightning_cli.py
deleted file mode 100644
index 8f61554019652..0000000000000
--- a/src/lightning/app/cli/lightning_cli.py
+++ /dev/null
@@ -1,395 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-from pathlib import Path
-from typing import Tuple, Union
-
-import click
-from requests.exceptions import ConnectionError
-
-import lightning.app.core.constants as constants
-from lightning.app import __version__ as ver
-from lightning.app.cli import cmd_init, cmd_install, cmd_pl_init, cmd_react_ui_init
-from lightning.app.cli.commands.app_commands import _run_app_command
-from lightning.app.cli.commands.cd import cd
-from lightning.app.cli.commands.cp import cp
-from lightning.app.cli.commands.logs import logs
-from lightning.app.cli.commands.ls import ls
-from lightning.app.cli.commands.pwd import pwd
-from lightning.app.cli.commands.rm import rm
-from lightning.app.cli.connect.app import (
- _list_app_commands,
- _retrieve_connection_to_an_app,
- connect_app,
- disconnect_app,
-)
-from lightning.app.cli.connect.data import connect_data
-from lightning.app.cli.lightning_cli_delete import delete
-from lightning.app.cli.lightning_cli_launch import launch
-from lightning.app.cli.lightning_cli_list import get_list
-from lightning.app.core.constants import ENABLE_APP_COMMENT_COMMAND_EXECUTION, get_lightning_cloud_url
-from lightning.app.runners.cloud import CloudRuntime
-from lightning.app.runners.runtime import dispatch
-from lightning.app.runners.runtime_type import RuntimeType
-from lightning.app.utilities.app_commands import run_app_commands
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cli_helpers import (
- _check_environment_and_redirect,
- _check_version_and_upgrade,
- _format_input_env_variables,
-)
-from lightning.app.utilities.exceptions import _ApiExceptionHandler
-from lightning.app.utilities.login import Auth
-from lightning.app.utilities.port import _find_lit_app_port
-
-logger = Logger(__name__)
-
-
-def main() -> None:
- # Check environment and versions if not in the cloud and not testing
- is_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
- if not is_testing and "LIGHTNING_APP_STATE_URL" not in os.environ:
- try:
- # Enforce running in PATH Python
- _check_environment_and_redirect()
-
- # Check for newer versions and upgrade
- _check_version_and_upgrade()
- except SystemExit:
- raise
- except Exception:
- # Note: We intentionally ignore all exceptions here so that we never panic if one of the above calls fails.
- # If they fail for some reason users should still be able to continue with their command.
- click.echo(
- "We encountered an unexpected problem while checking your environment."
- "We will still proceed with the command, however, there is a chance that errors may occur."
- )
-
- # 1: Handle connection to a Lightning App.
- if len(sys.argv) > 1 and sys.argv[1] in ("connect", "disconnect", "logout"):
- _main()
- else:
- # 2: Collect the connection a Lightning App.
- app_name, app_id = _retrieve_connection_to_an_app()
- if app_name:
- # 3: Handle development use case.
- is_local_app = app_name == "localhost"
- if sys.argv[1:3] == ["run", "app"] or (
- sys.argv[1:3] == ["show", "logs"] and "show logs" not in _list_app_commands(False)
- ):
- _main()
- else:
- if is_local_app:
- message = "You are connected to the local Lightning App."
- else:
- message = f"You are connected to the cloud Lightning App: {app_name}."
-
- if (len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]) or len(sys.argv) == 1:
- _list_app_commands()
- else:
- _run_app_command(app_name, app_id)
-
- click.echo()
- click.echo(message + " Return to the primary CLI with `lightning_app disconnect`.")
- else:
- _main()
-
-
-@click.group(cls=_ApiExceptionHandler)
-@click.version_option(ver)
-def _main() -> None:
- pass
-
-
-@_main.group()
-def show() -> None:
- """Show given resource."""
- pass
-
-
-@_main.group()
-def connect() -> None:
- """Connect apps and data."""
- pass
-
-
-@_main.group()
-def disconnect() -> None:
- """Disconnect apps."""
- pass
-
-
-connect.command("app")(connect_app)
-disconnect.command("app")(disconnect_app)
-connect.command("data", hidden=True)(connect_data)
-_main.command(hidden=True)(ls)
-_main.command(hidden=True)(cd)
-_main.command(hidden=True)(cp)
-_main.command(hidden=True)(pwd)
-_main.command(hidden=True)(rm)
-show.command()(logs)
-
-
-@_main.command()
-def login() -> None:
- """Log in to your lightning.ai account."""
- auth = Auth()
- auth.clear()
-
- try:
- auth.authenticate()
- except ConnectionError:
- click.echo(f"Unable to connect to {get_lightning_cloud_url()}. Please check your internet connection.")
- exit(1)
-
-
-@_main.command()
-def logout() -> None:
- """Log out of your lightning.ai account."""
- Auth().clear()
- disconnect_app(logout=True)
-
-
-def _run_app(
- file: str,
- cloud: bool,
- without_server: bool,
- no_cache: bool,
- name: str,
- blocking: bool,
- open_ui: bool,
- env: tuple,
- secret: tuple,
- run_app_comment_commands: bool,
- enable_basic_auth: str,
-) -> None:
- if not os.path.exists(file):
- original_file = file
- file = cmd_install.gallery_apps_and_components(file, True, "latest", overwrite=True) # type: ignore[assignment] # E501
- if file is None:
- click.echo(f"The provided entrypoint `{original_file}` doesn't exist.")
- sys.exit(1)
- run_app_comment_commands = True
-
- runtime_type = RuntimeType.CLOUD if cloud else RuntimeType.MULTIPROCESS
-
- # Cloud specific validations
- if runtime_type != RuntimeType.CLOUD:
- if no_cache:
- raise click.ClickException(
- "Caching is a property of apps running in cloud. "
- "Using the flag --no-cache in local execution is not supported."
- )
- if secret:
- raise click.ClickException(
- "Secrets can only be used for apps running in cloud. "
- "Using the option --secret in local execution is not supported."
- )
- if (ENABLE_APP_COMMENT_COMMAND_EXECUTION or run_app_comment_commands) and file is not None:
- run_app_commands(str(file))
-
- env_vars = _format_input_env_variables(env)
- os.environ.update(env_vars)
-
- secrets = _format_input_env_variables(secret)
-
- port = _find_lit_app_port(constants.APP_SERVER_PORT)
- constants.APP_SERVER_PORT = port
-
- click.echo("Your Lightning App is starting. This won't take long.")
-
- # TODO: Fixme when Grid utilities are available.
- # And refactor test_lightning_run_app_cloud
- file_path = Path(file)
- dispatch(
- file_path,
- runtime_type,
- start_server=not without_server,
- no_cache=no_cache,
- blocking=blocking,
- open_ui=open_ui,
- name=name,
- env_vars=env_vars,
- secrets=secrets,
- run_app_comment_commands=run_app_comment_commands,
- enable_basic_auth=enable_basic_auth,
- port=port,
- )
- if runtime_type == RuntimeType.CLOUD:
- click.echo("Application is ready in the cloud")
-
-
-@_main.group()
-def run() -> None:
- """Run a Lightning application locally or on the cloud."""
-
-
-@run.command("app")
-@click.argument("file", type=str)
-@click.option("--cloud", type=bool, default=False, is_flag=True)
-@click.option("--name", help="The current application name", default="", type=str)
-@click.option("--without-server", is_flag=True, default=False)
-@click.option(
- "--no-cache",
- is_flag=True,
- default=False,
- help="Disable caching of packages " "installed from requirements.txt",
-)
-@click.option("--blocking", "blocking", type=bool, default=False)
-@click.option(
- "--open-ui",
- type=bool,
- default=True,
- help="Decide whether to launch the app UI in a web browser",
-)
-@click.option("--env", type=str, default=[], multiple=True, help="Environment variables to be set for the app.")
-@click.option("--secret", type=str, default=[], multiple=True, help="Secret variables to be set for the app.")
-@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.")
-@click.option(
- "--setup",
- "-s",
- "run_app_comment_commands",
- is_flag=True,
- default=False,
- help="run environment setup commands from the app comments.",
-)
-@click.option(
- "--enable-basic-auth",
- type=str,
- default="",
- help="Enable basic authentication for the app and use credentials provided in the format username:password",
-)
-def run_app(
- file: str,
- cloud: bool,
- without_server: bool,
- no_cache: bool,
- name: str,
- blocking: bool,
- open_ui: bool,
- env: tuple,
- secret: tuple,
- app_args: tuple,
- run_app_comment_commands: bool,
- enable_basic_auth: str,
-) -> None:
- """Run an app from a file."""
- _run_app(
- file,
- cloud,
- without_server,
- no_cache,
- name,
- blocking,
- open_ui,
- env,
- secret,
- run_app_comment_commands,
- enable_basic_auth,
- )
-
-
-@_main.command("open", hidden=True)
-@click.argument("path", type=str, default=".")
-@click.option("--name", help="The name to use for the CloudSpace", default="", type=str)
-def open(path: str, name: str) -> None:
- """Open files or folders from your machine on the cloud."""
- if not os.path.exists(path):
- click.echo(f"The provided path `{path}` doesn't exist.")
- sys.exit(1)
-
- runtime = CloudRuntime(entrypoint=Path(path))
- runtime.open(name)
-
-
-_main.add_command(get_list)
-_main.add_command(delete)
-_main.add_command(launch)
-_main.add_command(cmd_install.install)
-
-
-@_main.group()
-def init() -> None:
- """Init a Lightning App and/or component."""
-
-
-@init.command("app")
-@click.argument("name", type=str, required=False)
-def init_app(name: str) -> None:
- cmd_init.app(name)
-
-
-@init.command("pl-app")
-@click.argument("source", nargs=-1)
-@click.option(
- "--name",
- "-n",
- type=str,
- default="pl-app",
- help="The name of the folder where the app code will be. Default: pl-app",
-)
-@click.option(
- "--overwrite",
- "-f",
- is_flag=True,
- default=False,
- help="When set, overwrite the output directory without asking if it already exists.",
-)
-def init_pl_app(source: Union[Tuple[str], Tuple[str, str]], name: str, overwrite: bool = False) -> None:
- """Create an app from your PyTorch Lightning source files."""
- if len(source) == 1:
- script_path = source[0]
- source_dir = str(Path(script_path).resolve().parent)
- elif len(source) == 2:
- # enable type checking once https://github.com/python/mypy/issues/1178 is available
- source_dir, script_path = source
- else:
- click.echo(
- f"Incorrect number of arguments. You passed ({', '.join(source)}) but only either one argument"
- f" (script path) or two arguments (root dir, script path) are allowed. Examples:\n"
- f"lightning init pl-app ./path/to/script.py\n"
- f"lightning init pl-app ./code ./code/path/to/script.py",
- err=True,
- )
- raise SystemExit(1)
-
- cmd_pl_init.pl_app(source_dir=source_dir, script_path=script_path, name=name, overwrite=overwrite)
-
-
-@init.command("component")
-@click.argument("name", type=str, required=False)
-def init_component(name: str) -> None:
- cmd_init.component(name)
-
-
-@init.command("react-ui")
-@click.option(
- "--dest_dir",
- "-dest_dir",
- type=str,
- help="optional destination directory to create the react ui",
-)
-def init_react_ui(dest_dir: str) -> None:
- """Create a react UI to give a Lightning component a React.js web user interface (UI)"""
- cmd_react_ui_init.react_ui(dest_dir)
-
-
-def _prepare_file(file: str) -> str:
- exists = os.path.exists(file)
- if exists:
- return file
-
- raise FileNotFoundError(f"The provided file {file} hasn't been found.")
diff --git a/src/lightning/app/cli/lightning_cli_delete.py b/src/lightning/app/cli/lightning_cli_delete.py
deleted file mode 100644
index 179e5b6fc365d..0000000000000
--- a/src/lightning/app/cli/lightning_cli_delete.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import click
-import inquirer
-from inquirer.themes import GreenPassion
-from rich.console import Console
-
-from lightning.app.cli.cmd_apps import _AppManager
-
-
-@click.group("delete")
-def delete() -> None:
- """Delete Lightning AI self-managed resources (e.g. apps)"""
- pass
-
-
-def _find_selected_app_instance_id(app_name: str) -> str:
- console = Console()
- app_manager = _AppManager()
-
- all_app_names_and_ids = {}
- selected_app_instance_id = None
-
- for app in app_manager.list_apps():
- all_app_names_and_ids[app.name] = app.id
- # figure out the ID of some app_name
- if app_name == app.name or app_name == app.id:
- selected_app_instance_id = app.id
- break
-
- if selected_app_instance_id is None:
- # when there is no app with the given app_name,
- # ask the user which app they would like to delete.
- console.print(f'[b][yellow]Cannot find app named "{app_name}"[/yellow][/b]')
- try:
- ask = [
- inquirer.List(
- "app_name",
- message="Select the app name to delete",
- choices=list(all_app_names_and_ids.keys()),
- ),
- ]
- app_name = inquirer.prompt(ask, theme=GreenPassion(), raise_keyboard_interrupt=True)["app_name"]
- selected_app_instance_id = all_app_names_and_ids[app_name]
- except KeyboardInterrupt:
- console.print("[b][red]Cancelled by user![/b][/red]")
- raise InterruptedError
-
- return selected_app_instance_id
-
-
-def _delete_app_confirmation_prompt(app_name: str) -> None:
- console = Console()
-
- # when the --yes / -y flags were not passed, do a final
- # confirmation that the user wants to delete the app.
- try:
- ask = [
- inquirer.Confirm(
- "confirm",
- message=f'Are you sure you want to delete app "{app_name}""?',
- default=False,
- ),
- ]
- if inquirer.prompt(ask, theme=GreenPassion(), raise_keyboard_interrupt=True)["confirm"] is False:
- console.print("[b][red]Aborted![/b][/red]")
- raise InterruptedError
- except KeyboardInterrupt:
- console.print("[b][red]Cancelled by user![/b][/red]")
- raise InterruptedError
-
-
-@delete.command("app")
-@click.argument("app-name", type=str)
-@click.option(
- "skip_user_confirm_prompt",
- "--yes",
- "-y",
- is_flag=True,
- default=False,
- help="Do not prompt for confirmation.",
-)
-def delete_app(app_name: str, skip_user_confirm_prompt: bool) -> None:
- """Delete a Lightning app.
-
- Deleting an app also deletes all app websites, works, artifacts, and logs. This permanently removes any record of
- the app as well as all any of its associated resources and data. This does not affect any resources and data
- associated with other Lightning apps on your account.
-
- """
- console = Console()
-
- try:
- selected_app_instance_id = _find_selected_app_instance_id(app_name=app_name)
- if not skip_user_confirm_prompt:
- _delete_app_confirmation_prompt(app_name=app_name)
- except InterruptedError:
- return
-
- try:
- # Delete the app!
- app_manager = _AppManager()
- app_manager.delete(app_id=selected_app_instance_id)
- except Exception as ex:
- console.print(
- f'[b][red]An issue occurred while deleting app "{app_name}. If the issue persists, please '
- "reach out to us at [link=mailto:support@lightning.ai]support@lightning.ai[/link][/b][/red]."
- )
- raise click.ClickException(str(ex))
-
- console.print(f'[b][green]App "{app_name}" has been successfully deleted"![/green][/b]')
- return
diff --git a/src/lightning/app/cli/lightning_cli_launch.py b/src/lightning/app/cli/lightning_cli_launch.py
deleted file mode 100644
index c171fd7b946f1..0000000000000
--- a/src/lightning/app/cli/lightning_cli_launch.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import Tuple
-
-import click
-
-from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT
-from lightning.app.launcher.launcher import (
- run_lightning_flow,
- run_lightning_work,
- serve_frontend,
- start_application_server,
- start_flow_and_servers,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@click.group(name="launch", hidden=True)
-def launch() -> None:
- """Launch your application."""
-
-
-@launch.command("server", hidden=True)
-@click.argument("file", type=click.Path(exists=True))
-@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
-@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str)
-@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int)
-def run_server(file: str, queue_id: str, host: str, port: int) -> None:
- """It takes the application file as input, build the application object and then use that to run the application
- server.
-
- This is used by the cloud runners to start the status server for the application
-
- """
- logger.debug(f"Run Server: {file} {queue_id} {host} {port}")
- start_application_server(file, host, port, queue_id=queue_id)
-
-
-@launch.command("flow", hidden=True)
-@click.argument("file", type=click.Path(exists=True))
-@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
-@click.option("--base-url", help="Base url at which the app server is hosted", default="")
-def run_flow(file: str, queue_id: str, base_url: str) -> None:
- """It takes the application file as input, build the application object, proxy all the work components and then run
- the application flow defined in the root component.
-
- It does exactly what a singleprocess dispatcher would do but with proxied work components.
-
- """
- logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
- run_lightning_flow(file, queue_id=queue_id, base_url=base_url)
-
-
-@launch.command("work", hidden=True)
-@click.argument("file", type=click.Path(exists=True))
-@click.option("--work-name", type=str)
-@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
-def run_work(file: str, work_name: str, queue_id: str) -> None:
- """Unlike other entrypoints, this command will take the file path or module details for a work component and run
- that by fetching the states from the queues."""
- logger.debug(f"Run Work: {file} {work_name} {queue_id}")
- run_lightning_work(
- file=file,
- work_name=work_name,
- queue_id=queue_id,
- )
-
-
-@launch.command("frontend", hidden=True)
-@click.argument("file", type=click.Path(exists=True))
-@click.option("--flow-name")
-@click.option("--host")
-@click.option("--port", type=int)
-def run_frontend(file: str, flow_name: str, host: str, port: int) -> None:
- """Serve the frontend specified by the given flow."""
- logger.debug(f"Run Frontend: {file} {flow_name} {host}")
- serve_frontend(file=file, flow_name=flow_name, host=host, port=port)
-
-
-@launch.command("flow-and-servers", hidden=True)
-@click.argument("file", type=click.Path(exists=True))
-@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
-@click.option("--base-url", help="Base url at which the app server is hosted", default="")
-@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str)
-@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int)
-@click.option(
- "--flow-port",
- help="Pair of flow name and frontend port",
- type=(str, int),
- multiple=True,
-)
-def run_flow_and_servers(
- file: str,
- base_url: str,
- queue_id: str,
- host: str,
- port: int,
- flow_port: Tuple[Tuple[str, int]],
-) -> None:
- """It takes the application file as input, build the application object and then use that to run the application
- flow defined in the root component, the application server and all the flow frontends.
-
- This is used by the cloud runners to start the flow, the status server and all frontends for the application
-
- """
- logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
- logger.debug(f"Run Server: {file} {queue_id} {host} {port}.")
- logger.debug(f"Run Frontend's: {flow_port}")
- start_flow_and_servers(
- entrypoint_file=file,
- base_url=base_url,
- queue_id=queue_id,
- host=host,
- port=port,
- flow_names_and_ports=flow_port,
- )
diff --git a/src/lightning/app/cli/lightning_cli_list.py b/src/lightning/app/cli/lightning_cli_list.py
deleted file mode 100644
index 0cbc8e3cc1887..0000000000000
--- a/src/lightning/app/cli/lightning_cli_list.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any
-
-import click
-
-from lightning.app.cli.cmd_apps import _AppManager
-
-
-@click.group(name="list")
-def get_list() -> None:
- """List Lightning AI self-managed resources (e.g. apps)"""
- pass
-
-
-@get_list.command("apps")
-def list_apps(**kwargs: Any) -> None:
- """List your Lightning AI apps."""
- app_manager = _AppManager()
- app_manager.list()
diff --git a/src/lightning/app/cli/pl-app-template/.gitignore b/src/lightning/app/cli/pl-app-template/.gitignore
deleted file mode 100644
index 01aa0091c3945..0000000000000
--- a/src/lightning/app/cli/pl-app-template/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-.storage
diff --git a/src/lightning/app/cli/pl-app-template/.lightningignore b/src/lightning/app/cli/pl-app-template/.lightningignore
deleted file mode 100644
index 5895fd5187660..0000000000000
--- a/src/lightning/app/cli/pl-app-template/.lightningignore
+++ /dev/null
@@ -1,2 +0,0 @@
-.storage
-ui/node_modules
diff --git a/src/lightning/app/cli/pl-app-template/app.py b/src/lightning/app/cli/pl-app-template/app.py
deleted file mode 100644
index b15ea21a02276..0000000000000
--- a/src/lightning/app/cli/pl-app-template/app.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import os
-from typing import Dict, List, Optional, Union
-
-from core.components import TensorBoard, WeightsAndBiases
-from core.components.script_runner import ScriptRunner
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.frontend import StaticWebFrontend
-from lightning.app.storage.path import Path
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-
-class ReactUI(LightningFlow):
- def configure_layout(self):
- return StaticWebFrontend(str(Path(__file__).parent / "ui/build"))
-
-
-class ScriptOrchestrator(LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.script_runner: Optional[ScriptRunner] = None
- self.triggered: bool = False
- self.running: bool = False
- self.succeeded: bool = False
- self.failed: bool = False
- self.script_args: List[str] = []
- self.cloud_compute_args: Dict[str, Union[str, int]] = {"name": "cpu-small"}
- self.environment_variables: Dict[str, str] = {}
- self.script_path = "{{ script_path }}"
-
- def run(self) -> None:
- if not self.triggered:
- return
-
- if self.script_runner is None:
- self.script_runner = ScriptRunner(
- root_path=str(Path(__file__).parent / "source"),
- script_path=str(Path(__file__).parent / "source" / self.script_path),
- script_args=self.script_args,
- env=self._prepare_environment(),
- parallel=True,
- cloud_compute=CloudCompute(**self.cloud_compute_args),
- raise_exception=False,
- )
- self.script_runner.run()
-
- self.running = self.script_runner is not None and self.script_runner.has_started
- self.succeeded = self.script_runner is not None and self.script_runner.has_succeeded
- self.failed = self.script_runner is not None and self.script_runner.has_failed
-
- if self.succeeded or self.failed:
- self.triggered = False
- # TODO: support restarting
- # self.script_runner = None
-
- def _prepare_environment(self) -> Dict[str, str]:
- env = os.environ.copy()
- env.update(self.environment_variables)
- return env
-
-
-class Main(LightningFlow):
- def __init__(self) -> None:
- super().__init__()
- self.react_ui = ReactUI()
- self.script_orchestrator = ScriptOrchestrator()
- self.running_in_cloud = bool(os.environ.get("LIGHTNING_CLOUD_APP_ID", False))
-
- def run(self) -> None:
- self.react_ui.run()
- self.script_orchestrator.run()
-
- if self.script_orchestrator.script_runner and self.script_orchestrator.script_runner.logger_metadatas:
- if not getattr(self, "logger_component", None):
- # TODO: Hack with hasattr and setattr until
- # https://linear.app/gridai/issue/LAI2-8970/work-getting-set-to-none-in-state-update-from-appstate
- # is resolved
- logger_component = self._choose_logger_component()
- if logger_component is not None:
- setattr(self, "logger_component", logger_component)
- else:
- self.logger_component.run()
-
- def configure_layout(self):
- tabs = [{"name": "Home", "content": self.react_ui}]
- if hasattr(self, "logger_component"):
- tabs.extend(self.logger_component.configure_layout())
- return tabs
-
- def _choose_logger_component(self) -> Optional[Union[TensorBoard, WeightsAndBiases]]:
- logger_metadatas = self.script_orchestrator.script_runner.logger_metadatas
- if not logger_metadatas:
- return None
- if logger_metadatas[0].get("class_name") == "TensorBoardLogger":
- return TensorBoard(log_dir=self.script_orchestrator.script_runner.log_dir)
- if logger_metadatas[0].get("class_name") == "WandbLogger":
- return WeightsAndBiases(
- username=logger_metadatas[0]["username"],
- project_name=logger_metadatas[0]["project_name"],
- run_id=logger_metadatas[0]["run_id"],
- api_key=self.script_orchestrator.environment_variables.get("WANDB_API_KEY"),
- )
- return None
-
-
-app = LightningApp(Main())
diff --git a/src/lightning/app/cli/pl-app-template/core/__init__.py b/src/lightning/app/cli/pl-app-template/core/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/pl-app-template/core/callbacks.py b/src/lightning/app/cli/pl-app-template/core/callbacks.py
deleted file mode 100644
index 87ec8e9bcbda2..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/callbacks.py
+++ /dev/null
@@ -1,319 +0,0 @@
-import inspect
-from typing import TYPE_CHECKING, Any, Dict, Union
-
-import lightning.pytorch as pl
-from lightning.app.storage.path import Path
-from lightning.app.utilities.app_helpers import Logger
-from lightning.pytorch import Callback
-from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
-from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from lightning.pytorch.utilities.parsing import collect_init_args
-
-from core.state import ProgressBarState, TrainerState
-
-if TYPE_CHECKING:
- from core.components.script_runner import ScriptRunner
-
-
-_log = Logger(__name__)
-
-
-class PLAppProgressTracker(Callback):
- """This callback tracks and communicates the Trainer's progress to the running PyTorch Lightning App."""
-
- def __init__(self, work: "ScriptRunner", refresh_rate: int = 1) -> None:
- super().__init__()
- self.work = work
- self.refresh_rate = refresh_rate
- self.is_enabled = False
- self._state = ProgressBarState()
-
- def setup(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- stage: str,
- ) -> None:
- self.is_enabled = trainer.is_global_zero
-
- def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- # We calculate the estimated stepping batches here instead of in the setup hook, because calling the
- # `Trainer.estimated_stepping_batches` too early would lead to a barrier() call in case of DDP and since this
- # callback is only attached on rank 0, would lead to a stall.
- self._state.fit.estimated_stepping_batches = trainer.estimated_stepping_batches
-
- def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
- self._state.fit.total_train_batches = self._total_train_batches(trainer)
- self._state.fit.total_val_batches = self._total_val_batches(trainer)
- self._state.fit.current_epoch = trainer.current_epoch
- if self.is_enabled:
- self._send_state()
-
- def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- current = self._train_batch_idx(trainer)
- self._state.fit.train_batch_idx = current
- self._state.fit.global_step = trainer.global_step
- if self._should_send(current, self._total_train_batches(trainer)):
- self._send_state()
-
- def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- if self.is_enabled:
- self._send_state()
-
- def on_validation_batch_start(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- batch: Any,
- batch_idx: int,
- dataloader_idx: int,
- ) -> None:
- if trainer.state.fn == "fit":
- self._state.fit.val_dataloader_idx = dataloader_idx
- self._state.fit.total_val_batches = self._total_val_batches(trainer)
- if trainer.state.fn == "validate":
- self._state.val.dataloader_idx = dataloader_idx
- self._state.val.total_val_batches = self._total_val_batches(trainer)
-
- def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- current = self._val_batch_idx(trainer)
- if trainer.state.fn == "fit":
- self._state.fit.val_batch_idx = current
- if trainer.state.fn == "validate":
- self._state.val.val_batch_idx = current
- if self._should_send(current, self._total_val_batches(trainer)):
- self._send_state()
-
- def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- if self.is_enabled:
- self._send_state()
-
- def on_test_batch_start(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- batch: Any,
- batch_idx: int,
- dataloader_idx: int,
- ) -> None:
- self._state.test.dataloader_idx = dataloader_idx
- self._state.test.total_test_batches = trainer.num_test_batches[dataloader_idx]
-
- def on_test_batch_end(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- outputs: Any,
- batch: Any,
- batch_idx: int,
- dataloader_idx: int = 0,
- ) -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- current = self._test_batch_idx(trainer)
- self._state.test.test_batch_idx = current
- if self._should_send(current, trainer.num_test_batches[dataloader_idx]):
- self._send_state()
-
- def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- if self.is_enabled:
- self._send_state()
-
- def on_predict_batch_start(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- batch: Any,
- batch_idx: int,
- dataloader_idx: int,
- ) -> None:
- self._state.predict.dataloader_idx = dataloader_idx
- self._state.predict.total_predict_batches = trainer.num_predict_batches[dataloader_idx]
-
- def on_predict_batch_end(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- outputs: Any,
- batch: Any,
- batch_idx: int,
- dataloader_idx: int = 0,
- ) -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- current = self._predict_batch_idx(trainer)
- self._state.predict.predict_batch_idx = current
- if self._should_send(current, trainer.num_predict_batches[dataloader_idx]):
- self._send_state()
-
- def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.metrics = self._progress_bar_metrics(trainer, pl_module)
- if self.is_enabled:
- self._send_state()
-
- def _train_batch_idx(self, trainer: "pl.Trainer") -> int:
- return trainer.fit_loop.epoch_loop.batch_progress.current.processed
-
- def _val_batch_idx(self, trainer: "pl.Trainer") -> int:
- loop = trainer.fit_loop.epoch_loop.val_loop if trainer.state.fn == "fit" else trainer.validate_loop
-
- return loop.epoch_loop.batch_progress.current.processed
-
- def _test_batch_idx(self, trainer: "pl.Trainer") -> int:
- return trainer.test_loop.epoch_loop.batch_progress.current.processed
-
- def _predict_batch_idx(self, trainer: "pl.Trainer") -> int:
- return trainer.predict_loop.epoch_loop.batch_progress.current.processed
-
- def _total_train_batches(self, trainer: "pl.Trainer") -> Union[int, float]:
- return trainer.num_training_batches
-
- def _total_val_batches(self, trainer: "pl.Trainer") -> Union[int, float]:
- return sum(trainer.num_val_batches) if trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
-
- def _progress_bar_metrics(
- self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
- ) -> Dict[str, Union[str, float]]:
- standard_metrics = get_standard_metrics(trainer, pl_module)
- pbar_metrics = trainer.progress_bar_metrics
- return {**standard_metrics, **pbar_metrics}
-
- def _send_state(self) -> None:
- self.work.trainer_progress = self._state.dict()
-
- def _should_send(self, current: int, total: int) -> bool:
- return self.is_enabled and current % self.refresh_rate == 0 or current == total
-
-
-class PLAppTrainerStateTracker(Callback):
- def __init__(self, work: "ScriptRunner") -> None:
- super().__init__()
- self.work = work
- self._state = TrainerState()
-
- def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = "fit"
- self.work.trainer_state = self._state.dict()
-
- def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = None
- self.work.trainer_state = self._state.dict()
-
- def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.stage = "training"
- self.work.trainer_state = self._state.dict()
-
- def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.stage = None
- self.work.trainer_state = self._state.dict()
-
- def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.stage = "validating"
- self.work.trainer_state = self._state.dict()
-
- def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.stage = None
- self.work.trainer_state = self._state.dict()
-
- def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = "test"
- self._state.stage = "testing"
- self.work.trainer_state = self._state.dict()
-
- def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = None
- self._state.stage = None
- self.work.trainer_state = self._state.dict()
-
- def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = "predict"
- self._state.stage = "predicting"
- self.work.trainer_state = self._state.dict()
-
- def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- self._state.fn = None
- self._state.stage = None
- self.work.trainer_state = self._state.dict()
-
-
-class PLAppSummary(Callback):
- def __init__(self, work: "ScriptRunner") -> None:
- super().__init__()
- self.work = work
-
- def on_init_end(self, trainer: "pl.Trainer") -> None:
- current_frame = inspect.currentframe()
- # Trainer.init() -> Trainer._call_callback_hooks() -> Callback.on_init_end()
- frame = current_frame.f_back.f_back
- init_args = {}
- for local_args in collect_init_args(frame, []):
- init_args.update(local_args)
-
- self.work.trainer_hparams = self._sanitize_trainer_init_args(init_args)
-
- def setup(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- stage: str,
- ) -> None:
- self.work.model_hparams = self._sanitize_model_init_args(dict(**pl_module.hparams))
-
- def _sanitize_trainer_init_args(self, init_args: Dict[str, Any]) -> Dict[str, str]:
- if init_args["callbacks"]:
- init_args["callbacks"] = [c.__class__.__name__ for c in init_args["callbacks"]]
- return {k: str(v) for k, v in init_args.items()}
-
- def _sanitize_model_init_args(self, init_args: Dict[str, Any]) -> Dict[str, str]:
- return {k: str(v) for k, v in init_args.items()}
-
-
-class PLAppArtifactsTracker(Callback):
- def __init__(self, work: "ScriptRunner") -> None:
- super().__init__()
- self.work = work
-
- def setup(
- self,
- trainer: "pl.Trainer",
- pl_module: "pl.LightningModule",
- stage: str,
- ) -> None:
- log_dir = self._get_logdir(trainer)
- self.work.log_dir = Path(log_dir) if log_dir is not None else None
- self._collect_logger_metadata(trainer)
-
- def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- if trainer.checkpoint_callback and trainer.checkpoint_callback.dirpath is not None:
- self.work.checkpoint_dir = Path(trainer.checkpoint_callback.dirpath)
-
- def _collect_logger_metadata(self, trainer: "pl.Trainer") -> None:
- if not trainer.loggers:
- return
-
- for logger in trainer.loggers:
- metadata = {"class_name": logger.__class__.__name__}
- if isinstance(logger, WandbLogger) and not logger._offline:
- metadata.update({
- "username": logger.experiment.entity,
- "project_name": logger.name,
- "run_id": logger.version,
- })
-
- if metadata and metadata not in self.work.logger_metadatas:
- self.work.logger_metadatas.append(metadata)
-
- @staticmethod
- def _get_logdir(trainer: "pl.Trainer") -> str:
- """The code here is the same as in the ``Trainer.log_dir``, with the exception of the broadcast call."""
- if len(trainer.loggers) == 1:
- if isinstance(trainer.logger, TensorBoardLogger):
- dirpath = trainer.logger.log_dir
- else:
- dirpath = trainer.logger.save_dir
- else:
- dirpath = trainer.default_root_dir
- return dirpath
diff --git a/src/lightning/app/cli/pl-app-template/core/components/__init__.py b/src/lightning/app/cli/pl-app-template/core/components/__init__.py
deleted file mode 100644
index 75f49eb7da05e..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/components/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from core.components.logger.tensorboard import TensorBoard # noqa: F401
-from core.components.logger.weights_and_biases import WeightsAndBiases # noqa: F401
diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/__init__.py b/src/lightning/app/cli/pl-app-template/core/components/logger/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py
deleted file mode 100644
index 0e1a536ff4859..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import subprocess
-import time
-from typing import Dict, List
-
-from lightning.app import BuildConfig, LightningFlow, LightningWork
-from lightning.app.storage.path import Path
-
-
-class TensorBoard(LightningFlow):
- def __init__(self, log_dir: Path, sync_every_n_seconds: int = 5) -> None:
- """This TensorBoard component synchronizes the log directory of an experiment and starts up the server.
-
- Args:
- log_dir: The path to the directory where the TensorBoard log-files will appear.
- sync_every_n_seconds: How often to sync the log directory (given as an argument to the run method)
-
- """
- super().__init__()
- self.worker = TensorBoardWorker(log_dir=log_dir, sync_every_n_seconds=sync_every_n_seconds)
-
- def run(self) -> None:
- self.worker.run()
-
- def configure_layout(self) -> List[Dict[str, str]]:
- return [{"name": "Training Logs", "content": self.worker.url}]
-
-
-class TensorBoardWorker(LightningWork):
- def __init__(self, log_dir: Path, sync_every_n_seconds: int = 5) -> None:
- super().__init__(cloud_build_config=BuildConfig(requirements=["tensorboard"]))
- self.log_dir = log_dir
- self._sync_every_n_seconds = sync_every_n_seconds
-
- def run(self) -> None:
- subprocess.Popen([
- "tensorboard",
- "--logdir",
- str(self.log_dir),
- "--host",
- self.host,
- "--port",
- str(self.port),
- ])
-
- # Download the log directory periodically
- while True:
- time.sleep(self._sync_every_n_seconds)
- if self.log_dir.exists_remote():
- self.log_dir.get(overwrite=True)
diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/weights_and_biases.py b/src/lightning/app/cli/pl-app-template/core/components/logger/weights_and_biases.py
deleted file mode 100644
index bf20d17de033c..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/components/logger/weights_and_biases.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import os
-from typing import TYPE_CHECKING, Dict, List, Optional
-
-from lightning.app import LightningFlow
-
-if TYPE_CHECKING:
- import wandb
-
-
-class WeightsAndBiases(LightningFlow):
- def __init__(self, username: str, project_name: str, run_id: str, api_key: Optional[str] = None) -> None:
- super().__init__()
- self.username = username
- self.project_name = project_name
- self.run_id = run_id
- self._api_key = api_key
- self._run: Optional[wandb.Run] = None
-
- def run(self) -> None:
- if self._run is not None:
- return
-
- if self._api_key:
- os.environ["WANDB_API_KEY"] = self._api_key
-
- import wandb
-
- self._run = wandb.init(project=self.project_name, id=self.run_id, entity=self.username)
-
- def configure_layout(self) -> List[Dict[str, str]]:
- if self._run is not None:
- return [{"name": "Training Logs", "content": self._run.get_url()}]
- return []
diff --git a/src/lightning/app/cli/pl-app-template/core/components/script_runner/__init__.py b/src/lightning/app/cli/pl-app-template/core/components/script_runner/__init__.py
deleted file mode 100644
index b74bcabd5fbd7..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/components/script_runner/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from core.components.script_runner.script_runner import ScriptRunner # noqa: F401
diff --git a/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py b/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py
deleted file mode 100644
index 0c2f09e372237..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import sys
-import traceback
-from typing import Any, Dict, List, Optional, Tuple
-
-from lightning.app.components.python import TracerPythonScript
-from lightning.app.storage.path import Path
-from lightning.app.utilities.packaging.build_config import BuildConfig, load_requirements
-from lightning.app.utilities.tracer import Tracer
-
-
-class ScriptRunner(TracerPythonScript):
- """The ScriptRunner executes the script using ``runpy`` and also patches the Trainer methods to inject additional
- code."""
-
- def __init__(self, root_path: str, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, cloud_build_config=self._get_build_config(root_path), **kwargs)
- self.root_path = root_path
- self.exception_message: str = ""
- self.trainer_progress: dict = {}
- self.trainer_state: dict = {}
- self.trainer_hparams: dict = {}
- self.model_hparams: dict = {}
- self.log_dir: Optional[Path] = None
- self.checkpoint_dir: Optional[Path] = None
- self.logger_metadatas: List[Dict[str, str]] = []
-
- def configure_tracer(self) -> Tracer:
- from lightning.pytorch import Trainer
-
- from core.callbacks import PLAppArtifactsTracker, PLAppProgressTracker, PLAppSummary, PLAppTrainerStateTracker
-
- tracer = Tracer()
- trainer_artifacts_tracker = PLAppArtifactsTracker(work=self)
- trainer_state_tracker = PLAppTrainerStateTracker(work=self)
- progress_tracker = PLAppProgressTracker(work=self)
- summary = PLAppSummary(work=self)
-
- def pre_trainer_init(_, *args: Any, **kwargs: Any) -> Tuple[Dict, Tuple[Any, ...], Dict[str, Any]]:
- kwargs.setdefault("callbacks", [])
- kwargs["callbacks"].extend([
- trainer_artifacts_tracker,
- trainer_state_tracker,
- progress_tracker,
- summary,
- ])
- return {}, args, kwargs
-
- tracer.add_traced(Trainer, "__init__", pre_fn=pre_trainer_init)
- return tracer
-
- def run(self) -> None:
- self.exception_message = ""
- # We need to set the module path both in sys.path and the PYTHONPATH env variable.
- # The former is for the current process which is already running, and the env variable is needed in case
- # the script launches subprocesses
- sys.path.insert(0, self.root_path)
- self.env["PYTHONPATH"] = self.root_path
- super().run()
-
- def on_exception(self, exception: BaseException) -> None:
- self.exception_message = traceback.format_exc()
- super().on_exception(exception)
-
- @staticmethod
- def _get_build_config(root_path: str) -> Optional[BuildConfig]:
- # These are the requirements for the script runner itself
- requirements = [
- "protobuf<4.21.0",
- "pytorch-lightning<=1.6.3",
- "pydantic<=1.9.0",
- ]
- if Path(root_path, "requirements.txt").exists():
- # Requirements from the user's code folder
- requirements.extend(load_requirements(root_path, file_name="requirements.txt"))
-
- return BuildConfig(requirements=requirements)
diff --git a/src/lightning/app/cli/pl-app-template/core/state.py b/src/lightning/app/cli/pl-app-template/core/state.py
deleted file mode 100644
index 80a9f3d4e0619..0000000000000
--- a/src/lightning/app/cli/pl-app-template/core/state.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import Dict, Optional, Union
-
-from pydantic import BaseModel, Field
-
-
-class FitProgress(BaseModel):
- current_epoch: int = 0
- train_batch_idx: int = 0
- total_train_batches: int = 0
- val_dataloader_idx: int = 0
- val_batch_idx: int = 0
- total_val_batches: int = 0
- global_step: int = 0
- estimated_stepping_batches: int = 0
-
-
-class ValidateProgress(BaseModel):
- dataloader_idx: int = 0
- val_batch_idx: int = 0
- total_val_batches: int = 0
-
-
-class TestProgress(BaseModel):
- dataloader_idx: int = 0
- test_batch_idx: int = 0
- total_test_batches: int = 0
-
-
-class PredictProgress(BaseModel):
- dataloader_idx: int = 0
- predict_batch_idx: int = 0
- total_predict_batches: int = 0
-
-
-class ProgressBarState(BaseModel):
- fit: FitProgress = Field(default_factory=FitProgress)
- val: ValidateProgress = Field(alias="validate", default_factory=ValidateProgress)
- test: TestProgress = Field(default_factory=TestProgress)
- predict: PredictProgress = Field(default_factory=PredictProgress)
- metrics: Dict[str, Union[float, str]] = {}
-
-
-class TrainerState(BaseModel):
- fn: Optional[str] = None
- stage: Optional[str] = None
diff --git a/src/lightning/app/cli/pl-app-template/setup.py b/src/lightning/app/cli/pl-app-template/setup.py
deleted file mode 100644
index dc223931779a2..0000000000000
--- a/src/lightning/app/cli/pl-app-template/setup.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import os
-from typing import List
-
-from setuptools import find_packages, setup
-
-_PROJECT_ROOT = os.path.dirname(__file__)
-
-
-def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]:
- """Load requirements from a file."""
- with open(os.path.join(path_dir, file_name)) as file:
- lines = [ln.strip() for ln in file.readlines()]
- reqs = []
- for ln in lines:
- # filer all comments
- if comment_char in ln:
- ln = ln[: ln.index(comment_char)].strip()
- # skip directly installed dependencies
- if ln.startswith("http"):
- continue
- # skip index url
- if ln.startswith("--extra-index-url"):
- continue
- if ln: # if requirement is not empty
- reqs.append(ln)
- return reqs
-
-
-setup(
- name="{{ app_name }}",
- version="0.0.1",
- packages=find_packages(exclude=["ui"]),
- python_requires=">=3.8",
-)
diff --git a/src/lightning/app/cli/pl-app-template/tests/__init__.py b/src/lightning/app/cli/pl-app-template/tests/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/pl-app-template/tests/core/__init__.py b/src/lightning/app/cli/pl-app-template/tests/core/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/cli/pl-app-template/tests/core/test_callbacks.py b/src/lightning/app/cli/pl-app-template/tests/core/test_callbacks.py
deleted file mode 100644
index a211da35e6a90..0000000000000
--- a/src/lightning/app/cli/pl-app-template/tests/core/test_callbacks.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os.path
-from unittest.mock import Mock
-
-import pytest
-from core.callbacks import PLAppArtifactsTracker, PLAppProgressTracker, PLAppSummary
-from core.components.script_runner import ScriptRunner
-from lightning.app.storage.path import Path
-from lightning.pytorch import LightningModule, Trainer
-from lightning.pytorch.loggers import TensorBoardLogger
-
-
-@pytest.mark.parametrize("rank", [0, 1])
-def test_progress_tracker_enabled(rank):
- trainer = Mock()
- trainer.global_rank = rank
- trainer.is_global_zero = rank == 0
- work = Mock()
- tracker = PLAppProgressTracker(work)
- assert not tracker.is_enabled
- tracker.setup(trainer, Mock(), Mock())
- assert tracker.is_enabled == trainer.is_global_zero
-
-
-def test_summary_callback_tracks_hyperparameters():
- class ModelWithParameters(LightningModule):
- def __init__(self, float_arg=0.1, int_arg=5, bool_arg=True, string_arg="string"):
- super().__init__()
- self.save_hyperparameters()
-
- model = ModelWithParameters()
- work = Mock()
- summary = PLAppSummary(work)
- trainer = Trainer(max_epochs=22, callbacks=[summary]) # this triggers the `Callback.on_init_end` hook
- summary.setup(trainer, model)
- assert work.model_hparams == {
- "float_arg": "0.1",
- "int_arg": "5",
- "bool_arg": "True",
- "string_arg": "string",
- }
-
- assert work.trainer_hparams["max_epochs"] == "22"
- assert work.trainer_hparams["logger"] == "True"
- assert "ModelCheckpoint" in work.trainer_hparams["callbacks"]
- assert "PLAppSummary" in work.trainer_hparams["callbacks"]
-
-
-def test_artifacts_tracker(tmpdir):
- work = ScriptRunner(root_path=os.path.dirname(__file__), script_path=__file__)
- tracker = PLAppArtifactsTracker(work=work)
- trainer = Mock()
-
- trainer.loggers = []
- trainer.default_root_dir = "default_root_dir"
- tracker.setup(trainer=trainer, pl_module=Mock())
- assert work.log_dir == Path("default_root_dir")
- assert not work.logger_metadatas
-
- trainer.loggers = [TensorBoardLogger(save_dir=tmpdir)]
- trainer.logger = trainer.loggers[0]
- tracker.setup(trainer=trainer, pl_module=Mock())
- assert work.log_dir == Path(tmpdir / "lightning_logs" / "version_0")
- assert len(work.logger_metadatas) == 1
- assert work.logger_metadatas[0] == {"class_name": "TensorBoardLogger"}
-
- # call setup a second time and the metadata length should not change
- tracker.setup(trainer=trainer, pl_module=Mock())
- assert len(work.logger_metadatas) == 1
diff --git a/src/lightning/app/cli/pl-app-template/tests/test_app.py b/src/lightning/app/cli/pl-app-template/tests/test_app.py
deleted file mode 100644
index 3fc14bfcdbf69..0000000000000
--- a/src/lightning/app/cli/pl-app-template/tests/test_app.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import pytest
-
-
-@pytest.mark.skip()
-def test_is_running_in_cloud(monkeypatch):
- from app import Main
-
- monkeypatch.setenv("LIGHTNING_CLOUD_APP_ID", "anything")
- app = Main()
- assert app.running_in_cloud
-
- monkeypatch.delenv("LIGHTNING_CLOUD_APP_ID", raising=False)
- app = Main()
- assert not app.running_in_cloud
diff --git a/src/lightning/app/cli/pl-app-template/ui/.gitignore b/src/lightning/app/cli/pl-app-template/ui/.gitignore
deleted file mode 100644
index 6c2d44cd3ba13..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/.gitignore
+++ /dev/null
@@ -1,25 +0,0 @@
-# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
-
-# dependencies
-/node_modules
-/.pnp
-.pnp.js
-
-# testing
-/coverage
-
-# production
-/build
-
-# misc
-.DS_Store
-
-npm-debug.log*
-yarn-debug.log*
-yarn-error.log*
-
-/cypress/videos
-/cypress/screenshots
-/cypress/downloads
-
-.eslintcache
diff --git a/src/lightning/app/cli/pl-app-template/ui/.prettierignore b/src/lightning/app/cli/pl-app-template/ui/.prettierignore
deleted file mode 100644
index 2ea70f096046d..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/.prettierignore
+++ /dev/null
@@ -1,3 +0,0 @@
-resources
-build
-node_modules
diff --git a/src/lightning/app/cli/pl-app-template/ui/.prettierrc b/src/lightning/app/cli/pl-app-template/ui/.prettierrc
deleted file mode 100644
index cad1459af3548..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/.prettierrc
+++ /dev/null
@@ -1,24 +0,0 @@
-{
- "jsxSingleQuote": false,
- "arrowParens": "avoid",
- "tabWidth": 2,
- "useTabs": false,
- "printWidth": 119,
- "singleQuote": false,
- "semi": true,
- "endOfLine": "lf",
- "proseWrap": "always",
- "bracketSameLine": true,
- "quoteProps": "consistent",
- "trailingComma": "all",
- "bracketSpacing": true,
- "importOrder": [
- "^react$",
- "",
- "^(components|hooks|resources|utils|lightning-.*)",
- "^tests",
- "^[./]"
- ],
- "importOrderSeparation": true,
- "importOrderSortSpecifiers": true
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/craco.config.js b/src/lightning/app/cli/pl-app-template/ui/craco.config.js
deleted file mode 100644
index 979d08985ff19..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/craco.config.js
+++ /dev/null
@@ -1,29 +0,0 @@
-const path = require("path");
-const fs = require("fs");
-const cracoBabelLoader = require("craco-babel-loader");
-
-// manage relative paths to packages
-const appDirectory = fs.realpathSync(process.cwd());
-const resolvePackage = relativePath => path.resolve(appDirectory, relativePath);
-
-module.exports = {
- devServer: {
- // When launching `yarn start dev`, write the files to the build folder too
- devMiddleware: { writeToDisk: true },
- },
- webpack: {
- configure: {
- output: {
- publicPath: "./",
- },
- },
- },
- plugins: [
- {
- plugin: cracoBabelLoader,
- options: {
- includes: [resolvePackage("node_modules/lightning-ui")],
- },
- },
- ],
-};
diff --git a/src/lightning/app/cli/pl-app-template/ui/package.json b/src/lightning/app/cli/pl-app-template/ui/package.json
deleted file mode 100644
index 690bd85b6498c..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/package.json
+++ /dev/null
@@ -1,95 +0,0 @@
-{
- "name": "pytorch-lightning-app",
- "version": "0.1.0",
- "private": true,
- "dependencies": {
- "@emotion/react": "^11.7.1",
- "@emotion/styled": "^11.6.0",
- "@mui/icons-material": "^5.6.2",
- "@mui/lab": "^5.0.0-alpha.64",
- "@mui/material": "^5.2.7",
- "@reduxjs/toolkit": "^1.8.0",
- "@stripe/stripe-js": "^1.29.0",
- "axios": "^0.25.0",
- "boring-avatars": "^1.6.3",
- "filter-material-ui": "2.7.0",
- "fontfaceobserver": "^2.1.0",
- "lightning-ui": "git+ssh://git@github.com/gridai/lightning-ui.git#35f4124cc8a16a313174fe63ec82cb74af388c6b",
- "lodash": "^4.17.21",
- "notistack": "^2.0.4",
- "query-string": "^7.1.0",
- "react": "^17.0.2",
- "react-dom": "^17.0.2",
- "react-github-btn": "^1.2.1",
- "react-hook-form": "^7.27.1",
- "react-query": "^3.34.7",
- "react-router-dom": "^6.2.1",
- "react-scripts": "5.0.0",
- "react-spring": "^9.4.4",
- "react-table": "^7.7.0",
- "rxjs": "^7.5.2",
- "typescript": "^4.4.2",
- "use-debounce": "^7.0.1",
- "web-vitals": "^2.1.0",
- "xterm": "^4.18.0",
- "xterm-addon-fit": "^0.5.0",
- "xterm-addon-search": "^0.8.2"
- },
- "scripts": {
- "start": "craco start",
- "build": "craco build",
- "lint": "eslint --cache --max-warnings=0 . && prettier -c .",
- "lint:fix": "eslint --cache --max-warnings=0 . --fix && prettier -w .",
- "eject": "react-scripts eject"
- },
- "lint-staged": {
- "**/*": "prettier --write --ignore-unknown"
- },
- "eslintConfig": {
- "extends": [
- "react-app"
- ],
- "ignorePatterns": [
- "node_modules/**",
- "build/**"
- ],
- "rules": {
- "react/jsx-sort-props": [
- "off",
- {
- "callbacksLast": true,
- "ignoreCase": true,
- "noSortAlphabetically": false,
- "reservedFirst": true,
- "shorthandFirst": true
- }
- ],
- "react/jsx-pascal-case": "warn"
- }
- },
- "browserslist": {
- "production": [
- ">0.2%",
- "not dead",
- "not op_mini all"
- ],
- "development": [
- "last 1 chrome version",
- "last 1 firefox version",
- "last 1 safari version"
- ]
- },
- "devDependencies": {
- "@craco/craco": "^6.4.3",
- "@trivago/prettier-plugin-sort-imports": "^3.1.1",
- "@types/fontfaceobserver": "^2.1.0",
- "@types/lodash": "^4.14.182",
- "@types/node": "^16.7.13",
- "@types/react": "^17.0.20",
- "@types/react-dom": "^17.0.9",
- "@types/react-table": "^7.7.9",
- "craco-babel-loader": "^1.0.3",
- "lint-staged": "^12.3.2",
- "prettier": "2.5.1"
- }
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/public/favicon.svg b/src/lightning/app/cli/pl-app-template/ui/public/favicon.svg
deleted file mode 100644
index 94a65989d0b4b..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/public/favicon.svg
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-
-
-
-
-
-
-
diff --git a/src/lightning/app/cli/pl-app-template/ui/public/index.html b/src/lightning/app/cli/pl-app-template/ui/public/index.html
deleted file mode 100644
index 0f2384212e1d3..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/public/index.html
+++ /dev/null
@@ -1,65 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
- PyTorch Lightning App
-
-
-
-
-
- You need to enable JavaScript to run this app.
-
-
-
-
diff --git a/src/lightning/app/cli/pl-app-template/ui/public/manifest.json b/src/lightning/app/cli/pl-app-template/ui/public/manifest.json
deleted file mode 100644
index 4a97d68d6b7ba..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/public/manifest.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{
- "short_name": "PL App",
- "name": "PyTorch Lightning App",
- "icons": [
- {
- "src": "favicon.svg",
- "sizes": "512x512 192x192 64x64 32x32 24x24 16x16",
- "type": "image/svg+xml"
- }
- ],
- "start_url": ".",
- "display": "standalone",
- "theme_color": "#000000",
- "background_color": "#ffffff"
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/public/robots.txt b/src/lightning/app/cli/pl-app-template/ui/public/robots.txt
deleted file mode 100644
index e9e57dc4d41b9..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/public/robots.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-# https://www.robotstxt.org/robotstxt.html
-User-agent: *
-Disallow:
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/App.tsx b/src/lightning/app/cli/pl-app-template/ui/src/App.tsx
deleted file mode 100644
index 717984216f298..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/App.tsx
+++ /dev/null
@@ -1,126 +0,0 @@
-import { useEffect } from "react";
-
-import { QueryClient, QueryClientProvider } from "react-query";
-import { BrowserRouter } from "react-router-dom";
-
-import ErrorPanel from "components/ErrorPanel";
-import HyperparameterSummary from "components/HyperparameterSummary";
-import Launcher from "components/Launcher";
-import ProgressBarGroup from "components/ProgressBarGroup";
-import {
- Breadcrumbs,
- Card,
- CardContent,
- CardHeader,
- Grid,
- SnackbarProvider,
- Stack,
- useSnackbar,
-} from "lightning-ui/src/design-system/components";
-import ThemeProvider from "lightning-ui/src/design-system/theme";
-
-import ExecutionSummary from "./components/ExecutionSummary";
-import { useLightningState } from "./hooks/useLightningState";
-
-const queryClient = new QueryClient();
-
-function AppContainer() {
- const { lightningState } = useLightningState();
-
- const trainer_progress = lightningState?.flows.script_orchestrator.works.script_runner?.vars.trainer_progress;
- const trainer_state = lightningState?.flows.script_orchestrator.works.script_runner?.vars.trainer_state;
- const trainer_hparams = lightningState?.flows.script_orchestrator.works.script_runner?.vars.trainer_hparams;
- const model_hparams = lightningState?.flows.script_orchestrator.works.script_runner?.vars.model_hparams;
-
- const script_running = lightningState?.flows.script_orchestrator.vars.running;
- const script_succeeded = lightningState?.flows.script_orchestrator.vars.succeeded;
- const script_failed = lightningState?.flows.script_orchestrator.vars.failed;
- const start_triggered = lightningState?.flows.script_orchestrator.vars.triggered;
- const script_path = lightningState?.flows.script_orchestrator.vars.script_path;
- const running_in_cloud = lightningState?.vars.running_in_cloud;
-
- const breadCrumbItems = [
- { title: "Users", href: "url/to/href/1" },
- { title: "adrian", href: "url/to/href/2" },
- { title: "projects", href: "url/to/href/3" },
- { title: "app_name", href: "url/to/href/4" },
- { title: "source", href: "url/to/href/5" },
- { title: "train.py", href: "url/to/href/6" },
- ];
-
- const { enqueueSnackbar } = useSnackbar();
- const exception_message = lightningState?.flows.script_orchestrator.works.script_runner?.vars?.exception_message;
- useEffect(() => {
- if (exception_message) {
- enqueueSnackbar({
- title: "The script failed to complete",
- severity: "error",
- children: "See the error message",
- });
- }
- }, [exception_message]);
-
- return (
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- );
-}
-
-function App() {
- return (
-
-
-
-
-
-
-
-
-
- );
-}
-
-export default App;
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/EnvironmentConfigurator.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/EnvironmentConfigurator.tsx
deleted file mode 100644
index 2d26f86ad9965..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/EnvironmentConfigurator.tsx
+++ /dev/null
@@ -1,67 +0,0 @@
-import { Button, Stack, TextField } from "lightning-ui/src/design-system/components";
-
-interface Data {
- [key: string]: string;
-}
-
-export function data2dict(data: Data[]) {
- var dict: Data = {};
- for (var i = 0; i < data.length; i++) {
- if (data[i]["name"] === "") {
- continue;
- }
- dict[data[i]["name"]] = data[i]["value"];
- }
- return dict;
-}
-
-export default function EnvironmentConfigurator(props: any) {
- const data: Data[] = props.data;
- const setData = props.setData;
- const addItemAllowed = data[data.length - 1].name.length > 0;
-
- const onItemAdd = () => {
- setData([...data, { name: "", value: "" }]);
- };
-
- const onItemChange = (fieldName: string, index: number, text: any) => {
- let newData = [...data];
-
- text = text.trim();
- if (fieldName == "name") {
- text = text.replace(/[^0-9a-zA-Z_]+/gi, "").toUpperCase();
- }
-
- newData[index][fieldName] = text;
- setData(newData);
- };
-
- return (
-
- {data.map((entry, index) => (
-
- onItemChange("name", index, e)}
- placeholder="KEY"
- size="medium"
- statusText=""
- type="text"
- value={entry.name || ""}
- />
- onItemChange("value", index, e)}
- placeholder="VALUE"
- size="medium"
- statusText=""
- type="text"
- value={entry.value || ""}
- />
-
- ))}
-
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/ErrorPanel.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/ErrorPanel.tsx
deleted file mode 100644
index 7b88e20007003..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/ErrorPanel.tsx
+++ /dev/null
@@ -1,24 +0,0 @@
-import * as React from "react";
-
-import ExpandMoreIcon from "@mui/icons-material/ExpandMore";
-import Accordion from "@mui/material/Accordion";
-import AccordionDetails from "@mui/material/AccordionDetails";
-import AccordionSummary from "@mui/material/AccordionSummary";
-import Typography from "@mui/material/Typography";
-
-export default function SimpleAccordion(props: any) {
- return (
-
- } aria-controls="panel1a-content" id="panel1a-header">
- Errors
-
-
-
-
- {props?.error_message}
-
-
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/ExecutionSummary.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/ExecutionSummary.tsx
deleted file mode 100644
index 48489c1246847..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/ExecutionSummary.tsx
+++ /dev/null
@@ -1,79 +0,0 @@
-import Typography from "@mui/material/Typography";
-
-import Timer from "components/Timer";
-import { Grid, Stack } from "lightning-ui/src/design-system/components";
-
-export default function ExecutionSummary(props: any) {
- const trainer_progress = props?.trainer_progress;
- const trainer_state = props?.trainer_state;
- const script_running = props?.script_running;
-
- const global_step = trainer_progress?.fit?.global_step || 0;
- const estimated_stepping_batches = trainer_progress?.fit?.estimated_stepping_batches || Infinity;
- const total_estimated_progress = Math.round((100 * global_step) / estimated_stepping_batches);
-
- return (
-
-
-
-
- Duration
-
-
-
-
-
-
-
-
-
-
-
- Stage
-
-
-
- {trainer_state?.stage != undefined ? trainer_state?.stage : "-"}
-
-
-
-
-
-
-
- Epoch
-
-
-
- {trainer_progress?.fit?.current_epoch != undefined ? trainer_progress?.fit?.current_epoch : "-"}
-
-
-
-
-
-
-
- Batch
-
-
-
- {trainer_progress?.fit?.train_batch_idx != undefined ? trainer_progress?.fit?.train_batch_idx : "-"}
-
-
-
-
-
-
-
- Total Progress
-
-
-
- {trainer_progress?.fit != undefined ? total_estimated_progress + "%" : "-"}
-
-
-
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/HyperparameterSummary.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/HyperparameterSummary.tsx
deleted file mode 100644
index 085348c9a7355..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/HyperparameterSummary.tsx
+++ /dev/null
@@ -1,95 +0,0 @@
-import { useState } from "react";
-
-import { TextField } from "@mui/material";
-
-import { Checkbox, Stack, Table, Typography } from "lightning-ui/src/design-system/components";
-
-function HyperparametersTable(props: any) {
- return ;
-}
-
-export default function HyperparameterSummary(props: any) {
- const model_hparams = props?.model_hparams ? props.model_hparams : {};
- const trainer_hparams = props?.trainer_hparams ? props.trainer_hparams : {};
- const model_hparams_keys = Object.keys(model_hparams);
- const trainer_hparams_keys = Object.keys(trainer_hparams);
-
- const [searched, setSearched] = useState("");
- const filteredModelHparams = model_hparams_keys
- .filter(key => key.toLowerCase().includes(searched.toLowerCase()))
- .map(key => [key, model_hparams[key]]);
- const filteredTrainerHparams = trainer_hparams_keys
- .filter(key => key.toLowerCase().includes(searched.toLowerCase()))
- .map(key => [key, trainer_hparams[key]]);
-
- const [modelHparamsVisible, setModelHparamsVisible] = useState(true);
- const [trainerHparamsVisible, setTrainerHparamsVisible] = useState(true);
-
- const requestSearch = (event: any) => {
- const searchedVal = event?.target.value;
- setSearched(searchedVal);
- };
-
- const toggleModelHparamsCheckbox = (value: boolean) => {
- setModelHparamsVisible(value);
- };
-
- const toggleTrainerHparamsCheckbox = (value: boolean) => {
- setTrainerHparamsVisible(value);
- };
-
- return (
-
-
- Trainer>}
- checked={trainerHparamsVisible}
- disabled={filteredTrainerHparams.length == 0}
- />
- Model>}
- checked={modelHparamsVisible}
- disabled={filteredModelHparams.length == 0}
- />
-
-
-
- {(!model_hparams || Object.keys(model_hparams).length == 0) && (
- Hyperparameters will appear when script is running.
- )}
-
-
- {modelHparamsVisible && model_hparams_keys && model_hparams_keys.length > 0 && (
-
- Model
-
-
- )}
-
- {trainerHparamsVisible && trainer_hparams && Object.keys(trainer_hparams).length > 0 && (
-
- Trainer
-
-
- )}
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/Launcher.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/Launcher.tsx
deleted file mode 100644
index d16ebd27d3229..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/Launcher.tsx
+++ /dev/null
@@ -1,172 +0,0 @@
-import { useState } from "react";
-
-import PlayCircleFilledWhiteOutlinedIcon from "@mui/icons-material/PlayCircleFilledWhiteOutlined";
-import cloneDeep from "lodash/cloneDeep";
-
-import { useLightningState } from "hooks/useLightningState";
-import {
- Banner,
- Button,
- Dialog,
- DialogActions,
- DialogContent,
- DialogTitle,
- Label,
- Select,
- Stack,
- TextField,
- Typography,
- useSnackbar,
-} from "lightning-ui/src/design-system/components";
-
-import EnvironmentConfigurator, { data2dict } from "./EnvironmentConfigurator";
-
-export default function Launcher(props: any) {
- // We had to pass the updateLightningState as props because accessing them through useLightningState does not work
- const { updateLightningState } = useLightningState();
- const [hardwareType, setHardwareType] = useState("cpu-small");
- const [environmentVariables, setEnvironmentVariables] = useState([{ name: "", value: "" }]);
- const [showHardwareDialog, setShowHardwareDialog] = useState(false);
- const { enqueueSnackbar } = useSnackbar();
- const [scriptArgs, setScriptArgs] = useState("");
-
- let status_text = "";
- let status_color: "default" | "primary" | "success" | "error" | "warning" | "purple" = "default";
- if (props.script_running) {
- status_text = "Running";
- status_color = "success";
- } else if (props.start_triggered && !props.script_running) {
- status_text = "Starting";
- status_color = "warning";
- } else if (props.script_succeeded) {
- status_text = "Finished";
- status_color = "success";
- } else if (props.script_failed) {
- status_text = "Failed";
- status_color = "error";
- }
-
- const onStartClick = async () => {
- setShowHardwareDialog(true);
- };
-
- const onHardwareDialogCancelClick = () => {
- setShowHardwareDialog(false);
- };
-
- const onHardwareDialogConfirmClick = () => {
- setShowHardwareDialog(false);
- enqueueSnackbar({
- title: "Hardware request sent",
- severity: "info",
- children: "Your script will start once the hardware is ready.",
- });
- if (props.lightningState) {
- const newLightningState = cloneDeep(props.lightningState);
- newLightningState.flows.script_orchestrator.vars.triggered =
- !newLightningState.flows.script_orchestrator.vars.triggered;
-
- newLightningState.flows.script_orchestrator.vars.cloud_compute_args = {
- name: hardwareType,
- };
- newLightningState.flows.script_orchestrator.vars.environment_variables = data2dict(environmentVariables);
- newLightningState.flows.script_orchestrator.vars.script_args =
- scriptArgs.length > 0 ? scriptArgs.trim().split(/[ ]+/) : [];
- updateLightningState(newLightningState);
- }
- };
-
- const handleHardwareTypeChange = (new_value: any) => {
- setHardwareType(new_value);
- };
-
- const handleScriptArgsChange = (new_value: string | null) => {
- if (new_value !== null) {
- setScriptArgs(new_value);
- }
- };
-
- return (
-
-
- }
- disabled={props.script_running || props.start_triggered}
- />
-
- {props.script_path}
-
- {status_text ? :
}
-
-
-
-
-
- Hardware
-
-
-
-
- Hardware selection is only available in the cloud.
-
- Hint: Try running the app with --cloud.
-
-
- Script Arguments
-
-
- Environment Variables
-
-
-
-
-
-
-
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBar.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBar.tsx
deleted file mode 100644
index 0fc1bec3999fc..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBar.tsx
+++ /dev/null
@@ -1,35 +0,0 @@
-import * as React from "react";
-
-import Box from "@mui/material/Box";
-import CircularProgress, { CircularProgressProps, circularProgressClasses } from "@mui/material/CircularProgress";
-import LinearProgress, { linearProgressClasses } from "@mui/material/LinearProgress";
-import Typography from "@mui/material/Typography";
-import { styled } from "@mui/material/styles";
-
-import { LIGHTNING_PURPLE } from "lightning-colors";
-
-const BorderLinearProgress = styled(LinearProgress)(({ theme }) => ({
- height: 10,
- borderRadius: 5,
- [`&.${linearProgressClasses.colorPrimary}`]: {
- backgroundColor: theme.palette.grey[theme.palette.mode === "light" ? 200 : 800],
- },
- [`& .${linearProgressClasses.bar}`]: {
- borderRadius: 5,
- backgroundColor: { LIGHTNING_PURPLE },
- },
-}));
-
-export default function ProgressBar(props: any) {
- const percentage = props.current ? (props.current * 100) / props.total : 0;
- return (
-
-
-
-
-
- {`${Math.round(percentage)}%`}
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBarGroup.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBarGroup.tsx
deleted file mode 100644
index 90787ab4f03ee..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/ProgressBarGroup.tsx
+++ /dev/null
@@ -1,49 +0,0 @@
-import { toPath } from "lodash";
-
-import { Stack, Typography } from "lightning-ui/src/design-system/components";
-
-import ProgressBar from "./ProgressBar";
-
-export default function ProgressBarGroup(props: any) {
- const trainer_progress = props.trainer_progress;
- const trainer_state = props.trainer_state;
-
- const primary_bar_title = "Training";
- var secondary_bar_title = "";
- var current = 0;
- var total = 0;
-
- switch (trainer_state?.stage) {
- case "validating":
- secondary_bar_title = "Validation";
- current = trainer_progress?.fit?.val_batch_idx;
- total = trainer_progress?.fit?.total_val_batches;
- break;
- case "testing":
- secondary_bar_title = "Test";
- current = trainer_progress?.test?.test_batch_idx;
- total = trainer_progress?.test?.total_test_batches;
- break;
- case "predicting":
- secondary_bar_title = "Prediction";
- current = trainer_progress?.predict?.predict_batch_idx;
- total = trainer_progress?.predict?.total_predict_batches;
- break;
- default:
- secondary_bar_title = "Hidden";
- }
-
- return (
-
- {primary_bar_title}
-
-
-
{secondary_bar_title}
-
-
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/components/Timer.tsx b/src/lightning/app/cli/pl-app-template/ui/src/components/Timer.tsx
deleted file mode 100644
index 29a6f5bca42b0..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/components/Timer.tsx
+++ /dev/null
@@ -1,29 +0,0 @@
-import { useEffect, useState } from "react";
-
-export default function Timer(props: any) {
- const isActive = props?.isActive;
- var [totalSeconds, setTotalSeconds] = useState(0);
-
- var hours = Math.floor(totalSeconds / 3600);
- var totalSeconds = totalSeconds % 3600;
- var minutes = Math.floor(totalSeconds / 60);
- var seconds = totalSeconds % 60;
-
- useEffect(() => {
- let interval: any = null;
- if (isActive) {
- interval = setInterval(() => {
- setTotalSeconds(totalSeconds => totalSeconds + 1);
- }, 1000);
- } else if (!isActive && totalSeconds !== 0) {
- clearInterval(interval);
- }
- return () => clearInterval(interval);
- }, [isActive, totalSeconds]);
-
- return (
-
- {("0" + hours).slice(-2)}:{("0" + minutes).slice(-2)}:{("0" + seconds).slice(-2)}
-
- );
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/hooks/useLightningState.ts b/src/lightning/app/cli/pl-app-template/ui/src/hooks/useLightningState.ts
deleted file mode 100644
index f4aa42604843b..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/hooks/useLightningState.ts
+++ /dev/null
@@ -1,31 +0,0 @@
-import { useEffect, useState } from "react";
-
-import type { LightingState } from "../types/lightning";
-
-interface LightningState {
- subscribe(handler: (state: any) => void): () => void;
- next(state: any): void;
-}
-
-declare global {
- interface Window {
- LightningState: LightningState;
- }
-}
-
-export const useLightningState = () => {
- const [lightningState, setLightningState] = useState();
-
- useEffect(() => {
- const unsubscribe = window.LightningState.subscribe(setLightningState);
-
- return unsubscribe;
- }, []);
-
- const updateLightningState = window.LightningState.next;
-
- return {
- lightningState,
- updateLightningState,
- };
-};
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/index.css b/src/lightning/app/cli/pl-app-template/ui/src/index.css
deleted file mode 100644
index 6e15903b858f8..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/index.css
+++ /dev/null
@@ -1,19 +0,0 @@
-body {
- margin: 0;
- font-family: "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif;
- -webkit-font-smoothing: antialiased;
- -moz-osx-font-smoothing: grayscale;
- overflow-y: overlay;
-}
-
-code {
- font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace;
-}
-
-pre {
- overflow: scroll;
- font-size: 0.8em;
- max-height: 40em;
-}
-
-@import "lightning-ui/src/design-system/style.css";
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/index.tsx b/src/lightning/app/cli/pl-app-template/ui/src/index.tsx
deleted file mode 100644
index b1feb569d6770..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/index.tsx
+++ /dev/null
@@ -1,26 +0,0 @@
-import React from "react";
-
-import FontFaceObserver from "fontfaceobserver";
-import ReactDOM from "react-dom";
-
-import App from "./App";
-import "./index.css";
-import reportWebVitals from "./reportWebVitals";
-
-// Make sure Roboto Mono is loaded before rendering the app as it is used within a canvas element
-// The rest of the fonts don't need to be loaded before the render as they will be applied
-// as soon as they are available
-const font = new FontFaceObserver("Roboto Mono");
-font.load().then(() => {
- ReactDOM.render(
-
-
- ,
- document.getElementById("root"),
- );
-});
-
-// If you want to start measuring performance in your app, pass a function
-// to log results (for example: reportWebVitals(console.log))
-// or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
-reportWebVitals();
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/lightning-colors.ts b/src/lightning/app/cli/pl-app-template/ui/src/lightning-colors.ts
deleted file mode 100644
index a910b9a1dc455..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/lightning-colors.ts
+++ /dev/null
@@ -1,2 +0,0 @@
-export const LIGHTNING_PURPLE = "#6162D1";
-export const BREADCRUMBS_BACKGROUND = "#F7F8FB";
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/react-app-env.d.ts b/src/lightning/app/cli/pl-app-template/ui/src/react-app-env.d.ts
deleted file mode 100644
index 6431bc5fc6b2c..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/react-app-env.d.ts
+++ /dev/null
@@ -1 +0,0 @@
-///
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/reportWebVitals.ts b/src/lightning/app/cli/pl-app-template/ui/src/reportWebVitals.ts
deleted file mode 100644
index 5fa3583b7500b..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/reportWebVitals.ts
+++ /dev/null
@@ -1,15 +0,0 @@
-import { ReportHandler } from "web-vitals";
-
-const reportWebVitals = (onPerfEntry?: ReportHandler) => {
- if (onPerfEntry && onPerfEntry instanceof Function) {
- import("web-vitals").then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
- getCLS(onPerfEntry);
- getFID(onPerfEntry);
- getFCP(onPerfEntry);
- getLCP(onPerfEntry);
- getTTFB(onPerfEntry);
- });
- }
-};
-
-export default reportWebVitals;
diff --git a/src/lightning/app/cli/pl-app-template/ui/src/types/lightning.ts b/src/lightning/app/cli/pl-app-template/ui/src/types/lightning.ts
deleted file mode 100644
index 02893ca2a9321..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/src/types/lightning.ts
+++ /dev/null
@@ -1,57 +0,0 @@
-/**
- * Represents the internal state of a Lightning app as exposed by
- * the `/state` endpoint of the Lightning HTTP API.
- */
-export type LightingState = {
- vars: {
- _layout: Layout | Layout[];
- [key: string]: any;
- };
- calls: {
- [key: string]: {
- name: string;
- call_hash: string;
- ret: boolean;
- };
- };
- flows: {
- [key: string]: ChildState;
- };
- works: {
- [key: string]: any;
- };
- changes: {
- [key: string]: any;
- };
- app_state: {
- stage: AppStage;
- };
-};
-
-export type ChildState = Omit;
-
-export type Layout = LayoutBranch | LayoutLeaf;
-
-export type LayoutBranch = {
- name: string;
- content: string;
-};
-
-export type LayoutLeaf = {
- name: string;
- type: LayoutType;
- source?: string;
- target: string;
-};
-
-export enum LayoutType {
- web = "web",
- streamlit = "streamlit",
-}
-
-export enum AppStage {
- blocking = "blocking",
- restarting = "restarting",
- running = "running",
- stopping = "stopping",
-}
diff --git a/src/lightning/app/cli/pl-app-template/ui/tsconfig.json b/src/lightning/app/cli/pl-app-template/ui/tsconfig.json
deleted file mode 100644
index cf7275fc65df5..0000000000000
--- a/src/lightning/app/cli/pl-app-template/ui/tsconfig.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "compilerOptions": {
- "target": "es5",
- "lib": ["dom", "dom.iterable", "esnext"],
- "allowJs": true,
- "baseUrl": "src/",
- "skipLibCheck": true,
- "esModuleInterop": true,
- "allowSyntheticDefaultImports": true,
- "strict": true,
- "forceConsistentCasingInFileNames": true,
- "noFallthroughCasesInSwitch": true,
- "module": "esnext",
- "moduleResolution": "node",
- "resolveJsonModule": true,
- "isolatedModules": true,
- "noEmit": true,
- "jsx": "react-jsx"
- },
- "types": ["node"],
- "include": ["src"]
-}
diff --git a/src/lightning/app/cli/react-ui-template/README.md b/src/lightning/app/cli/react-ui-template/README.md
deleted file mode 100644
index 1f28215d22199..0000000000000
--- a/src/lightning/app/cli/react-ui-template/README.md
+++ /dev/null
@@ -1,103 +0,0 @@
-# React-ui template
-
-This is a full react template ready to use in a component
-
-This UI was automatically generated with:
-
-```commandline
-lightning_app init react-ui
-```
-
-### Delete files
-
-This template has 3 main files/folders:
-
-- README.md
-- example_app.py
-- ui
-
-The README.md and example_app.py are here to help you get started. However, they should be deleted.
-All you need is the `ui` folder when you connect this react UI to your component.
-
-### Run the example_app
-
-This template comes with `example_app.py` to show how to integrate the UI into a component.
-
-run it with:
-
-```bash
-lightning_app run app react-ui/example_app.py
-```
-
-### Connect React to your component
-
-To connect the react UI to your component, simply point the `StaticWebFrontend` to the `dist/` folder generated by yarn after building your react website.
-
-```python
-import lightning as L
-
-
-class YourComponent(L.LightningFlow):
- def configure_layout(self):
- return Lapp.frontend.StaticWebFrontend(Path(__file__).parent / "react-ui/src/dist")
-```
-
-### Set up interactions between React and the component
-
-To communicate bi-directionally between react and the component, use the `useLightningState` module:
-
-```js
-// App.tsx
-
-import { useLightningState } from "./hooks/useLightningState";
-import cloneDeep from "lodash/cloneDeep";
-
-function App() {
- const { lightningState, updateLightningState } = useLightningState();
-
- const modify_and_send_back_the_state = async (event: ChangeEvent) => {
- if (lightningState) {
- const newLightningState = cloneDeep(lightningState);
- // Update the state and send it back.
- newLightningState.flows.counter += 1
-
- updateLightningState(newLightningState);
- }
- };
-
- return (
-
-
- );
-}
-
-export default App;
-```
-
-#### Application Folder Structure
-
-You would find the following structure in the folder `/react`.
-
-```
-react
-├── dist
-├── index.html
-├── node_modules
-├── package.json
-├── src
-│ ├── App.css
-│ ├── App.tsx
-│ ├── hooks
-│ ├── index.css
-│ ├── main.tsx
-│ ├── types
-│ └── vite-env.d.ts
-├── tsconfig.json
-├── tsconfig.node.json
-├── vite.config.ts
-└── yarn.lock
-```
-
-### Read the docs
-
-For more features and information, [read the documentation](https://ideal-bassoon-5313f557.pages.github.io/workflows/add_web_ui/react/index.html).
diff --git a/src/lightning/app/cli/react-ui-template/example_app.py b/src/lightning/app/cli/react-ui-template/example_app.py
deleted file mode 100644
index 362a2a0477f7e..0000000000000
--- a/src/lightning/app/cli/react-ui-template/example_app.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# example_app.py
-
-from pathlib import Path
-
-from lightning.app import LightningApp, LightningFlow, frontend
-
-
-class YourComponent(LightningFlow):
- def __init__(self):
- super().__init__()
- self.message_to_print = "Hello World!"
- self.should_print = False
-
- def configure_layout(self):
- return frontend.StaticWebFrontend(Path(__file__).parent / "ui/dist")
-
-
-class HelloLitReact(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
- self.react_ui = YourComponent()
-
- def run(self):
- if self.react_ui.should_print:
- print(f"{self.counter}: {self.react_ui.message_to_print}")
- self.counter += 1
-
- def configure_layout(self):
- return [{"name": "React UI", "content": self.react_ui}]
-
-
-app = LightningApp(HelloLitReact())
diff --git a/src/lightning/app/cli/react-ui-template/ui/index.html b/src/lightning/app/cli/react-ui-template/ui/index.html
deleted file mode 100644
index 837abf069d329..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/index.html
+++ /dev/null
@@ -1,14 +0,0 @@
-
-
-
-
-
-
- Vite App
-
-
-
-
-
-
-
diff --git a/src/lightning/app/cli/react-ui-template/ui/package.json b/src/lightning/app/cli/react-ui-template/ui/package.json
deleted file mode 100644
index d43665302c55a..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/package.json
+++ /dev/null
@@ -1,31 +0,0 @@
-{
- "name": "hello-world",
- "private": true,
- "version": "0.0.0",
- "scripts": {
- "start": "vite",
- "build": "tsc --noEmit && vite build",
- "preview": "vite preview"
- },
- "dependencies": {
- "@emotion/react": "^11.8.2",
- "@emotion/styled": "^11.8.1",
- "@mui/material": "5.8.5",
- "axios": "^1.6.0",
- "lodash": "^4.17.21",
- "nanoid": "^3.3.1",
- "react": "^17.0.2",
- "react-dom": "^17.0.2"
- },
- "devDependencies": {
- "@types/lodash": "^4.14.179",
- "@types/react": "^18.0.1",
- "@types/react-dom": "^18.0.0",
- "@vitejs/plugin-react": "^1.0.7",
- "prettier": "^2.5.1",
- "typescript": "^4.5.4",
- "vite": "^2.9.17"
- },
- "main": "index.js",
- "license": "MIT"
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/App.css b/src/lightning/app/cli/react-ui-template/ui/src/App.css
deleted file mode 100644
index 4fce394eeb216..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/App.css
+++ /dev/null
@@ -1,10 +0,0 @@
-.App {
- text-align: center;
- height: 100vh;
- background: rgb(34, 193, 195);
- background: linear-gradient(0deg, rgba(34, 193, 195, 1) 0%, rgba(253, 187, 45, 1) 100%);
-}
-
-.App .wrapper {
- transform: translateY(100%);
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/App.tsx b/src/lightning/app/cli/react-ui-template/ui/src/App.tsx
deleted file mode 100644
index 716280938fbc4..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/App.tsx
+++ /dev/null
@@ -1,69 +0,0 @@
-// App.tsx
-
-import { Button } from "@mui/material";
-import { TextField } from "@mui/material";
-import Box from "@mui/material/Box";
-import { ChangeEvent } from "react";
-import cloneDeep from "lodash/cloneDeep";
-
-import "./App.css";
-import { useLightningState } from "./hooks/useLightningState";
-
-function App() {
- const { lightningState, updateLightningState } = useLightningState();
-
- const counter = lightningState?.vars.counter;
-
- const handleClick = async () => {
- if (lightningState) {
- const newLightningState = cloneDeep(lightningState);
- newLightningState.flows.react_ui.vars.should_print = !newLightningState.flows.react_ui.vars.should_print;
-
- updateLightningState(newLightningState);
- }
- };
-
- const handleTextField = async (event: ChangeEvent) => {
- if (lightningState) {
- const newLightningState = cloneDeep(lightningState);
- newLightningState.flows.react_ui.vars.message_to_print = event.target.value;
-
- updateLightningState(newLightningState);
- }
- };
-
- return (
-
-
-
- handleClick()}>
-
- {lightningState?.["flows"]?.["react_ui"]?.["vars"]?.["should_print"] ? "Stop printing" : "Start Printing"}
-
-
-
-
-
-
-
-
-
Total number of prints in your terminal: {counter}
-
-
-
-
- );
-}
-
-export default App;
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/favicon.svg b/src/lightning/app/cli/react-ui-template/ui/src/favicon.svg
deleted file mode 100644
index de4aeddc12bdf..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/favicon.svg
+++ /dev/null
@@ -1,15 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/hooks/useLightningState.ts b/src/lightning/app/cli/react-ui-template/ui/src/hooks/useLightningState.ts
deleted file mode 100644
index 5a44a304c0d9b..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/hooks/useLightningState.ts
+++ /dev/null
@@ -1,31 +0,0 @@
-import { useState, useEffect } from "react";
-
-import type { LightingState } from "../types/lightning";
-
-interface LightningState {
- subscribe(handler: (state: any) => void): () => void;
- next(state: any): void;
-}
-
-declare global {
- interface Window {
- LightningState: LightningState;
- }
-}
-
-export const useLightningState = () => {
- const [lightningState, setLightningState] = useState();
-
- useEffect(() => {
- const unsubscribe = window.LightningState.subscribe(setLightningState);
-
- return unsubscribe;
- }, []);
-
- const updateLightningState = window.LightningState.next;
-
- return {
- lightningState,
- updateLightningState,
- };
-};
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/index.css b/src/lightning/app/cli/react-ui-template/ui/src/index.css
deleted file mode 100644
index e9927237d70c0..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/index.css
+++ /dev/null
@@ -1,11 +0,0 @@
-body {
- margin: 0;
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans",
- "Droid Sans", "Helvetica Neue", sans-serif;
- -webkit-font-smoothing: antialiased;
- -moz-osx-font-smoothing: grayscale;
-}
-
-code {
- font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", monospace;
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/main.tsx b/src/lightning/app/cli/react-ui-template/ui/src/main.tsx
deleted file mode 100644
index e7cd236bbac70..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/main.tsx
+++ /dev/null
@@ -1,11 +0,0 @@
-import React from "react";
-import ReactDOM from "react-dom";
-import "./index.css";
-import App from "./App";
-
-ReactDOM.render(
-
-
- ,
- document.getElementById("root"),
-);
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/types/lightning.ts b/src/lightning/app/cli/react-ui-template/ui/src/types/lightning.ts
deleted file mode 100644
index 02893ca2a9321..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/types/lightning.ts
+++ /dev/null
@@ -1,57 +0,0 @@
-/**
- * Represents the internal state of a Lightning app as exposed by
- * the `/state` endpoint of the Lightning HTTP API.
- */
-export type LightingState = {
- vars: {
- _layout: Layout | Layout[];
- [key: string]: any;
- };
- calls: {
- [key: string]: {
- name: string;
- call_hash: string;
- ret: boolean;
- };
- };
- flows: {
- [key: string]: ChildState;
- };
- works: {
- [key: string]: any;
- };
- changes: {
- [key: string]: any;
- };
- app_state: {
- stage: AppStage;
- };
-};
-
-export type ChildState = Omit;
-
-export type Layout = LayoutBranch | LayoutLeaf;
-
-export type LayoutBranch = {
- name: string;
- content: string;
-};
-
-export type LayoutLeaf = {
- name: string;
- type: LayoutType;
- source?: string;
- target: string;
-};
-
-export enum LayoutType {
- web = "web",
- streamlit = "streamlit",
-}
-
-export enum AppStage {
- blocking = "blocking",
- restarting = "restarting",
- running = "running",
- stopping = "stopping",
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/src/vite-env.d.ts b/src/lightning/app/cli/react-ui-template/ui/src/vite-env.d.ts
deleted file mode 100644
index 11f02fe2a0061..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/src/vite-env.d.ts
+++ /dev/null
@@ -1 +0,0 @@
-///
diff --git a/src/lightning/app/cli/react-ui-template/ui/tsconfig.json b/src/lightning/app/cli/react-ui-template/ui/tsconfig.json
deleted file mode 100644
index c8bdc64082aa2..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/tsconfig.json
+++ /dev/null
@@ -1,21 +0,0 @@
-{
- "compilerOptions": {
- "target": "ESNext",
- "useDefineForClassFields": true,
- "lib": ["DOM", "DOM.Iterable", "ESNext"],
- "allowJs": false,
- "skipLibCheck": false,
- "esModuleInterop": false,
- "allowSyntheticDefaultImports": true,
- "strict": true,
- "forceConsistentCasingInFileNames": true,
- "module": "ESNext",
- "moduleResolution": "Node",
- "resolveJsonModule": true,
- "isolatedModules": true,
- "noEmit": true,
- "jsx": "react-jsx"
- },
- "include": ["src"],
- "references": [{ "path": "./tsconfig.node.json" }]
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/tsconfig.node.json b/src/lightning/app/cli/react-ui-template/ui/tsconfig.node.json
deleted file mode 100644
index e993792cb12c9..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/tsconfig.node.json
+++ /dev/null
@@ -1,8 +0,0 @@
-{
- "compilerOptions": {
- "composite": true,
- "module": "esnext",
- "moduleResolution": "node"
- },
- "include": ["vite.config.ts"]
-}
diff --git a/src/lightning/app/cli/react-ui-template/ui/vite.config.ts b/src/lightning/app/cli/react-ui-template/ui/vite.config.ts
deleted file mode 100644
index 82632ac0ed8e6..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/vite.config.ts
+++ /dev/null
@@ -1,9 +0,0 @@
-import { defineConfig } from "vite";
-import react from "@vitejs/plugin-react";
-
-// https://vitejs.dev/config/
-export default defineConfig({
- plugins: [react()],
- // NOTE: Component UI's are served under `/{componentName}/` subpath, so the app needs to be configured for relative base path.
- base: "./",
-});
diff --git a/src/lightning/app/cli/react-ui-template/ui/yarn.lock b/src/lightning/app/cli/react-ui-template/ui/yarn.lock
deleted file mode 100644
index 50570338c1877..0000000000000
--- a/src/lightning/app/cli/react-ui-template/ui/yarn.lock
+++ /dev/null
@@ -1,1278 +0,0 @@
-# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
-# yarn lockfile v1
-
-
-"@ampproject/remapping@^2.1.0":
- version "2.2.0"
- resolved "https://registry.yarnpkg.com/@ampproject/remapping/-/remapping-2.2.0.tgz#56c133824780de3174aed5ab6834f3026790154d"
- integrity sha512-qRmjj8nj9qmLTQXXmaR1cck3UXSRMPrbsLJAasZpF+t3riI71BXed5ebIOYwQntykeZuhjsdweEc9BxH5Jc26w==
- dependencies:
- "@jridgewell/gen-mapping" "^0.1.0"
- "@jridgewell/trace-mapping" "^0.3.9"
-
-"@babel/code-frame@^7.0.0", "@babel/code-frame@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.18.6.tgz#3b25d38c89600baa2dcc219edfa88a74eb2c427a"
- integrity sha512-TDCmlK5eOvH+eH7cdAFlNXeVJqWIQ7gW9tY1GJIpUtFb6CmjVyq2VM3u71bOyR8CRihcCgMUYoDNyLXao3+70Q==
- dependencies:
- "@babel/highlight" "^7.18.6"
-
-"@babel/code-frame@^7.22.13":
- version "7.22.13"
- resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e"
- integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==
- dependencies:
- "@babel/highlight" "^7.22.13"
- chalk "^2.4.2"
-
-"@babel/compat-data@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/compat-data/-/compat-data-7.18.6.tgz#8b37d24e88e8e21c499d4328db80577d8882fa53"
- integrity sha512-tzulrgDT0QD6U7BJ4TKVk2SDDg7wlP39P9yAx1RfLy7vP/7rsDRlWVfbWxElslu56+r7QOhB2NSDsabYYruoZQ==
-
-"@babel/core@^7.17.10":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/core/-/core-7.18.6.tgz#54a107a3c298aee3fe5e1947a6464b9b6faca03d"
- integrity sha512-cQbWBpxcbbs/IUredIPkHiAGULLV8iwgNRMFzvbhEXISp4f3rUUXE5+TIw6KwUWUR3DwyI6gmBRnmAtYaWehwQ==
- dependencies:
- "@ampproject/remapping" "^2.1.0"
- "@babel/code-frame" "^7.18.6"
- "@babel/generator" "^7.18.6"
- "@babel/helper-compilation-targets" "^7.18.6"
- "@babel/helper-module-transforms" "^7.18.6"
- "@babel/helpers" "^7.18.6"
- "@babel/parser" "^7.18.6"
- "@babel/template" "^7.18.6"
- "@babel/traverse" "^7.18.6"
- "@babel/types" "^7.18.6"
- convert-source-map "^1.7.0"
- debug "^4.1.0"
- gensync "^1.0.0-beta.2"
- json5 "^2.2.1"
- semver "^6.3.0"
-
-"@babel/generator@^7.18.6":
- version "7.18.7"
- resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.18.7.tgz#2aa78da3c05aadfc82dbac16c99552fc802284bd"
- integrity sha512-shck+7VLlY72a2w9c3zYWuE1pwOKEiQHV7GTUbSnhyl5eu3i04t30tBY82ZRWrDfo3gkakCFtevExnxbkf2a3A==
- dependencies:
- "@babel/types" "^7.18.7"
- "@jridgewell/gen-mapping" "^0.3.2"
- jsesc "^2.5.1"
-
-"@babel/generator@^7.23.0":
- version "7.23.0"
- resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420"
- integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g==
- dependencies:
- "@babel/types" "^7.23.0"
- "@jridgewell/gen-mapping" "^0.3.2"
- "@jridgewell/trace-mapping" "^0.3.17"
- jsesc "^2.5.1"
-
-"@babel/helper-annotate-as-pure@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.18.6.tgz#eaa49f6f80d5a33f9a5dd2276e6d6e451be0a6bb"
- integrity sha512-duORpUiYrEpzKIop6iNbjnwKLAKnJ47csTyRACyEmWj0QdUrm5aqNJGHSSEQSUAvNW0ojX0dOmK9dZduvkfeXA==
- dependencies:
- "@babel/types" "^7.18.6"
-
-"@babel/helper-compilation-targets@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-compilation-targets/-/helper-compilation-targets-7.18.6.tgz#18d35bfb9f83b1293c22c55b3d576c1315b6ed96"
- integrity sha512-vFjbfhNCzqdeAtZflUFrG5YIFqGTqsctrtkZ1D/NB0mDW9TwW3GmmUepYY4G9wCET5rY5ugz4OGTcLd614IzQg==
- dependencies:
- "@babel/compat-data" "^7.18.6"
- "@babel/helper-validator-option" "^7.18.6"
- browserslist "^4.20.2"
- semver "^6.3.0"
-
-"@babel/helper-environment-visitor@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.6.tgz#b7eee2b5b9d70602e59d1a6cad7dd24de7ca6cd7"
- integrity sha512-8n6gSfn2baOY+qlp+VSzsosjCVGFqWKmDF0cCWOybh52Dw3SEyoWR1KrhMJASjLwIEkkAufZ0xvr+SxLHSpy2Q==
-
-"@babel/helper-environment-visitor@^7.22.20":
- version "7.22.20"
- resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167"
- integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA==
-
-"@babel/helper-function-name@^7.23.0":
- version "7.23.0"
- resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759"
- integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==
- dependencies:
- "@babel/template" "^7.22.15"
- "@babel/types" "^7.23.0"
-
-"@babel/helper-hoist-variables@^7.22.5":
- version "7.22.5"
- resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb"
- integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==
- dependencies:
- "@babel/types" "^7.22.5"
-
-"@babel/helper-module-imports@^7.12.13", "@babel/helper-module-imports@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-module-imports/-/helper-module-imports-7.18.6.tgz#1e3ebdbbd08aad1437b428c50204db13c5a3ca6e"
- integrity sha512-0NFvs3VkuSYbFi1x2Vd6tKrywq+z/cLeYC/RJNFrIX/30Bf5aiGYbtvGXolEktzJH8o5E5KJ3tT+nkxuuZFVlA==
- dependencies:
- "@babel/types" "^7.18.6"
-
-"@babel/helper-module-transforms@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-module-transforms/-/helper-module-transforms-7.18.6.tgz#57e3ca669e273d55c3cda55e6ebf552f37f483c8"
- integrity sha512-L//phhB4al5uucwzlimruukHB3jRd5JGClwRMD/ROrVjXfLqovYnvQrK/JK36WYyVwGGO7OD3kMyVTjx+WVPhw==
- dependencies:
- "@babel/helper-environment-visitor" "^7.18.6"
- "@babel/helper-module-imports" "^7.18.6"
- "@babel/helper-simple-access" "^7.18.6"
- "@babel/helper-split-export-declaration" "^7.18.6"
- "@babel/helper-validator-identifier" "^7.18.6"
- "@babel/template" "^7.18.6"
- "@babel/traverse" "^7.18.6"
- "@babel/types" "^7.18.6"
-
-"@babel/helper-plugin-utils@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-plugin-utils/-/helper-plugin-utils-7.18.6.tgz#9448974dd4fb1d80fefe72e8a0af37809cd30d6d"
- integrity sha512-gvZnm1YAAxh13eJdkb9EWHBnF3eAub3XTLCZEehHT2kWxiKVRL64+ae5Y6Ivne0mVHmMYKT+xWgZO+gQhuLUBg==
-
-"@babel/helper-simple-access@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-simple-access/-/helper-simple-access-7.18.6.tgz#d6d8f51f4ac2978068df934b569f08f29788c7ea"
- integrity sha512-iNpIgTgyAvDQpDj76POqg+YEt8fPxx3yaNBg3S30dxNKm2SWfYhD0TGrK/Eu9wHpUW63VQU894TsTg+GLbUa1g==
- dependencies:
- "@babel/types" "^7.18.6"
-
-"@babel/helper-split-export-declaration@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.18.6.tgz#7367949bc75b20c6d5a5d4a97bba2824ae8ef075"
- integrity sha512-bde1etTx6ZyTmobl9LLMMQsaizFVZrquTEHOqKeQESMKo4PlObf+8+JA25ZsIpZhT/WEd39+vOdLXAFG/nELpA==
- dependencies:
- "@babel/types" "^7.18.6"
-
-"@babel/helper-split-export-declaration@^7.22.6":
- version "7.22.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c"
- integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==
- dependencies:
- "@babel/types" "^7.22.5"
-
-"@babel/helper-string-parser@^7.22.5":
- version "7.22.5"
- resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f"
- integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==
-
-"@babel/helper-validator-identifier@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.18.6.tgz#9c97e30d31b2b8c72a1d08984f2ca9b574d7a076"
- integrity sha512-MmetCkz9ej86nJQV+sFCxoGGrUbU3q02kgLciwkrt9QqEB7cP39oKEY0PakknEO0Gu20SskMRi+AYZ3b1TpN9g==
-
-"@babel/helper-validator-identifier@^7.22.20":
- version "7.22.20"
- resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0"
- integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==
-
-"@babel/helper-validator-option@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helper-validator-option/-/helper-validator-option-7.18.6.tgz#bf0d2b5a509b1f336099e4ff36e1a63aa5db4db8"
- integrity sha512-XO7gESt5ouv/LRJdrVjkShckw6STTaB7l9BrpBaAHDeF5YZT+01PCwmR0SJHnkW6i8OwW/EVWRShfi4j2x+KQw==
-
-"@babel/helpers@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/helpers/-/helpers-7.18.6.tgz#4c966140eaa1fcaa3d5a8c09d7db61077d4debfd"
- integrity sha512-vzSiiqbQOghPngUYt/zWGvK3LAsPhz55vc9XNN0xAl2gV4ieShI2OQli5duxWHD+72PZPTKAcfcZDE1Cwc5zsQ==
- dependencies:
- "@babel/template" "^7.18.6"
- "@babel/traverse" "^7.18.6"
- "@babel/types" "^7.18.6"
-
-"@babel/highlight@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.18.6.tgz#81158601e93e2563795adcbfbdf5d64be3f2ecdf"
- integrity sha512-u7stbOuYjaPezCuLj29hNW1v64M2Md2qupEKP1fHc7WdOA3DgLh37suiSrZYY7haUB7iBeQZ9P1uiRF359do3g==
- dependencies:
- "@babel/helper-validator-identifier" "^7.18.6"
- chalk "^2.0.0"
- js-tokens "^4.0.0"
-
-"@babel/highlight@^7.22.13":
- version "7.22.20"
- resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54"
- integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==
- dependencies:
- "@babel/helper-validator-identifier" "^7.22.20"
- chalk "^2.4.2"
- js-tokens "^4.0.0"
-
-"@babel/parser@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.18.6.tgz#845338edecad65ebffef058d3be851f1d28a63bc"
- integrity sha512-uQVSa9jJUe/G/304lXspfWVpKpK4euFLgGiMQFOCpM/bgcAdeoHwi/OQz23O9GK2osz26ZiXRRV9aV+Yl1O8tw==
-
-"@babel/parser@^7.22.15", "@babel/parser@^7.23.0":
- version "7.23.0"
- resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719"
- integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw==
-
-"@babel/plugin-syntax-jsx@^7.12.13", "@babel/plugin-syntax-jsx@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/plugin-syntax-jsx/-/plugin-syntax-jsx-7.18.6.tgz#a8feef63b010150abd97f1649ec296e849943ca0"
- integrity sha512-6mmljtAedFGTWu2p/8WIORGwy+61PLgOMPOdazc7YoJ9ZCWUyFy3A6CpPkRKLKD1ToAesxX8KGEViAiLo9N+7Q==
- dependencies:
- "@babel/helper-plugin-utils" "^7.18.6"
-
-"@babel/plugin-transform-react-jsx-development@^7.16.7":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-development/-/plugin-transform-react-jsx-development-7.18.6.tgz#dbe5c972811e49c7405b630e4d0d2e1380c0ddc5"
- integrity sha512-SA6HEjwYFKF7WDjWcMcMGUimmw/nhNRDWxr+KaLSCrkD/LMDBvWRmHAYgE1HDeF8KUuI8OAu+RT6EOtKxSW2qA==
- dependencies:
- "@babel/plugin-transform-react-jsx" "^7.18.6"
-
-"@babel/plugin-transform-react-jsx-self@^7.16.7":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-self/-/plugin-transform-react-jsx-self-7.18.6.tgz#3849401bab7ae8ffa1e3e5687c94a753fc75bda7"
- integrity sha512-A0LQGx4+4Jv7u/tWzoJF7alZwnBDQd6cGLh9P+Ttk4dpiL+J5p7NSNv/9tlEFFJDq3kjxOavWmbm6t0Gk+A3Ig==
- dependencies:
- "@babel/helper-plugin-utils" "^7.18.6"
-
-"@babel/plugin-transform-react-jsx-source@^7.16.7":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx-source/-/plugin-transform-react-jsx-source-7.18.6.tgz#06e9ae8a14d2bc19ce6e3c447d842032a50598fc"
- integrity sha512-utZmlASneDfdaMh0m/WausbjUjEdGrQJz0vFK93d7wD3xf5wBtX219+q6IlCNZeguIcxS2f/CvLZrlLSvSHQXw==
- dependencies:
- "@babel/helper-plugin-utils" "^7.18.6"
-
-"@babel/plugin-transform-react-jsx@^7.17.3", "@babel/plugin-transform-react-jsx@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.18.6.tgz#2721e96d31df96e3b7ad48ff446995d26bc028ff"
- integrity sha512-Mz7xMPxoy9kPS/JScj6fJs03TZ/fZ1dJPlMjRAgTaxaS0fUBk8FV/A2rRgfPsVCZqALNwMexD+0Uaf5zlcKPpw==
- dependencies:
- "@babel/helper-annotate-as-pure" "^7.18.6"
- "@babel/helper-module-imports" "^7.18.6"
- "@babel/helper-plugin-utils" "^7.18.6"
- "@babel/plugin-syntax-jsx" "^7.18.6"
- "@babel/types" "^7.18.6"
-
-"@babel/runtime@^7.13.10", "@babel/runtime@^7.17.2", "@babel/runtime@^7.5.5", "@babel/runtime@^7.7.2", "@babel/runtime@^7.8.7":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.18.6.tgz#6a1ef59f838debd670421f8c7f2cbb8da9751580"
- integrity sha512-t9wi7/AW6XtKahAe20Yw0/mMljKq0B1r2fPdvaAdV/KPDZewFXdaaa6K7lxmZBZ8FBNpCiAT6iHPmd6QO9bKfQ==
- dependencies:
- regenerator-runtime "^0.13.4"
-
-"@babel/template@^7.18.6":
- version "7.18.6"
- resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.18.6.tgz#1283f4993e00b929d6e2d3c72fdc9168a2977a31"
- integrity sha512-JoDWzPe+wgBsTTgdnIma3iHNFC7YVJoPssVBDjiHfNlyt4YcunDtcDOUmfVDfCK5MfdsaIoX9PkijPhjH3nYUw==
- dependencies:
- "@babel/code-frame" "^7.18.6"
- "@babel/parser" "^7.18.6"
- "@babel/types" "^7.18.6"
-
-"@babel/template@^7.22.15":
- version "7.22.15"
- resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38"
- integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==
- dependencies:
- "@babel/code-frame" "^7.22.13"
- "@babel/parser" "^7.22.15"
- "@babel/types" "^7.22.15"
-
-"@babel/traverse@^7.18.6":
- version "7.23.2"
- resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8"
- integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==
- dependencies:
- "@babel/code-frame" "^7.22.13"
- "@babel/generator" "^7.23.0"
- "@babel/helper-environment-visitor" "^7.22.20"
- "@babel/helper-function-name" "^7.23.0"
- "@babel/helper-hoist-variables" "^7.22.5"
- "@babel/helper-split-export-declaration" "^7.22.6"
- "@babel/parser" "^7.23.0"
- "@babel/types" "^7.23.0"
- debug "^4.1.0"
- globals "^11.1.0"
-
-"@babel/types@^7.18.6", "@babel/types@^7.18.7":
- version "7.18.7"
- resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.18.7.tgz#a4a2c910c15040ea52cdd1ddb1614a65c8041726"
- integrity sha512-QG3yxTcTIBoAcQmkCs+wAPYZhu7Dk9rXKacINfNbdJDNERTbLQbHGyVG8q/YGMPeCJRIhSY0+fTc5+xuh6WPSQ==
- dependencies:
- "@babel/helper-validator-identifier" "^7.18.6"
- to-fast-properties "^2.0.0"
-
-"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0":
- version "7.23.0"
- resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb"
- integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg==
- dependencies:
- "@babel/helper-string-parser" "^7.22.5"
- "@babel/helper-validator-identifier" "^7.22.20"
- to-fast-properties "^2.0.0"
-
-"@emotion/babel-plugin@^11.7.1":
- version "11.9.2"
- resolved "https://registry.yarnpkg.com/@emotion/babel-plugin/-/babel-plugin-11.9.2.tgz#723b6d394c89fb2ef782229d92ba95a740576e95"
- integrity sha512-Pr/7HGH6H6yKgnVFNEj2MVlreu3ADqftqjqwUvDy/OJzKFgxKeTQ+eeUf20FOTuHVkDON2iNa25rAXVYtWJCjw==
- dependencies:
- "@babel/helper-module-imports" "^7.12.13"
- "@babel/plugin-syntax-jsx" "^7.12.13"
- "@babel/runtime" "^7.13.10"
- "@emotion/hash" "^0.8.0"
- "@emotion/memoize" "^0.7.5"
- "@emotion/serialize" "^1.0.2"
- babel-plugin-macros "^2.6.1"
- convert-source-map "^1.5.0"
- escape-string-regexp "^4.0.0"
- find-root "^1.1.0"
- source-map "^0.5.7"
- stylis "4.0.13"
-
-"@emotion/cache@^11.7.1", "@emotion/cache@^11.9.3":
- version "11.9.3"
- resolved "https://registry.yarnpkg.com/@emotion/cache/-/cache-11.9.3.tgz#96638449f6929fd18062cfe04d79b29b44c0d6cb"
- integrity sha512-0dgkI/JKlCXa+lEXviaMtGBL0ynpx4osh7rjOXE71q9bIF8G+XhJgvi+wDu0B0IdCVx37BffiwXlN9I3UuzFvg==
- dependencies:
- "@emotion/memoize" "^0.7.4"
- "@emotion/sheet" "^1.1.1"
- "@emotion/utils" "^1.0.0"
- "@emotion/weak-memoize" "^0.2.5"
- stylis "4.0.13"
-
-"@emotion/hash@^0.8.0":
- version "0.8.0"
- resolved "https://registry.yarnpkg.com/@emotion/hash/-/hash-0.8.0.tgz#bbbff68978fefdbe68ccb533bc8cbe1d1afb5413"
- integrity sha512-kBJtf7PH6aWwZ6fka3zQ0p6SBYzx4fl1LoZXE2RrnYST9Xljm7WfKJrU4g/Xr3Beg72MLrp1AWNUmuYJTL7Cow==
-
-"@emotion/is-prop-valid@^1.1.2", "@emotion/is-prop-valid@^1.1.3":
- version "1.1.3"
- resolved "https://registry.yarnpkg.com/@emotion/is-prop-valid/-/is-prop-valid-1.1.3.tgz#f0907a416368cf8df9e410117068e20fe87c0a3a"
- integrity sha512-RFg04p6C+1uO19uG8N+vqanzKqiM9eeV1LDOG3bmkYmuOj7NbKNlFC/4EZq5gnwAIlcC/jOT24f8Td0iax2SXA==
- dependencies:
- "@emotion/memoize" "^0.7.4"
-
-"@emotion/memoize@^0.7.4", "@emotion/memoize@^0.7.5":
- version "0.7.5"
- resolved "https://registry.yarnpkg.com/@emotion/memoize/-/memoize-0.7.5.tgz#2c40f81449a4e554e9fc6396910ed4843ec2be50"
- integrity sha512-igX9a37DR2ZPGYtV6suZ6whr8pTFtyHL3K/oLUotxpSVO2ASaprmAe2Dkq7tBo7CRY7MMDrAa9nuQP9/YG8FxQ==
-
-"@emotion/react@^11.8.2":
- version "11.9.3"
- resolved "https://registry.yarnpkg.com/@emotion/react/-/react-11.9.3.tgz#f4f4f34444f6654a2e550f5dab4f2d360c101df9"
- integrity sha512-g9Q1GcTOlzOEjqwuLF/Zd9LC+4FljjPjDfxSM7KmEakm+hsHXk+bYZ2q+/hTJzr0OUNkujo72pXLQvXj6H+GJQ==
- dependencies:
- "@babel/runtime" "^7.13.10"
- "@emotion/babel-plugin" "^11.7.1"
- "@emotion/cache" "^11.9.3"
- "@emotion/serialize" "^1.0.4"
- "@emotion/utils" "^1.1.0"
- "@emotion/weak-memoize" "^0.2.5"
- hoist-non-react-statics "^3.3.1"
-
-"@emotion/serialize@^1.0.2", "@emotion/serialize@^1.0.4":
- version "1.0.4"
- resolved "https://registry.yarnpkg.com/@emotion/serialize/-/serialize-1.0.4.tgz#ff31fd11bb07999611199c2229e152faadc21a3c"
- integrity sha512-1JHamSpH8PIfFwAMryO2bNka+y8+KA5yga5Ocf2d7ZEiJjb7xlLW7aknBGZqJLajuLOvJ+72vN+IBSwPlXD1Pg==
- dependencies:
- "@emotion/hash" "^0.8.0"
- "@emotion/memoize" "^0.7.4"
- "@emotion/unitless" "^0.7.5"
- "@emotion/utils" "^1.0.0"
- csstype "^3.0.2"
-
-"@emotion/sheet@^1.1.1":
- version "1.1.1"
- resolved "https://registry.yarnpkg.com/@emotion/sheet/-/sheet-1.1.1.tgz#015756e2a9a3c7c5f11d8ec22966a8dbfbfac787"
- integrity sha512-J3YPccVRMiTZxYAY0IOq3kd+hUP8idY8Kz6B/Cyo+JuXq52Ek+zbPbSQUrVQp95aJ+lsAW7DPL1P2Z+U1jGkKA==
-
-"@emotion/styled@^11.8.1":
- version "11.9.3"
- resolved "https://registry.yarnpkg.com/@emotion/styled/-/styled-11.9.3.tgz#47f0c71137fec7c57035bf3659b52fb536792340"
- integrity sha512-o3sBNwbtoVz9v7WB1/Y/AmXl69YHmei2mrVnK7JgyBJ//Rst5yqPZCecEJlMlJrFeWHp+ki/54uN265V2pEcXA==
- dependencies:
- "@babel/runtime" "^7.13.10"
- "@emotion/babel-plugin" "^11.7.1"
- "@emotion/is-prop-valid" "^1.1.3"
- "@emotion/serialize" "^1.0.4"
- "@emotion/utils" "^1.1.0"
-
-"@emotion/unitless@^0.7.5":
- version "0.7.5"
- resolved "https://registry.yarnpkg.com/@emotion/unitless/-/unitless-0.7.5.tgz#77211291c1900a700b8a78cfafda3160d76949ed"
- integrity sha512-OWORNpfjMsSSUBVrRBVGECkhWcULOAJz9ZW8uK9qgxD+87M7jHRcvh/A96XXNhXTLmKcoYSQtBEX7lHMO7YRwg==
-
-"@emotion/utils@^1.0.0", "@emotion/utils@^1.1.0":
- version "1.1.0"
- resolved "https://registry.yarnpkg.com/@emotion/utils/-/utils-1.1.0.tgz#86b0b297f3f1a0f2bdb08eeac9a2f49afd40d0cf"
- integrity sha512-iRLa/Y4Rs5H/f2nimczYmS5kFJEbpiVvgN3XVfZ022IYhuNA1IRSHEizcof88LtCTXtl9S2Cxt32KgaXEu72JQ==
-
-"@emotion/weak-memoize@^0.2.5":
- version "0.2.5"
- resolved "https://registry.yarnpkg.com/@emotion/weak-memoize/-/weak-memoize-0.2.5.tgz#8eed982e2ee6f7f4e44c253e12962980791efd46"
- integrity sha512-6U71C2Wp7r5XtFtQzYrW5iKFT67OixrSxjI4MptCHzdSVlgabczzqLe0ZSgnub/5Kp4hSbpDB1tMytZY9pwxxA==
-
-"@jridgewell/gen-mapping@^0.1.0":
- version "0.1.1"
- resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.1.1.tgz#e5d2e450306a9491e3bd77e323e38d7aff315996"
- integrity sha512-sQXCasFk+U8lWYEe66WxRDOE9PjVz4vSM51fTu3Hw+ClTpUSQb718772vH3pyS5pShp6lvQM7SxgIDXXXmOX7w==
- dependencies:
- "@jridgewell/set-array" "^1.0.0"
- "@jridgewell/sourcemap-codec" "^1.4.10"
-
-"@jridgewell/gen-mapping@^0.3.2":
- version "0.3.2"
- resolved "https://registry.yarnpkg.com/@jridgewell/gen-mapping/-/gen-mapping-0.3.2.tgz#c1aedc61e853f2bb9f5dfe6d4442d3b565b253b9"
- integrity sha512-mh65xKQAzI6iBcFzwv28KVWSmCkdRBWoOh+bYQGW3+6OZvbbN3TqMGo5hqYxQniRcH9F2VZIoJCm4pa3BPDK/A==
- dependencies:
- "@jridgewell/set-array" "^1.0.1"
- "@jridgewell/sourcemap-codec" "^1.4.10"
- "@jridgewell/trace-mapping" "^0.3.9"
-
-"@jridgewell/resolve-uri@^3.0.3":
- version "3.0.8"
- resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.0.8.tgz#687cc2bbf243f4e9a868ecf2262318e2658873a1"
- integrity sha512-YK5G9LaddzGbcucK4c8h5tWFmMPBvRZ/uyWmN1/SbBdIvqGUdWGkJ5BAaccgs6XbzVLsqbPJrBSFwKv3kT9i7w==
-
-"@jridgewell/resolve-uri@^3.1.0":
- version "3.1.1"
- resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz#c08679063f279615a3326583ba3a90d1d82cc721"
- integrity sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==
-
-"@jridgewell/set-array@^1.0.0", "@jridgewell/set-array@^1.0.1":
- version "1.1.2"
- resolved "https://registry.yarnpkg.com/@jridgewell/set-array/-/set-array-1.1.2.tgz#7c6cf998d6d20b914c0a55a91ae928ff25965e72"
- integrity sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==
-
-"@jridgewell/sourcemap-codec@^1.4.10":
- version "1.4.14"
- resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz#add4c98d341472a289190b424efbdb096991bb24"
- integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==
-
-"@jridgewell/sourcemap-codec@^1.4.14":
- version "1.4.15"
- resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz#d7c6e6755c78567a951e04ab52ef0fd26de59f32"
- integrity sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==
-
-"@jridgewell/trace-mapping@^0.3.17":
- version "0.3.20"
- resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz#72e45707cf240fa6b081d0366f8265b0cd10197f"
- integrity sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==
- dependencies:
- "@jridgewell/resolve-uri" "^3.1.0"
- "@jridgewell/sourcemap-codec" "^1.4.14"
-
-"@jridgewell/trace-mapping@^0.3.9":
- version "0.3.14"
- resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.14.tgz#b231a081d8f66796e475ad588a1ef473112701ed"
- integrity sha512-bJWEfQ9lPTvm3SneWwRFVLzrh6nhjwqw7TUFFBEMzwvg7t7PCDenf2lDwqo4NQXzdpgBXyFgDWnQA+2vkruksQ==
- dependencies:
- "@jridgewell/resolve-uri" "^3.0.3"
- "@jridgewell/sourcemap-codec" "^1.4.10"
-
-"@mui/base@5.0.0-alpha.86":
- version "5.0.0-alpha.86"
- resolved "https://registry.yarnpkg.com/@mui/base/-/base-5.0.0-alpha.86.tgz#7ac5af939cec7e763c1bf49bf5e30bb9464c4ebf"
- integrity sha512-0vi/Nni1mizrgrzKeyksEjw5JVSrgT8Vr2NhxzFtYxqpMgtdSrBvcmcuzBf9kE/ECMPbgpSIcqv0nLbLZUYkOQ==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@emotion/is-prop-valid" "^1.1.2"
- "@mui/types" "^7.1.4"
- "@mui/utils" "^5.8.4"
- "@popperjs/core" "^2.11.5"
- clsx "^1.1.1"
- prop-types "^15.8.1"
- react-is "^17.0.2"
-
-"@mui/material@5.8.5":
- version "5.8.5"
- resolved "https://registry.yarnpkg.com/@mui/material/-/material-5.8.5.tgz#a1a79fc57b212a9781eb4a53e9995c4a9df04753"
- integrity sha512-wngPXlOI9BurLSGlObQM/2L0QFFaIcvJnDK5A+ALxuUyuQnPviVWfC1l/r8rPlxQ4PCbSYpq3gzLlgnLoWcO/g==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@mui/base" "5.0.0-alpha.86"
- "@mui/system" "^5.8.5"
- "@mui/types" "^7.1.4"
- "@mui/utils" "^5.8.4"
- "@types/react-transition-group" "^4.4.4"
- clsx "^1.1.1"
- csstype "^3.1.0"
- prop-types "^15.8.1"
- react-is "^17.0.2"
- react-transition-group "^4.4.2"
-
-"@mui/private-theming@^5.8.6":
- version "5.8.6"
- resolved "https://registry.yarnpkg.com/@mui/private-theming/-/private-theming-5.8.6.tgz#db2bafeda1699e43e67b3ff4f770d6b7a234501f"
- integrity sha512-yHsJk1qU9r/q0DlnxGRJPHyM0Y/nUv8FTNgDTiI9I58GWuVuZqeTUr7JRvPh6ybeP/FLtW5eXEavRK9wxVk4uQ==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@mui/utils" "^5.8.6"
- prop-types "^15.8.1"
-
-"@mui/styled-engine@^5.8.0":
- version "5.8.0"
- resolved "https://registry.yarnpkg.com/@mui/styled-engine/-/styled-engine-5.8.0.tgz#89ed42efe7c8749e5a60af035bc5d3a6bea362bf"
- integrity sha512-Q3spibB8/EgeMYHc+/o3RRTnAYkSl7ROCLhXJ830W8HZ2/iDiyYp16UcxKPurkXvLhUaILyofPVrP3Su2uKsAw==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@emotion/cache" "^11.7.1"
- prop-types "^15.8.1"
-
-"@mui/system@^5.8.5":
- version "5.8.6"
- resolved "https://registry.yarnpkg.com/@mui/system/-/system-5.8.6.tgz#aed7e501c513429dab9cfbbe86da5dcd056c2a0a"
- integrity sha512-+a+rD58XltKQHDrrjcuCta2cUBqdnLDUDwnphSLCMFigRl8/uk+R+fdQRlMNRXAOgnMb8ioWIgfjxri5pmTH4A==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@mui/private-theming" "^5.8.6"
- "@mui/styled-engine" "^5.8.0"
- "@mui/types" "^7.1.4"
- "@mui/utils" "^5.8.6"
- clsx "^1.1.1"
- csstype "^3.1.0"
- prop-types "^15.8.1"
-
-"@mui/types@^7.1.4":
- version "7.1.4"
- resolved "https://registry.yarnpkg.com/@mui/types/-/types-7.1.4.tgz#4185c05d6df63ec673cda15feab80440abadc764"
- integrity sha512-uveM3byMbthO+6tXZ1n2zm0W3uJCQYtwt/v5zV5I77v2v18u0ITkb8xwhsDD2i3V2Kye7SaNR6FFJ6lMuY/WqQ==
-
-"@mui/utils@^5.8.4", "@mui/utils@^5.8.6":
- version "5.8.6"
- resolved "https://registry.yarnpkg.com/@mui/utils/-/utils-5.8.6.tgz#543de64a64bb9135316ecfd91d75a8740544d79f"
- integrity sha512-QM2Sd1xZo2jOt2Vz5Rmro+pi2FLJyiv4+OjxkUwXR3oUM65KSMAMLl/KNYU55s3W3DLRFP5MVwE4FhAbHseHAg==
- dependencies:
- "@babel/runtime" "^7.17.2"
- "@types/prop-types" "^15.7.5"
- "@types/react-is" "^16.7.1 || ^17.0.0"
- prop-types "^15.8.1"
- react-is "^17.0.2"
-
-"@popperjs/core@^2.11.5":
- version "2.11.5"
- resolved "https://registry.yarnpkg.com/@popperjs/core/-/core-2.11.5.tgz#db5a11bf66bdab39569719555b0f76e138d7bd64"
- integrity sha512-9X2obfABZuDVLCgPK9aX0a/x4jaOEweTTWE2+9sr0Qqqevj2Uv5XorvusThmc9XGYpS9yI+fhh8RTafBtGposw==
-
-"@rollup/pluginutils@^4.2.1":
- version "4.2.1"
- resolved "https://registry.yarnpkg.com/@rollup/pluginutils/-/pluginutils-4.2.1.tgz#e6c6c3aba0744edce3fb2074922d3776c0af2a6d"
- integrity sha512-iKnFXr7NkdZAIHiIWE+BX5ULi/ucVFYWD6TbAV+rZctiRTY2PL6tsIKhoIOaoskiWAkgu+VsbXgUVDNLHf+InQ==
- dependencies:
- estree-walker "^2.0.1"
- picomatch "^2.2.2"
-
-"@types/lodash@^4.14.179":
- version "4.14.182"
- resolved "https://registry.yarnpkg.com/@types/lodash/-/lodash-4.14.182.tgz#05301a4d5e62963227eaafe0ce04dd77c54ea5c2"
- integrity sha512-/THyiqyQAP9AfARo4pF+aCGcyiQ94tX/Is2I7HofNRqoYLgN1PBoOWu2/zTA5zMxzP5EFutMtWtGAFRKUe961Q==
-
-"@types/parse-json@^4.0.0":
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.0.tgz#2f8bb441434d163b35fb8ffdccd7138927ffb8c0"
- integrity sha512-//oorEZjL6sbPcKUaCdIGlIUeH26mgzimjBB77G6XRgnDl/L5wOnpyBGRe/Mmf5CVW3PwEBE1NjiMZ/ssFh4wA==
-
-"@types/prop-types@*", "@types/prop-types@^15.7.5":
- version "15.7.5"
- resolved "https://registry.yarnpkg.com/@types/prop-types/-/prop-types-15.7.5.tgz#5f19d2b85a98e9558036f6a3cacc8819420f05cf"
- integrity sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==
-
-"@types/react-dom@^18.0.0":
- version "18.0.5"
- resolved "https://registry.yarnpkg.com/@types/react-dom/-/react-dom-18.0.5.tgz#330b2d472c22f796e5531446939eacef8378444a"
- integrity sha512-OWPWTUrY/NIrjsAPkAk1wW9LZeIjSvkXRhclsFO8CZcZGCOg2G0YZy4ft+rOyYxy8B7ui5iZzi9OkDebZ7/QSA==
- dependencies:
- "@types/react" "*"
-
-"@types/react-is@^16.7.1 || ^17.0.0":
- version "17.0.3"
- resolved "https://registry.yarnpkg.com/@types/react-is/-/react-is-17.0.3.tgz#2d855ba575f2fc8d17ef9861f084acc4b90a137a"
- integrity sha512-aBTIWg1emtu95bLTLx0cpkxwGW3ueZv71nE2YFBpL8k/z5czEW8yYpOo8Dp+UUAFAtKwNaOsh/ioSeQnWlZcfw==
- dependencies:
- "@types/react" "*"
-
-"@types/react-transition-group@^4.4.4":
- version "4.4.5"
- resolved "https://registry.yarnpkg.com/@types/react-transition-group/-/react-transition-group-4.4.5.tgz#aae20dcf773c5aa275d5b9f7cdbca638abc5e416"
- integrity sha512-juKD/eiSM3/xZYzjuzH6ZwpP+/lejltmiS3QEzV/vmb/Q8+HfDmxu+Baga8UEMGBqV88Nbg4l2hY/K2DkyaLLA==
- dependencies:
- "@types/react" "*"
-
-"@types/react@*", "@types/react@^18.0.1":
- version "18.0.14"
- resolved "https://registry.yarnpkg.com/@types/react/-/react-18.0.14.tgz#e016616ffff51dba01b04945610fe3671fdbe06d"
- integrity sha512-x4gGuASSiWmo0xjDLpm5mPb52syZHJx02VKbqUKdLmKtAwIh63XClGsiTI1K6DO5q7ox4xAsQrU+Gl3+gGXF9Q==
- dependencies:
- "@types/prop-types" "*"
- "@types/scheduler" "*"
- csstype "^3.0.2"
-
-"@types/scheduler@*":
- version "0.16.2"
- resolved "https://registry.yarnpkg.com/@types/scheduler/-/scheduler-0.16.2.tgz#1a62f89525723dde24ba1b01b092bf5df8ad4d39"
- integrity sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew==
-
-"@vitejs/plugin-react@^1.0.7":
- version "1.3.2"
- resolved "https://registry.yarnpkg.com/@vitejs/plugin-react/-/plugin-react-1.3.2.tgz#2fcf0b6ce9bcdcd4cec5c760c199779d5657ece1"
- integrity sha512-aurBNmMo0kz1O4qRoY+FM4epSA39y3ShWGuqfLRA/3z0oEJAdtoSfgA3aO98/PCCHAqMaduLxIxErWrVKIFzXA==
- dependencies:
- "@babel/core" "^7.17.10"
- "@babel/plugin-transform-react-jsx" "^7.17.3"
- "@babel/plugin-transform-react-jsx-development" "^7.16.7"
- "@babel/plugin-transform-react-jsx-self" "^7.16.7"
- "@babel/plugin-transform-react-jsx-source" "^7.16.7"
- "@rollup/pluginutils" "^4.2.1"
- react-refresh "^0.13.0"
- resolve "^1.22.0"
-
-ansi-styles@^3.2.1:
- version "3.2.1"
- resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-3.2.1.tgz#41fbb20243e50b12be0f04b8dedbf07520ce841d"
- integrity sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==
- dependencies:
- color-convert "^1.9.0"
-
-asynckit@^0.4.0:
- version "0.4.0"
- resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79"
- integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==
-
-axios@^1.6.0:
- version "1.6.0"
- resolved "https://registry.yarnpkg.com/axios/-/axios-1.6.0.tgz#f1e5292f26b2fd5c2e66876adc5b06cdbd7d2102"
- integrity sha512-EZ1DYihju9pwVB+jg67ogm+Tmqc6JmhamRN6I4Zt8DfZu5lbcQGw3ozH9lFejSJgs/ibaef3A9PMXPLeefFGJg==
- dependencies:
- follow-redirects "^1.15.0"
- form-data "^4.0.0"
- proxy-from-env "^1.1.0"
-
-babel-plugin-macros@^2.6.1:
- version "2.8.0"
- resolved "https://registry.yarnpkg.com/babel-plugin-macros/-/babel-plugin-macros-2.8.0.tgz#0f958a7cc6556b1e65344465d99111a1e5e10138"
- integrity sha512-SEP5kJpfGYqYKpBrj5XU3ahw5p5GOHJ0U5ssOSQ/WBVdwkD2Dzlce95exQTs3jOVWPPKLBN2rlEWkCK7dSmLvg==
- dependencies:
- "@babel/runtime" "^7.7.2"
- cosmiconfig "^6.0.0"
- resolve "^1.12.0"
-
-browserslist@^4.20.2:
- version "4.21.1"
- resolved "https://registry.yarnpkg.com/browserslist/-/browserslist-4.21.1.tgz#c9b9b0a54c7607e8dc3e01a0d311727188011a00"
- integrity sha512-Nq8MFCSrnJXSc88yliwlzQe3qNe3VntIjhsArW9IJOEPSHNx23FalwApUVbzAWABLhYJJ7y8AynWI/XM8OdfjQ==
- dependencies:
- caniuse-lite "^1.0.30001359"
- electron-to-chromium "^1.4.172"
- node-releases "^2.0.5"
- update-browserslist-db "^1.0.4"
-
-callsites@^3.0.0:
- version "3.1.0"
- resolved "https://registry.yarnpkg.com/callsites/-/callsites-3.1.0.tgz#b3630abd8943432f54b3f0519238e33cd7df2f73"
- integrity sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==
-
-caniuse-lite@^1.0.30001359:
- version "1.0.30001361"
- resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001361.tgz#ba2adb2527566fb96f3ac7c67698ae7fc495a28d"
- integrity sha512-ybhCrjNtkFji1/Wto6SSJKkWk6kZgVQsDq5QI83SafsF6FXv2JB4df9eEdH6g8sdGgqTXrFLjAxqBGgYoU3azQ==
-
-chalk@^2.0.0, chalk@^2.4.2:
- version "2.4.2"
- resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.2.tgz#cd42541677a54333cf541a49108c1432b44c9424"
- integrity sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==
- dependencies:
- ansi-styles "^3.2.1"
- escape-string-regexp "^1.0.5"
- supports-color "^5.3.0"
-
-clsx@^1.1.1:
- version "1.1.1"
- resolved "https://registry.yarnpkg.com/clsx/-/clsx-1.1.1.tgz#98b3134f9abbdf23b2663491ace13c5c03a73188"
- integrity sha512-6/bPho624p3S2pMyvP5kKBPXnI3ufHLObBFCfgx+LkeR5lg2XYy2hqZqUf45ypD8COn2bhgGJSUE+l5dhNBieA==
-
-color-convert@^1.9.0:
- version "1.9.3"
- resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-1.9.3.tgz#bb71850690e1f136567de629d2d5471deda4c1e8"
- integrity sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==
- dependencies:
- color-name "1.1.3"
-
-color-name@1.1.3:
- version "1.1.3"
- resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.3.tgz#a7d0558bd89c42f795dd42328f740831ca53bc25"
- integrity sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==
-
-combined-stream@^1.0.8:
- version "1.0.8"
- resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f"
- integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==
- dependencies:
- delayed-stream "~1.0.0"
-
-convert-source-map@^1.5.0, convert-source-map@^1.7.0:
- version "1.8.0"
- resolved "https://registry.yarnpkg.com/convert-source-map/-/convert-source-map-1.8.0.tgz#f3373c32d21b4d780dd8004514684fb791ca4369"
- integrity sha512-+OQdjP49zViI/6i7nIJpA8rAl4sV/JdPfU9nZs3VqOwGIgizICvuN2ru6fMd+4llL0tar18UYJXfZ/TWtmhUjA==
- dependencies:
- safe-buffer "~5.1.1"
-
-cosmiconfig@^6.0.0:
- version "6.0.0"
- resolved "https://registry.yarnpkg.com/cosmiconfig/-/cosmiconfig-6.0.0.tgz#da4fee853c52f6b1e6935f41c1a2fc50bd4a9982"
- integrity sha512-xb3ZL6+L8b9JLLCx3ZdoZy4+2ECphCMo2PwqgP1tlfVq6M6YReyzBJtvWWtbDSpNr9hn96pkCiZqUcFEc+54Qg==
- dependencies:
- "@types/parse-json" "^4.0.0"
- import-fresh "^3.1.0"
- parse-json "^5.0.0"
- path-type "^4.0.0"
- yaml "^1.7.2"
-
-csstype@^3.0.2, csstype@^3.1.0:
- version "3.1.0"
- resolved "https://registry.yarnpkg.com/csstype/-/csstype-3.1.0.tgz#4ddcac3718d787cf9df0d1b7d15033925c8f29f2"
- integrity sha512-uX1KG+x9h5hIJsaKR9xHUeUraxf8IODOwq9JLNPq6BwB04a/xgpq3rcx47l5BZu5zBPlgD342tdke3Hom/nJRA==
-
-debug@^4.1.0:
- version "4.3.4"
- resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.4.tgz#1319f6579357f2338d3337d2cdd4914bb5dcc865"
- integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==
- dependencies:
- ms "2.1.2"
-
-delayed-stream@~1.0.0:
- version "1.0.0"
- resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619"
- integrity sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==
-
-dom-helpers@^5.0.1:
- version "5.2.1"
- resolved "https://registry.yarnpkg.com/dom-helpers/-/dom-helpers-5.2.1.tgz#d9400536b2bf8225ad98fe052e029451ac40e902"
- integrity sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==
- dependencies:
- "@babel/runtime" "^7.8.7"
- csstype "^3.0.2"
-
-electron-to-chromium@^1.4.172:
- version "1.4.174"
- resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.4.174.tgz#ffdf57f26dd4558c5aabdb4b190c47af1c4e443b"
- integrity sha512-JER+w+9MV2MBVFOXxP036bLlNOnzbYAWrWU8sNUwoOO69T3w4564WhM5H5atd8VVS8U4vpi0i0kdoYzm1NPQgQ==
-
-error-ex@^1.3.1:
- version "1.3.2"
- resolved "https://registry.yarnpkg.com/error-ex/-/error-ex-1.3.2.tgz#b4ac40648107fdcdcfae242f428bea8a14d4f1bf"
- integrity sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==
- dependencies:
- is-arrayish "^0.2.1"
-
-esbuild-android-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-android-64/-/esbuild-android-64-0.14.48.tgz#7e6394a0e517f738641385aaf553c7e4fb6d1ae3"
- integrity sha512-3aMjboap/kqwCUpGWIjsk20TtxVoKck8/4Tu19rubh7t5Ra0Yrpg30Mt1QXXlipOazrEceGeWurXKeFJgkPOUg==
-
-esbuild-android-arm64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-android-arm64/-/esbuild-android-arm64-0.14.48.tgz#6877566be0f82dd5a43030c0007d06ece7f7c02f"
- integrity sha512-vptI3K0wGALiDq+EvRuZotZrJqkYkN5282iAfcffjI5lmGG9G1ta/CIVauhY42MBXwEgDJkweiDcDMRLzBZC4g==
-
-esbuild-darwin-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-darwin-64/-/esbuild-darwin-64-0.14.48.tgz#ea3caddb707d88f844b1aa1dea5ff3b0a71ef1fd"
- integrity sha512-gGQZa4+hab2Va/Zww94YbshLuWteyKGD3+EsVon8EWTWhnHFRm5N9NbALNbwi/7hQ/hM1Zm4FuHg+k6BLsl5UA==
-
-esbuild-darwin-arm64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-darwin-arm64/-/esbuild-darwin-arm64-0.14.48.tgz#4e5eaab54df66cc319b76a2ac0e8af4e6f0d9c2f"
- integrity sha512-bFjnNEXjhZT+IZ8RvRGNJthLWNHV5JkCtuOFOnjvo5pC0sk2/QVk0Qc06g2PV3J0TcU6kaPC3RN9yy9w2PSLEA==
-
-esbuild-freebsd-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-freebsd-64/-/esbuild-freebsd-64-0.14.48.tgz#47b5abc7426eae66861490ffbb380acc67af5b15"
- integrity sha512-1NOlwRxmOsnPcWOGTB10JKAkYSb2nue0oM1AfHWunW/mv3wERfJmnYlGzL3UAOIUXZqW8GeA2mv+QGwq7DToqA==
-
-esbuild-freebsd-arm64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-freebsd-arm64/-/esbuild-freebsd-arm64-0.14.48.tgz#e8c54c8637cd44feed967ea12338b0a4da3a7b11"
- integrity sha512-gXqKdO8wabVcYtluAbikDH2jhXp+Klq5oCD5qbVyUG6tFiGhrC9oczKq3vIrrtwcxDQqK6+HDYK8Zrd4bCA9Gw==
-
-esbuild-linux-32@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-32/-/esbuild-linux-32-0.14.48.tgz#229cf3246de2b7937c3ac13fac622d4d7a1344c5"
- integrity sha512-ghGyDfS289z/LReZQUuuKq9KlTiTspxL8SITBFQFAFRA/IkIvDpnZnCAKTCjGXAmUqroMQfKJXMxyjJA69c/nQ==
-
-esbuild-linux-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-64/-/esbuild-linux-64-0.14.48.tgz#7c0e7226c02c42aacc5656c36977493dc1e96c4f"
- integrity sha512-vni3p/gppLMVZLghI7oMqbOZdGmLbbKR23XFARKnszCIBpEMEDxOMNIKPmMItQrmH/iJrL1z8Jt2nynY0bE1ug==
-
-esbuild-linux-arm64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-arm64/-/esbuild-linux-arm64-0.14.48.tgz#0af1eda474b5c6cc0cace8235b74d0cb8fcf57a7"
- integrity sha512-3CFsOlpoxlKPRevEHq8aAntgYGYkE1N9yRYAcPyng/p4Wyx0tPR5SBYsxLKcgPB9mR8chHEhtWYz6EZ+H199Zw==
-
-esbuild-linux-arm@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-arm/-/esbuild-linux-arm-0.14.48.tgz#de4d1fa6b77cdcd00e2bb43dd0801e4680f0ab52"
- integrity sha512-+VfSV7Akh1XUiDNXgqgY1cUP1i2vjI+BmlyXRfVz5AfV3jbpde8JTs5Q9sYgaoq5cWfuKfoZB/QkGOI+QcL1Tw==
-
-esbuild-linux-mips64le@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-mips64le/-/esbuild-linux-mips64le-0.14.48.tgz#822c1778495f7868e990d4da47ad7281df28fd15"
- integrity sha512-cs0uOiRlPp6ymknDnjajCgvDMSsLw5mST2UXh+ZIrXTj2Ifyf2aAP3Iw4DiqgnyYLV2O/v/yWBJx+WfmKEpNLA==
-
-esbuild-linux-ppc64le@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-ppc64le/-/esbuild-linux-ppc64le-0.14.48.tgz#55de0a9ec4a48fedfe82a63e083164d001709447"
- integrity sha512-+2F0vJMkuI0Wie/wcSPDCqXvSFEELH7Jubxb7mpWrA/4NpT+/byjxDz0gG6R1WJoeDefcrMfpBx4GFNN1JQorQ==
-
-esbuild-linux-riscv64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-riscv64/-/esbuild-linux-riscv64-0.14.48.tgz#cd2b7381880b2f4b21a5a598fb673492120f18a5"
- integrity sha512-BmaK/GfEE+5F2/QDrIXteFGKnVHGxlnK9MjdVKMTfvtmudjY3k2t8NtlY4qemKSizc+QwyombGWTBDc76rxePA==
-
-esbuild-linux-s390x@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-linux-s390x/-/esbuild-linux-s390x-0.14.48.tgz#4b319eca2a5c64637fc7397ffbd9671719cdb6bf"
- integrity sha512-tndw/0B9jiCL+KWKo0TSMaUm5UWBLsfCKVdbfMlb3d5LeV9WbijZ8Ordia8SAYv38VSJWOEt6eDCdOx8LqkC4g==
-
-esbuild-netbsd-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-netbsd-64/-/esbuild-netbsd-64-0.14.48.tgz#c27cde8b5cb55dcc227943a18ab078fb98d0adbf"
- integrity sha512-V9hgXfwf/T901Lr1wkOfoevtyNkrxmMcRHyticybBUHookznipMOHoF41Al68QBsqBxnITCEpjjd4yAos7z9Tw==
-
-esbuild-openbsd-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-openbsd-64/-/esbuild-openbsd-64-0.14.48.tgz#af5ab2d1cb41f09064bba9465fc8bf1309150df1"
- integrity sha512-+IHf4JcbnnBl4T52egorXMatil/za0awqzg2Vy6FBgPcBpisDWT2sVz/tNdrK9kAqj+GZG/jZdrOkj7wsrNTKA==
-
-esbuild-sunos-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-sunos-64/-/esbuild-sunos-64-0.14.48.tgz#db3ae20526055cf6fd5c4582676233814603ac54"
- integrity sha512-77m8bsr5wOpOWbGi9KSqDphcq6dFeJyun8TA+12JW/GAjyfTwVtOnN8DOt6DSPUfEV+ltVMNqtXUeTeMAxl5KA==
-
-esbuild-windows-32@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-windows-32/-/esbuild-windows-32-0.14.48.tgz#021ffceb0a3f83078262870da88a912293c57475"
- integrity sha512-EPgRuTPP8vK9maxpTGDe5lSoIBHGKO/AuxDncg5O3NkrPeLNdvvK8oywB0zGaAZXxYWfNNSHskvvDgmfVTguhg==
-
-esbuild-windows-64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-windows-64/-/esbuild-windows-64-0.14.48.tgz#a4d3407b580f9faac51f61eec095fa985fb3fee4"
- integrity sha512-YmpXjdT1q0b8ictSdGwH3M8VCoqPpK1/UArze3X199w6u8hUx3V8BhAi1WjbsfDYRBanVVtduAhh2sirImtAvA==
-
-esbuild-windows-arm64@0.14.48:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild-windows-arm64/-/esbuild-windows-arm64-0.14.48.tgz#762c0562127d8b09bfb70a3c816460742dd82880"
- integrity sha512-HHaOMCsCXp0rz5BT2crTka6MPWVno121NKApsGs/OIW5QC0ggC69YMGs1aJct9/9FSUF4A1xNE/cLvgB5svR4g==
-
-esbuild@^0.14.27:
- version "0.14.48"
- resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.14.48.tgz#da5d8d25cd2d940c45ea0cfecdca727f7aee2b85"
- integrity sha512-w6N1Yn5MtqK2U1/WZTX9ZqUVb8IOLZkZ5AdHkT6x3cHDMVsYWC7WPdiLmx19w3i4Rwzy5LqsEMtVihG3e4rFzA==
- optionalDependencies:
- esbuild-android-64 "0.14.48"
- esbuild-android-arm64 "0.14.48"
- esbuild-darwin-64 "0.14.48"
- esbuild-darwin-arm64 "0.14.48"
- esbuild-freebsd-64 "0.14.48"
- esbuild-freebsd-arm64 "0.14.48"
- esbuild-linux-32 "0.14.48"
- esbuild-linux-64 "0.14.48"
- esbuild-linux-arm "0.14.48"
- esbuild-linux-arm64 "0.14.48"
- esbuild-linux-mips64le "0.14.48"
- esbuild-linux-ppc64le "0.14.48"
- esbuild-linux-riscv64 "0.14.48"
- esbuild-linux-s390x "0.14.48"
- esbuild-netbsd-64 "0.14.48"
- esbuild-openbsd-64 "0.14.48"
- esbuild-sunos-64 "0.14.48"
- esbuild-windows-32 "0.14.48"
- esbuild-windows-64 "0.14.48"
- esbuild-windows-arm64 "0.14.48"
-
-escalade@^3.1.1:
- version "3.1.1"
- resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40"
- integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==
-
-escape-string-regexp@^1.0.5:
- version "1.0.5"
- resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-1.0.5.tgz#1b61c0562190a8dff6ae3bb2cf0200ca130b86d4"
- integrity sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==
-
-escape-string-regexp@^4.0.0:
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz#14ba83a5d373e3d311e5afca29cf5bfad965bf34"
- integrity sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==
-
-estree-walker@^2.0.1:
- version "2.0.2"
- resolved "https://registry.yarnpkg.com/estree-walker/-/estree-walker-2.0.2.tgz#52f010178c2a4c117a7757cfe942adb7d2da4cac"
- integrity sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==
-
-find-root@^1.1.0:
- version "1.1.0"
- resolved "https://registry.yarnpkg.com/find-root/-/find-root-1.1.0.tgz#abcfc8ba76f708c42a97b3d685b7e9450bfb9ce4"
- integrity sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==
-
-follow-redirects@^1.15.0:
- version "1.15.6"
- resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.15.6.tgz#7f815c0cda4249c74ff09e95ef97c23b5fd0399b"
- integrity sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==
-
-form-data@^4.0.0:
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.0.tgz#93919daeaf361ee529584b9b31664dc12c9fa452"
- integrity sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==
- dependencies:
- asynckit "^0.4.0"
- combined-stream "^1.0.8"
- mime-types "^2.1.12"
-
-fsevents@~2.3.2:
- version "2.3.2"
- resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a"
- integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==
-
-function-bind@^1.1.1:
- version "1.1.1"
- resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d"
- integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==
-
-gensync@^1.0.0-beta.2:
- version "1.0.0-beta.2"
- resolved "https://registry.yarnpkg.com/gensync/-/gensync-1.0.0-beta.2.tgz#32a6ee76c3d7f52d46b2b1ae5d93fea8580a25e0"
- integrity sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==
-
-globals@^11.1.0:
- version "11.12.0"
- resolved "https://registry.yarnpkg.com/globals/-/globals-11.12.0.tgz#ab8795338868a0babd8525758018c2a7eb95c42e"
- integrity sha512-WOBp/EEGUiIsJSp7wcv/y6MO+lV9UoncWqxuFfm8eBwzWNgyfBd6Gz+IeKQ9jCmyhoH99g15M3T+QaVHFjizVA==
-
-has-flag@^3.0.0:
- version "3.0.0"
- resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-3.0.0.tgz#b5d454dc2199ae225699f3467e5a07f3b955bafd"
- integrity sha512-sKJf1+ceQBr4SMkvQnBDNDtf4TXpVhVGateu0t918bl30FnbE2m4vNLX+VWe/dpjlb+HugGYzW7uQXH98HPEYw==
-
-has@^1.0.3:
- version "1.0.3"
- resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796"
- integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==
- dependencies:
- function-bind "^1.1.1"
-
-hoist-non-react-statics@^3.3.1:
- version "3.3.2"
- resolved "https://registry.yarnpkg.com/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz#ece0acaf71d62c2969c2ec59feff42a4b1a85b45"
- integrity sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==
- dependencies:
- react-is "^16.7.0"
-
-import-fresh@^3.1.0:
- version "3.3.0"
- resolved "https://registry.yarnpkg.com/import-fresh/-/import-fresh-3.3.0.tgz#37162c25fcb9ebaa2e6e53d5b4d88ce17d9e0c2b"
- integrity sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==
- dependencies:
- parent-module "^1.0.0"
- resolve-from "^4.0.0"
-
-is-arrayish@^0.2.1:
- version "0.2.1"
- resolved "https://registry.yarnpkg.com/is-arrayish/-/is-arrayish-0.2.1.tgz#77c99840527aa8ecb1a8ba697b80645a7a926a9d"
- integrity sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==
-
-is-core-module@^2.9.0:
- version "2.9.0"
- resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69"
- integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A==
- dependencies:
- has "^1.0.3"
-
-"js-tokens@^3.0.0 || ^4.0.0", js-tokens@^4.0.0:
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
- integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==
-
-jsesc@^2.5.1:
- version "2.5.2"
- resolved "https://registry.yarnpkg.com/jsesc/-/jsesc-2.5.2.tgz#80564d2e483dacf6e8ef209650a67df3f0c283a4"
- integrity sha512-OYu7XEzjkCQ3C5Ps3QIZsQfNpqoJyZZA99wd9aWd05NCtC5pWOkShK2mkL6HXQR6/Cy2lbNdPlZBpuQHXE63gA==
-
-json-parse-even-better-errors@^2.3.0:
- version "2.3.1"
- resolved "https://registry.yarnpkg.com/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz#7c47805a94319928e05777405dc12e1f7a4ee02d"
- integrity sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==
-
-json5@^2.2.1:
- version "2.2.3"
- resolved "https://registry.yarnpkg.com/json5/-/json5-2.2.3.tgz#78cd6f1a19bdc12b73db5ad0c61efd66c1e29283"
- integrity sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==
-
-lines-and-columns@^1.1.6:
- version "1.2.4"
- resolved "https://registry.yarnpkg.com/lines-and-columns/-/lines-and-columns-1.2.4.tgz#eca284f75d2965079309dc0ad9255abb2ebc1632"
- integrity sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==
-
-lodash@^4.17.21:
- version "4.17.21"
- resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
- integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
-
-loose-envify@^1.1.0, loose-envify@^1.4.0:
- version "1.4.0"
- resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.4.0.tgz#71ee51fa7be4caec1a63839f7e682d8132d30caf"
- integrity sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==
- dependencies:
- js-tokens "^3.0.0 || ^4.0.0"
-
-mime-db@1.52.0:
- version "1.52.0"
- resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70"
- integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==
-
-mime-types@^2.1.12:
- version "2.1.35"
- resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.35.tgz#381a871b62a734450660ae3deee44813f70d959a"
- integrity sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==
- dependencies:
- mime-db "1.52.0"
-
-ms@2.1.2:
- version "2.1.2"
- resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"
- integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==
-
-nanoid@^3.3.1:
- version "3.3.4"
- resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab"
- integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw==
-
-nanoid@^3.3.6:
- version "3.3.6"
- resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.6.tgz#443380c856d6e9f9824267d960b4236ad583ea4c"
- integrity sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==
-
-node-releases@^2.0.5:
- version "2.0.5"
- resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-2.0.5.tgz#280ed5bc3eba0d96ce44897d8aee478bfb3d9666"
- integrity sha512-U9h1NLROZTq9uE1SNffn6WuPDg8icmi3ns4rEl/oTfIle4iLjTliCzgTsbaIFMq/Xn078/lfY/BL0GWZ+psK4Q==
-
-object-assign@^4.1.1:
- version "4.1.1"
- resolved "https://registry.yarnpkg.com/object-assign/-/object-assign-4.1.1.tgz#2109adc7965887cfc05cbbd442cac8bfbb360863"
- integrity sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==
-
-parent-module@^1.0.0:
- version "1.0.1"
- resolved "https://registry.yarnpkg.com/parent-module/-/parent-module-1.0.1.tgz#691d2709e78c79fae3a156622452d00762caaaa2"
- integrity sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==
- dependencies:
- callsites "^3.0.0"
-
-parse-json@^5.0.0:
- version "5.2.0"
- resolved "https://registry.yarnpkg.com/parse-json/-/parse-json-5.2.0.tgz#c76fc66dee54231c962b22bcc8a72cf2f99753cd"
- integrity sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==
- dependencies:
- "@babel/code-frame" "^7.0.0"
- error-ex "^1.3.1"
- json-parse-even-better-errors "^2.3.0"
- lines-and-columns "^1.1.6"
-
-path-parse@^1.0.7:
- version "1.0.7"
- resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735"
- integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==
-
-path-type@^4.0.0:
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/path-type/-/path-type-4.0.0.tgz#84ed01c0a7ba380afe09d90a8c180dcd9d03043b"
- integrity sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==
-
-picocolors@^1.0.0:
- version "1.0.0"
- resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c"
- integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==
-
-picomatch@^2.2.2:
- version "2.3.1"
- resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42"
- integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==
-
-postcss@^8.4.13:
- version "8.4.31"
- resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.31.tgz#92b451050a9f914da6755af352bdc0192508656d"
- integrity sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==
- dependencies:
- nanoid "^3.3.6"
- picocolors "^1.0.0"
- source-map-js "^1.0.2"
-
-prettier@^2.5.1:
- version "2.7.1"
- resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.7.1.tgz#e235806850d057f97bb08368a4f7d899f7760c64"
- integrity sha512-ujppO+MkdPqoVINuDFDRLClm7D78qbDt0/NR+wp5FqEZOoTNAjPHWj17QRhu7geIHJfcNhRk1XVQmF8Bp3ye+g==
-
-prop-types@^15.6.2, prop-types@^15.8.1:
- version "15.8.1"
- resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.8.1.tgz#67d87bf1a694f48435cf332c24af10214a3140b5"
- integrity sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==
- dependencies:
- loose-envify "^1.4.0"
- object-assign "^4.1.1"
- react-is "^16.13.1"
-
-proxy-from-env@^1.1.0:
- version "1.1.0"
- resolved "https://registry.yarnpkg.com/proxy-from-env/-/proxy-from-env-1.1.0.tgz#e102f16ca355424865755d2c9e8ea4f24d58c3e2"
- integrity sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==
-
-react-dom@^17.0.2:
- version "17.0.2"
- resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-17.0.2.tgz#ecffb6845e3ad8dbfcdc498f0d0a939736502c23"
- integrity sha512-s4h96KtLDUQlsENhMn1ar8t2bEa+q/YAtj8pPPdIjPDGBDIVNsrD9aXNWqspUe6AzKCIG0C1HZZLqLV7qpOBGA==
- dependencies:
- loose-envify "^1.1.0"
- object-assign "^4.1.1"
- scheduler "^0.20.2"
-
-react-is@^16.13.1, react-is@^16.7.0:
- version "16.13.1"
- resolved "https://registry.yarnpkg.com/react-is/-/react-is-16.13.1.tgz#789729a4dc36de2999dc156dd6c1d9c18cea56a4"
- integrity sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==
-
-react-is@^17.0.2:
- version "17.0.2"
- resolved "https://registry.yarnpkg.com/react-is/-/react-is-17.0.2.tgz#e691d4a8e9c789365655539ab372762b0efb54f0"
- integrity sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==
-
-react-refresh@^0.13.0:
- version "0.13.0"
- resolved "https://registry.yarnpkg.com/react-refresh/-/react-refresh-0.13.0.tgz#cbd01a4482a177a5da8d44c9755ebb1f26d5a1c1"
- integrity sha512-XP8A9BT0CpRBD+NYLLeIhld/RqG9+gktUjW1FkE+Vm7OCinbG1SshcK5tb9ls4kzvjZr9mOQc7HYgBngEyPAXg==
-
-react-transition-group@^4.4.2:
- version "4.4.2"
- resolved "https://registry.yarnpkg.com/react-transition-group/-/react-transition-group-4.4.2.tgz#8b59a56f09ced7b55cbd53c36768b922890d5470"
- integrity sha512-/RNYfRAMlZwDSr6z4zNKV6xu53/e2BuaBbGhbyYIXTrmgu/bGHzmqOs7mJSJBHy9Ud+ApHx3QjrkKSp1pxvlFg==
- dependencies:
- "@babel/runtime" "^7.5.5"
- dom-helpers "^5.0.1"
- loose-envify "^1.4.0"
- prop-types "^15.6.2"
-
-react@^17.0.2:
- version "17.0.2"
- resolved "https://registry.yarnpkg.com/react/-/react-17.0.2.tgz#d0b5cc516d29eb3eee383f75b62864cfb6800037"
- integrity sha512-gnhPt75i/dq/z3/6q/0asP78D0u592D5L1pd7M8P+dck6Fu/jJeL6iVVK23fptSUZj8Vjf++7wXA8UNclGQcbA==
- dependencies:
- loose-envify "^1.1.0"
- object-assign "^4.1.1"
-
-regenerator-runtime@^0.13.4:
- version "0.13.9"
- resolved "https://registry.yarnpkg.com/regenerator-runtime/-/regenerator-runtime-0.13.9.tgz#8925742a98ffd90814988d7566ad30ca3b263b52"
- integrity sha512-p3VT+cOEgxFsRRA9X4lkI1E+k2/CtnKtU4gcxyaCUreilL/vqI6CdZ3wxVUx3UOUg+gnUOQQcRI7BmSI656MYA==
-
-resolve-from@^4.0.0:
- version "4.0.0"
- resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-4.0.0.tgz#4abcd852ad32dd7baabfe9b40e00a36db5f392e6"
- integrity sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==
-
-resolve@^1.12.0, resolve@^1.22.0:
- version "1.22.1"
- resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177"
- integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw==
- dependencies:
- is-core-module "^2.9.0"
- path-parse "^1.0.7"
- supports-preserve-symlinks-flag "^1.0.0"
-
-"rollup@>=2.59.0 <2.78.0":
- version "2.77.3"
- resolved "https://registry.yarnpkg.com/rollup/-/rollup-2.77.3.tgz#8f00418d3a2740036e15deb653bed1a90ee0cc12"
- integrity sha512-/qxNTG7FbmefJWoeeYJFbHehJ2HNWnjkAFRKzWN/45eNBBF/r8lo992CwcJXEzyVxs5FmfId+vTSTQDb+bxA+g==
- optionalDependencies:
- fsevents "~2.3.2"
-
-safe-buffer@~5.1.1:
- version "5.1.2"
- resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d"
- integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==
-
-scheduler@^0.20.2:
- version "0.20.2"
- resolved "https://registry.yarnpkg.com/scheduler/-/scheduler-0.20.2.tgz#4baee39436e34aa93b4874bddcbf0fe8b8b50e91"
- integrity sha512-2eWfGgAqqWFGqtdMmcL5zCMK1U8KlXv8SQFGglL3CEtd0aDVDWgeF/YoCmvln55m5zSk3J/20hTaSBeSObsQDQ==
- dependencies:
- loose-envify "^1.1.0"
- object-assign "^4.1.1"
-
-semver@^6.3.0:
- version "6.3.1"
- resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.1.tgz#556d2ef8689146e46dcea4bfdd095f3434dffcb4"
- integrity sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==
-
-source-map-js@^1.0.2:
- version "1.0.2"
- resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c"
- integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==
-
-source-map@^0.5.7:
- version "0.5.7"
- resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.5.7.tgz#8a039d2d1021d22d1ea14c80d8ea468ba2ef3fcc"
- integrity sha512-LbrmJOMUSdEVxIKvdcJzQC+nQhe8FUZQTXQy6+I75skNgn3OoQ0DZA8YnFa7gp8tqtL3KPf1kmo0R5DoApeSGQ==
-
-stylis@4.0.13:
- version "4.0.13"
- resolved "https://registry.yarnpkg.com/stylis/-/stylis-4.0.13.tgz#f5db332e376d13cc84ecfe5dace9a2a51d954c91"
- integrity sha512-xGPXiFVl4YED9Jh7Euv2V220mriG9u4B2TA6Ybjc1catrstKD2PpIdU3U0RKpkVBC2EhmL/F0sPCr9vrFTNRag==
-
-supports-color@^5.3.0:
- version "5.5.0"
- resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-5.5.0.tgz#e2e69a44ac8772f78a1ec0b35b689df6530efc8f"
- integrity sha512-QjVjwdXIt408MIiAqCX4oUKsgU2EqAGzs2Ppkm4aQYbjm+ZEWEcW4SfFNTr4uMNZma0ey4f5lgLrkB0aX0QMow==
- dependencies:
- has-flag "^3.0.0"
-
-supports-preserve-symlinks-flag@^1.0.0:
- version "1.0.0"
- resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09"
- integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==
-
-to-fast-properties@^2.0.0:
- version "2.0.0"
- resolved "https://registry.yarnpkg.com/to-fast-properties/-/to-fast-properties-2.0.0.tgz#dc5e698cbd079265bc73e0377681a4e4e83f616e"
- integrity sha512-/OaKK0xYrs3DmxRYqL/yDc+FxFUVYhDlXMhRmv3z915w2HF1tnN1omB354j8VUGO/hbRzyD6Y3sA7v7GS/ceog==
-
-typescript@^4.5.4:
- version "4.7.4"
- resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235"
- integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ==
-
-update-browserslist-db@^1.0.4:
- version "1.0.4"
- resolved "https://registry.yarnpkg.com/update-browserslist-db/-/update-browserslist-db-1.0.4.tgz#dbfc5a789caa26b1db8990796c2c8ebbce304824"
- integrity sha512-jnmO2BEGUjsMOe/Fg9u0oczOe/ppIDZPebzccl1yDWGLFP16Pa1/RM5wEoKYPG2zstNcDuAStejyxsOuKINdGA==
- dependencies:
- escalade "^3.1.1"
- picocolors "^1.0.0"
-
-vite@^2.9.17:
- version "2.9.17"
- resolved "https://registry.yarnpkg.com/vite/-/vite-2.9.17.tgz#6b770525e12fa2a2e3a0fa0d028d304f4f7dc7d4"
- integrity sha512-XxcRzra6d7xrKXH66jZUgb+srThoPu+TLJc06GifUyKq9JmjHkc1Numc8ra0h56rju2jfVWw3B3fs5l3OFMvUw==
- dependencies:
- esbuild "^0.14.27"
- postcss "^8.4.13"
- resolve "^1.22.0"
- rollup ">=2.59.0 <2.78.0"
- optionalDependencies:
- fsevents "~2.3.2"
-
-yaml@^1.7.2:
- version "1.10.2"
- resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b"
- integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==
diff --git a/src/lightning/app/components/README.md b/src/lightning/app/components/README.md
deleted file mode 100644
index 5d3e9eeaf608e..0000000000000
--- a/src/lightning/app/components/README.md
+++ /dev/null
@@ -1 +0,0 @@
-TODO: add a guide how to add and use an external component
diff --git a/src/lightning/app/components/__init__.py b/src/lightning/app/components/__init__.py
deleted file mode 100644
index 8e44c0370d38b..0000000000000
--- a/src/lightning/app/components/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from lightning.app.components.database.client import DatabaseClient
-from lightning.app.components.database.server import Database
-from lightning.app.components.multi_node import (
- FabricMultiNode,
- LightningTrainerMultiNode,
- MultiNode,
- PyTorchSpawnMultiNode,
-)
-from lightning.app.components.python.popen import PopenPythonScript
-from lightning.app.components.python.tracer import Code, TracerPythonScript
-from lightning.app.components.serve.auto_scaler import AutoScaler
-from lightning.app.components.serve.cold_start_proxy import ColdStartProxy
-from lightning.app.components.serve.gradio_server import ServeGradio
-from lightning.app.components.serve.python_server import Category, Image, Number, PythonServer, Text
-from lightning.app.components.serve.serve import ModelInferenceAPI
-from lightning.app.components.serve.streamlit import ServeStreamlit
-from lightning.app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
-
-__all__ = [
- "AutoScaler",
- "ColdStartProxy",
- "DatabaseClient",
- "Database",
- "PopenPythonScript",
- "Code",
- "TracerPythonScript",
- "ServeGradio",
- "ServeStreamlit",
- "ModelInferenceAPI",
- "PythonServer",
- "Image",
- "Number",
- "Category",
- "Text",
- "MultiNode",
- "FabricMultiNode",
- "LightningTrainerScript",
- "PyTorchLightningScriptRunner",
- "PyTorchSpawnMultiNode",
- "LightningTrainerMultiNode",
-]
diff --git a/src/lightning/app/components/database/__init__.py b/src/lightning/app/components/database/__init__.py
deleted file mode 100644
index 50e7d095716b4..0000000000000
--- a/src/lightning/app/components/database/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning.app.components.database.client import DatabaseClient
-from lightning.app.components.database.server import Database
-
-__all__ = ["Database", "DatabaseClient"]
diff --git a/src/lightning/app/components/database/client.py b/src/lightning/app/components/database/client.py
deleted file mode 100644
index 81f0862918934..0000000000000
--- a/src/lightning/app/components/database/client.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Dict, List, Optional, Type, TypeVar
-
-import requests
-from requests import Session
-from requests.adapters import HTTPAdapter
-from urllib3.util.retry import Retry
-
-from lightning.app.components.database.utilities import _GeneralModel
-
-_CONNECTION_RETRY_TOTAL = 5
-_CONNECTION_RETRY_BACKOFF_FACTOR = 1
-
-
-def _configure_session() -> Session:
- """Configures the session for GET and POST requests.
-
- It enables a generous retrial strategy that waits for the application server to connect.
-
- """
- retry_strategy = Retry(
- # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
- total=_CONNECTION_RETRY_TOTAL,
- backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
- status_forcelist=[429, 500, 502, 503, 504],
- )
- adapter = HTTPAdapter(max_retries=retry_strategy)
- http = requests.Session()
- http.mount("https://", adapter)
- http.mount("http://", adapter)
- return http
-
-
-T = TypeVar("T")
-
-
-class DatabaseClient:
- def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None:
- self.db_url = db_url
- self.model = model
- self.token = token or ""
- self._session = None
-
- def select_all(self, model: Optional[Type[T]] = None) -> List[T]:
- cls = model if model else self.model
- resp = self.session.post(
- self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json()
- )
- assert resp.status_code == 200
- return [cls(**data) for data in resp.json()]
-
- def insert(self, model: T) -> None:
- resp = self.session.post(
- self.db_url + "/insert/",
- data=_GeneralModel.from_obj(model, token=self.token).json(),
- )
- assert resp.status_code == 200
-
- def update(self, model: T) -> None:
- resp = self.session.post(
- self.db_url + "/update/",
- data=_GeneralModel.from_obj(model, token=self.token).json(),
- )
- assert resp.status_code == 200
-
- def delete(self, model: T) -> None:
- resp = self.session.post(
- self.db_url + "/delete/",
- data=_GeneralModel.from_obj(model, token=self.token).json(),
- )
- assert resp.status_code == 200
-
- @property
- def session(self):
- if self._session is None:
- self._session = _configure_session()
- return self._session
-
- def to_dict(self) -> Dict[str, Any]:
- return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None}
diff --git a/src/lightning/app/components/database/server.py b/src/lightning/app/components/database/server.py
deleted file mode 100644
index 3fbf75f01ac85..0000000000000
--- a/src/lightning/app/components/database/server.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import os
-import sqlite3
-import sys
-import tempfile
-import threading
-import traceback
-from typing import List, Optional, Type, Union
-
-import uvicorn
-from fastapi import FastAPI
-from uvicorn import run
-
-from lightning.app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update
-from lightning.app.core.work import LightningWork
-from lightning.app.storage import Drive
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.imports import _is_sqlmodel_available
-from lightning.app.utilities.packaging.build_config import BuildConfig
-
-if _is_sqlmodel_available():
- from sqlmodel import SQLModel
-else:
- SQLModel = object
-
-
-logger = Logger(__name__)
-
-
-# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
-# Discussions: https://github.com/encode/uvicorn/discussions/1708
-class _DatabaseUvicornServer(uvicorn.Server):
- has_started_queue = None
-
- def run(self, sockets=None):
- self.config.setup_event_loop()
- loop = asyncio.get_event_loop()
- asyncio.ensure_future(self.serve(sockets=sockets))
- loop.run_forever()
-
- def install_signal_handlers(self):
- """Ignore Uvicorn Signal Handlers."""
-
-
-_lock = threading.Lock()
-
-
-class Database(LightningWork):
- def __init__(
- self,
- models: Union[Type["SQLModel"], List[Type["SQLModel"]]],
- db_filename: str = "database.db",
- store_interval: int = 10,
- debug: bool = False,
- ) -> None:
- """The Database Component enables to interact with an SQLite database to store some structured information
- about your application.
-
- The provided models are SQLModel tables
-
- Arguments:
- models: A SQLModel or a list of SQLModels table to be added to the database.
- db_filename: The name of the SQLite database.
- store_interval: Time interval (in seconds) at which the database is periodically synchronized to the Drive.
- Note that the database is also always synchronized on exit.
- debug: Whether to run the database in debug mode.
-
- Example::
-
- from typing import List
- from sqlmodel import SQLModel, Field
- from uuid import uuid4
-
- from lightning.app import LightningFlow, LightningApp
- from lightning.app.components.database import Database, DatabaseClient
-
- class CounterModel(SQLModel, table=True):
- __table_args__ = {"extend_existing": True}
-
- id: int = Field(default=None, primary_key=True)
- count: int
-
-
- class Flow(LightningFlow):
-
- def __init__(self):
- super().__init__()
- self._private_token = uuid4().hex
- self.db = Database(models=[CounterModel])
- self._client = None
- self.counter = 0
-
- def run(self):
- self.db.run(token=self._private_token)
-
- if not self.db.alive():
- return
-
- if self.counter == 0:
- self._client = DatabaseClient(
- model=CounterModel,
- db_url=self.db.url,
- token=self._private_token,
- )
-
- rows = self._client.select_all()
-
- print(f"{self.counter}: {rows}")
-
- if not rows:
- self._client.insert(CounterModel(count=0))
- else:
- row: CounterModel = rows[0]
- row.count += 1
- self._client.update(row)
-
- if self.counter >= 100:
- row: CounterModel = rows[0]
- self._client.delete(row)
- self.stop()
-
- self.counter += 1
-
- app = LightningApp(Flow())
-
- If you want to use nested SQLModels, we provide a utility to do so as follows:
-
- Example::
-
- from typing import List
- from sqlmodel import SQLModel, Field
- from sqlalchemy import Column
-
- from lightning.app.components.database.utilities import pydantic_column_type
-
- class KeyValuePair(SQLModel):
- name: str
- value: str
-
- class CounterModel(SQLModel, table=True):
- __table_args__ = {"extend_existing": True}
-
- name: int = Field(default=None, primary_key=True)
-
- # RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
- kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
-
- """
- super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
- self.db_filename = db_filename
- self._root_folder = os.path.dirname(db_filename)
- self.debug = debug
- self.store_interval = store_interval
- self._models = models if isinstance(models, list) else [models]
- self._store_thread = None
- self._exit_event = None
-
- def store_database(self):
- try:
- with tempfile.TemporaryDirectory() as tmpdir:
- tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
-
- source = sqlite3.connect(self.db_filename)
- dest = sqlite3.connect(tmp_db_filename)
-
- source.backup(dest)
-
- source.close()
- dest.close()
-
- drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
- drive.put(os.path.basename(tmp_db_filename))
-
- logger.debug("Stored the database to the Drive.")
- except Exception:
- print(traceback.print_exc())
-
- def periodic_store_database(self, store_interval):
- while not self._exit_event.is_set():
- with _lock:
- self.store_database()
- self._exit_event.wait(store_interval)
-
- def run(self, token: Optional[str] = None) -> None:
- """
- Arguments:
- token: Token used to protect the database access. Ensure you don't expose it through the App State.
- """
- drive = Drive("lit://database", component_name=self.name, root_folder=self._root_folder)
- filenames = drive.list(component_name=self.name)
- if self.db_filename in filenames:
- drive.get(self.db_filename)
- print("Retrieved the database from Drive.")
-
- app = FastAPI()
-
- _create_database(self.db_filename, self._models, self.debug)
- models = {m.__name__: m for m in self._models}
- app.post("/select_all/")(_SelectAll(models, token))
- app.post("/insert/")(_Insert(models, token))
- app.post("/update/")(_Update(models, token))
- app.post("/delete/")(_Delete(models, token))
-
- sys.modules["uvicorn.main"].Server = _DatabaseUvicornServer
-
- self._exit_event = threading.Event()
- self._store_thread = threading.Thread(target=self.periodic_store_database, args=(self.store_interval,))
- self._store_thread.start()
-
- run(app, host=self.host, port=self.port, log_level="error")
-
- def alive(self) -> bool:
- """Hack: Returns whether the server is alive."""
- return self.db_url != ""
-
- @property
- def db_url(self) -> Optional[str]:
- use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
- if use_localhost:
- return self.url
- ip_addr = self.public_ip or self.internal_ip
- if ip_addr != "":
- return f"http://{ip_addr}:{self.port}"
- return ip_addr
-
- def on_exit(self):
- self._exit_event.set()
- with _lock:
- self.store_database()
diff --git a/src/lightning/app/components/database/utilities.py b/src/lightning/app/components/database/utilities.py
deleted file mode 100644
index 129f8a210a758..0000000000000
--- a/src/lightning/app/components/database/utilities.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-import json
-import pathlib
-from typing import Any, Dict, Generic, List, Type, TypeVar
-
-from fastapi import Response, status
-from fastapi.encoders import jsonable_encoder
-from lightning_utilities.core.imports import RequirementCache
-from pydantic import BaseModel, parse_obj_as
-
-if RequirementCache("pydantic>=2.0.0"):
- from pydantic.v1.main import ModelMetaclass
-else:
- from pydantic.main import ModelMetaclass
-
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.imports import _is_sqlmodel_available
-
-if _is_sqlmodel_available():
- from sqlalchemy.inspection import inspect as sqlalchemy_inspect
- from sqlmodel import JSON, Session, SQLModel, TypeDecorator, select
-
-logger = Logger(__name__)
-engine = None
-
-T = TypeVar("T")
-
-
-# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
-def _pydantic_column_type(pydantic_type: Any) -> Any:
- """This function enables to support JSON types with SQLModel.
-
- Example::
-
- from sqlmodel import SQLModel
- from sqlalchemy import Column
-
- class TrialConfig(SQLModel, table=False):
- ...
- params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
-
- """
-
- class PydanticJSONType(TypeDecorator, Generic[T]):
- impl = JSON()
-
- def __init__(
- self,
- json_encoder=json,
- ):
- self.json_encoder = json_encoder
- super().__init__()
-
- def bind_processor(self, dialect):
- impl_processor = self.impl.bind_processor(dialect)
- dumps = self.json_encoder.dumps
- if impl_processor:
-
- def process(value: T):
- if value is not None:
- if isinstance(pydantic_type, ModelMetaclass):
- # This allows to assign non-InDB models and if they're
- # compatible, they're directly parsed into the InDB
- # representation, thus hiding the implementation in the
- # background. However, the InDB model will still be returned
- value_to_dump = pydantic_type.from_orm(value)
- else:
- value_to_dump = value
- value = jsonable_encoder(value_to_dump)
- return impl_processor(value)
-
- else:
-
- def process(value):
- if isinstance(pydantic_type, ModelMetaclass):
- # This allows to assign non-InDB models and if they're
- # compatible, they're directly parsed into the InDB
- # representation, thus hiding the implementation in the
- # background. However, the InDB model will still be returned
- value_to_dump = pydantic_type.from_orm(value)
- else:
- value_to_dump = value
- return dumps(jsonable_encoder(value_to_dump))
-
- return process
-
- def result_processor(self, dialect, coltype) -> T:
- impl_processor = self.impl.result_processor(dialect, coltype)
- if impl_processor:
-
- def process(value):
- value = impl_processor(value)
- if value is None:
- return None
-
- data = value
- # Explicitly use the generic directly, not type(T)
- return parse_obj_as(pydantic_type, data)
-
- else:
-
- def process(value):
- if value is None:
- return None
-
- # Explicitly use the generic directly, not type(T)
- return parse_obj_as(pydantic_type, value)
-
- return process
-
- def compare_values(self, x, y):
- return x == y
-
- return PydanticJSONType
-
-
-@functools.lru_cache(maxsize=128)
-def _get_primary_key(model_type: Type["SQLModel"]) -> str:
- primary_keys = sqlalchemy_inspect(model_type).primary_key
-
- if len(primary_keys) != 1:
- raise ValueError(f"The model {model_type.__name__} should have a single primary key field.")
-
- return primary_keys[0].name
-
-
-class _GeneralModel(BaseModel):
- cls_name: str
- data: str
- token: str
-
- def convert_to_model(self, models: Dict[str, BaseModel]):
- return models[self.cls_name].parse_raw(self.data)
-
- @classmethod
- def from_obj(cls, obj, token):
- return cls(**{
- "cls_name": obj.__class__.__name__,
- "data": obj.json(),
- "token": token,
- })
-
- @classmethod
- def from_cls(cls, obj_cls, token):
- return cls(**{
- "cls_name": obj_cls.__name__,
- "data": "",
- "token": token,
- })
-
-
-class _SelectAll:
- def __init__(self, models, token):
- print(models, token)
- self.models = models
- self.token = token
-
- def __call__(self, data: Dict, response: Response):
- if self.token and data["token"] != self.token:
- response.status_code = status.HTTP_401_UNAUTHORIZED
- return {"status": "failure", "reason": "Unauthorized request to the database."}
-
- with Session(engine) as session:
- cls: Type["SQLModel"] = self.models[data["cls_name"]]
- statement = select(cls)
- results = session.exec(statement)
- return results.all()
-
-
-class _Insert:
- def __init__(self, models, token):
- self.models = models
- self.token = token
-
- def __call__(self, data: Dict, response: Response):
- if self.token and data["token"] != self.token:
- response.status_code = status.HTTP_401_UNAUTHORIZED
- return {"status": "failure", "reason": "Unauthorized request to the database."}
-
- with Session(engine) as session:
- ele = self.models[data["cls_name"]].parse_raw(data["data"])
- session.add(ele)
- session.commit()
- session.refresh(ele)
- return ele
-
-
-class _Update:
- def __init__(self, models, token):
- self.models = models
- self.token = token
-
- def __call__(self, data: Dict, response: Response):
- if self.token and data["token"] != self.token:
- response.status_code = status.HTTP_401_UNAUTHORIZED
- return {"status": "failure", "reason": "Unauthorized request to the database."}
-
- with Session(engine) as session:
- update_data = self.models[data["cls_name"]].parse_raw(data["data"])
- primary_key = _get_primary_key(update_data.__class__)
- identifier = getattr(update_data.__class__, primary_key, None)
- statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
- results = session.exec(statement)
- result = results.one()
- for k, v in vars(update_data).items():
- if k in ("id", "_sa_instance_state"):
- continue
- if getattr(result, k) != v:
- setattr(result, k, v)
- session.add(result)
- session.commit()
- session.refresh(result)
- return None
-
-
-class _Delete:
- def __init__(self, models, token):
- self.models = models
- self.token = token
-
- def __call__(self, data: Dict, response: Response):
- if self.token and data["token"] != self.token:
- response.status_code = status.HTTP_401_UNAUTHORIZED
- return {"status": "failure", "reason": "Unauthorized request to the database."}
-
- with Session(engine) as session:
- update_data = self.models[data["cls_name"]].parse_raw(data["data"])
- primary_key = _get_primary_key(update_data.__class__)
- identifier = getattr(update_data.__class__, primary_key, None)
- statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
- results = session.exec(statement)
- result = results.one()
- session.delete(result)
- session.commit()
- return None
-
-
-def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: bool = False):
- global engine
-
- from sqlmodel import create_engine
-
- engine = create_engine(f"sqlite:///{pathlib.Path(db_filename).resolve()}", echo=echo)
-
- logger.debug(f"Creating the following tables {models}")
- try:
- SQLModel.metadata.create_all(engine)
- except Exception as ex:
- logger.debug(ex)
diff --git a/src/lightning/app/components/multi_node/__init__.py b/src/lightning/app/components/multi_node/__init__.py
deleted file mode 100644
index e672a0d422ed9..0000000000000
--- a/src/lightning/app/components/multi_node/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app.components.multi_node.base import MultiNode
-from lightning.app.components.multi_node.fabric import FabricMultiNode
-from lightning.app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode
-from lightning.app.components.multi_node.trainer import LightningTrainerMultiNode
-
-__all__ = ["FabricMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"]
diff --git a/src/lightning/app/components/multi_node/base.py b/src/lightning/app/components/multi_node/base.py
deleted file mode 100644
index e5a651466c0a4..0000000000000
--- a/src/lightning/app/components/multi_node/base.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import warnings
-from typing import Any, Type
-
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.structures import List as _List
-from lightning.app.utilities.cloud import is_running_in_cloud
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-
-class MultiNode(LightningFlow):
- def __init__(
- self,
- work_cls: Type["LightningWork"],
- num_nodes: int,
- cloud_compute: "CloudCompute",
- *work_args: Any,
- **work_kwargs: Any,
- ) -> None:
- """This component enables performing distributed multi-node multi-device training.
-
- Example::
-
- import torch
-
- from lightning.app import LightningWork, CloudCompute
- from lightning.components import MultiNode
-
- class AnyDistributedComponent(LightningWork):
- def run(
- self,
- main_address: str,
- main_port: int,
- node_rank: int,
- ):
- print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
-
-
- compute = CloudCompute("gpu")
- app = LightningApp(
- MultiNode(
- AnyDistributedComponent,
- num_nodes=8,
- cloud_compute=compute,
- )
- )
-
- Arguments:
- work_cls: The work to be executed
- num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
- multiple cloud machines.
- cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
- running locally.
- work_args: Arguments to be provided to the work on instantiation.
- work_kwargs: Keywords arguments to be provided to the work on instantiation.
-
- """
- super().__init__()
- if num_nodes > 1 and not is_running_in_cloud():
- warnings.warn(
- f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
- " We assume you are debugging and will ignore the `num_nodes` argument."
- " To run on multiple nodes in the cloud, launch your app with `--cloud`."
- )
- num_nodes = 1
- self.ws = _List(*[
- work_cls(
- *work_args,
- cloud_compute=cloud_compute.clone(),
- **work_kwargs,
- parallel=True,
- )
- for _ in range(num_nodes)
- ])
-
- def run(self) -> None:
- # 1. Wait for all works to be started !
- if not all(w.internal_ip for w in self.ws):
- return
-
- # 2. Loop over all node machines
- for node_rank in range(len(self.ws)):
- # 3. Run the user code in a distributed way !
- self.ws[node_rank].run(
- main_address=self.ws[0].internal_ip,
- main_port=self.ws[0].port,
- num_nodes=len(self.ws),
- node_rank=node_rank,
- )
-
- # 4. Stop the machine when finished.
- if self.ws[node_rank].has_succeeded:
- self.ws[node_rank].stop()
diff --git a/src/lightning/app/components/multi_node/fabric.py b/src/lightning/app/components/multi_node/fabric.py
deleted file mode 100644
index 542e0fe3afc43..0000000000000
--- a/src/lightning/app/components/multi_node/fabric.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import importlib
-import os
-import warnings
-from dataclasses import dataclass
-from typing import Any, Callable, Protocol, Type, runtime_checkable
-
-from lightning.app.components.multi_node.base import MultiNode
-from lightning.app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-from lightning.app.utilities.tracer import Tracer
-
-
-@runtime_checkable
-class _FabricWorkProtocol(Protocol):
- @staticmethod
- def run() -> None:
- """Run."""
-
-
-@dataclass
-class _FabricRunExecutor(_PyTorchSpawnRunExecutor):
- @staticmethod
- def run(
- local_rank: int,
- work_run: Callable,
- main_address: str,
- main_port: int,
- num_nodes: int,
- node_rank: int,
- nprocs: int,
- ):
- fabrics = []
- strategies = []
- mps_accelerators = []
-
- for pkg_name in ("lightning.fabric", "lightning_" + "fabric"):
- try:
- pkg = importlib.import_module(pkg_name)
- fabrics.append(pkg.Fabric)
- strategies.append(pkg.strategies.DDPStrategy)
- mps_accelerators.append(pkg.accelerators.MPSAccelerator)
- except (ImportError, ModuleNotFoundError):
- continue
-
- # Used to configure PyTorch progress group
- os.environ["MASTER_ADDR"] = main_address
- os.environ["MASTER_PORT"] = str(main_port)
-
- # Used to hijack TorchElastic Cluster Environnement.
- os.environ["GROUP_RANK"] = str(node_rank)
- os.environ["RANK"] = str(local_rank + node_rank * nprocs)
- os.environ["LOCAL_RANK"] = str(local_rank)
- os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
- os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
- os.environ["TORCHELASTIC_RUN_ID"] = "1"
-
- # Used to force Fabric to setup the distributed environnement.
- os.environ["LT_CLI_USED"] = "1"
-
- # Used to pass information to Fabric directly.
- def pre_fn(fabric, *args: Any, **kwargs: Any):
- kwargs["devices"] = nprocs
- kwargs["num_nodes"] = num_nodes
-
- if any(acc.is_available() for acc in mps_accelerators):
- old_acc_value = kwargs.get("accelerator", "auto")
- kwargs["accelerator"] = "cpu"
-
- if old_acc_value != kwargs["accelerator"]:
- warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
- else:
- kwargs["accelerator"] = "auto"
- strategy = kwargs.get("strategy", None)
- if strategy:
- if isinstance(strategy, str):
- if strategy == "ddp_spawn":
- strategy = "ddp"
- elif strategy == "ddp_sharded_spawn":
- strategy = "ddp_sharded"
- elif isinstance(strategy, tuple(strategies)) and strategy._start_method in ("spawn", "fork"):
- raise ValueError("DDP Spawned strategies aren't supported yet.")
-
- kwargs["strategy"] = strategy
-
- return {}, args, kwargs
-
- tracer = Tracer()
- for lf in fabrics:
- tracer.add_traced(lf, "__init__", pre_fn=pre_fn)
- tracer._instrument()
- ret_val = work_run()
- tracer._restore()
- return ret_val
-
-
-class FabricMultiNode(MultiNode):
- def __init__(
- self,
- work_cls: Type["LightningWork"],
- cloud_compute: "CloudCompute",
- num_nodes: int,
- *work_args: Any,
- **work_kwargs: Any,
- ) -> None:
- assert issubclass(work_cls, _FabricWorkProtocol)
-
- # Note: Private way to modify the work run executor
- # Probably exposed to the users in the future if needed.
- work_cls._run_executor_cls = _FabricRunExecutor
-
- super().__init__(
- work_cls,
- *work_args,
- num_nodes=num_nodes,
- cloud_compute=cloud_compute,
- **work_kwargs,
- )
diff --git a/src/lightning/app/components/multi_node/pytorch_spawn.py b/src/lightning/app/components/multi_node/pytorch_spawn.py
deleted file mode 100644
index 9bccbad5cfc11..0000000000000
--- a/src/lightning/app/components/multi_node/pytorch_spawn.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Callable, Protocol, Type, runtime_checkable
-
-from lightning.app.components.multi_node.base import MultiNode
-from lightning.app.core.queues import MultiProcessQueue
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-from lightning.app.utilities.proxies import WorkRunExecutor, WorkStateObserver, _proxy_setattr, unwrap
-
-
-@runtime_checkable
-class _PyTorchSpawnWorkProtocol(Protocol):
- def run(
- self,
- world_size: int,
- node_rank: int,
- global_rank: int,
- local_rank: int,
- ) -> None:
- pass
-
-
-class _PyTorchSpawnRunExecutor(WorkRunExecutor):
- enable_start_observer: bool = False
-
- def __call__(
- self,
- main_address: str,
- main_port: int,
- num_nodes: int,
- node_rank: int,
- ):
- import torch
-
- with self.enable_spawn():
- nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
- queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
- torch.multiprocessing.spawn(
- self.dispatch_run,
- args=(self.__class__, self.work, queue, main_address, main_port, num_nodes, node_rank, nprocs),
- nprocs=nprocs,
- )
-
- @staticmethod
- def dispatch_run(local_rank, cls, work, delta_queue, *args: Any, **kwargs: Any):
- if local_rank == 0:
- if isinstance(delta_queue, dict):
- delta_queue = cls.process_queue(delta_queue)
- work._request_queue = cls.process_queue(work._request_queue)
- work._response_queue = cls.process_queue(work._response_queue)
-
- state_observer = WorkStateObserver(work, delta_queue=delta_queue)
- state_observer.start()
- _proxy_setattr(work, delta_queue, state_observer)
-
- cls.run(local_rank, unwrap(work.run), *args, **kwargs)
-
- if local_rank == 0:
- state_observer.join(0)
-
- @staticmethod
- def run(
- local_rank: int,
- work_run: Callable,
- main_address: str,
- main_port: int,
- num_nodes: int,
- node_rank: int,
- nprocs: int,
- ):
- import torch
-
- # 1. Setting distributed environment
- global_rank = local_rank + node_rank * nprocs
- world_size = num_nodes * nprocs
-
- if torch.distributed.is_available():
- if not torch.distributed.is_initialized():
- torch.distributed.init_process_group(
- "nccl" if torch.cuda.is_available() else "gloo",
- rank=global_rank,
- world_size=world_size,
- init_method=f"tcp://{main_address}:{main_port}",
- )
- elif world_size > 1:
- raise Exception("Torch distributed should be available.")
-
- return work_run(world_size, node_rank, global_rank, local_rank)
-
-
-class PyTorchSpawnMultiNode(MultiNode):
- def __init__(
- self,
- work_cls: Type["LightningWork"],
- cloud_compute: "CloudCompute",
- num_nodes: int,
- *work_args: Any,
- **work_kwargs: Any,
- ) -> None:
- assert issubclass(work_cls, _PyTorchSpawnWorkProtocol)
-
- # Note: Private way to modify the work run executor
- # Probably exposed to the users in the future if needed.
- work_cls._run_executor_cls = _PyTorchSpawnRunExecutor
-
- super().__init__(work_cls, num_nodes, cloud_compute, *work_args, **work_kwargs)
diff --git a/src/lightning/app/components/multi_node/trainer.py b/src/lightning/app/components/multi_node/trainer.py
deleted file mode 100644
index ee7aae36f981b..0000000000000
--- a/src/lightning/app/components/multi_node/trainer.py
+++ /dev/null
@@ -1,130 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import importlib
-import os
-import warnings
-from dataclasses import dataclass
-from typing import Any, Callable, Protocol, Type, runtime_checkable
-
-from lightning.app.components.multi_node.base import MultiNode
-from lightning.app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-from lightning.app.utilities.tracer import Tracer
-
-
-@runtime_checkable
-class _LightningTrainerWorkProtocol(Protocol):
- @staticmethod
- def run() -> None:
- """Run."""
-
-
-@dataclass
-class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
- @staticmethod
- def run(
- local_rank: int,
- work_run: Callable,
- main_address: str,
- main_port: int,
- num_nodes: int,
- node_rank: int,
- nprocs: int,
- ):
- trainers = []
- strategies = []
- mps_accelerators = []
-
- for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
- try:
- pkg = importlib.import_module(pkg_name)
- trainers.append(pkg.Trainer)
- strategies.append(pkg.strategies.DDPStrategy)
- mps_accelerators.append(pkg.accelerators.MPSAccelerator)
- except (ImportError, ModuleNotFoundError):
- continue
-
- # Used to configure PyTorch progress group
- os.environ["MASTER_ADDR"] = main_address
- os.environ["MASTER_PORT"] = str(main_port)
-
- # Used to hijack TorchElastic Cluster Environnement.
- os.environ["GROUP_RANK"] = str(node_rank)
- os.environ["RANK"] = str(local_rank + node_rank * nprocs)
- os.environ["LOCAL_RANK"] = str(local_rank)
- os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
- os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
- os.environ["TORCHELASTIC_RUN_ID"] = "1"
-
- # Used to pass information to the Trainer directly.
- def pre_fn(trainer, *args: Any, **kwargs: Any):
- kwargs["devices"] = nprocs
- kwargs["num_nodes"] = num_nodes
- if any(acc.is_available() for acc in mps_accelerators):
- old_acc_value = kwargs.get("accelerator", "auto")
- kwargs["accelerator"] = "cpu"
-
- if old_acc_value != kwargs["accelerator"]:
- warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
- else:
- kwargs["accelerator"] = "auto"
-
- strategy = kwargs.get("strategy", None)
- if strategy:
- if isinstance(strategy, str):
- if strategy == "ddp_spawn":
- strategy = "ddp"
- elif strategy == "ddp_sharded_spawn":
- strategy = "ddp_sharded"
- elif isinstance(strategy, tuple(strategies)):
- raise ValueError("DDP Spawned strategies aren't supported yet.")
- kwargs["strategy"] = strategy
- return {}, args, kwargs
-
- tracer = Tracer()
- for trainer in trainers:
- tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
- tracer._instrument()
- ret_val = work_run()
- tracer._restore()
- return ret_val
-
-
-class LightningTrainerMultiNode(MultiNode):
- def __init__(
- self,
- work_cls: Type["LightningWork"],
- cloud_compute: "CloudCompute",
- num_nodes: int,
- *work_args: Any,
- **work_kwargs: Any,
- ) -> None:
- assert issubclass(work_cls, _LightningTrainerWorkProtocol)
-
- # Note: Private way to modify the work run executor
- # Probably exposed to the users in the future if needed.
- work_cls._run_executor_cls = _LightningTrainerRunExecutor
-
- super().__init__(
- work_cls,
- *work_args,
- num_nodes=num_nodes,
- cloud_compute=cloud_compute,
- **work_kwargs,
- )
-
- # the Trainer enables TensorBoard by default, so this is often an undesired directory to upload to the cloud
- self.lightningignore += ("lightning_logs",)
diff --git a/src/lightning/app/components/python/__init__.py b/src/lightning/app/components/python/__init__.py
deleted file mode 100644
index 86de268e86fcf..0000000000000
--- a/src/lightning/app/components/python/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning.app.components.python.popen import PopenPythonScript
-from lightning.app.components.python.tracer import TracerPythonScript
-
-__all__ = ["PopenPythonScript", "TracerPythonScript"]
diff --git a/src/lightning/app/components/python/popen.py b/src/lightning/app/components/python/popen.py
deleted file mode 100644
index cb9d1ea14562d..0000000000000
--- a/src/lightning/app/components/python/popen.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import signal
-import subprocess
-import sys
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
-
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import Logger, _collect_child_process_pids
-from lightning.app.utilities.tracer import Tracer
-
-logger = Logger(__name__)
-
-
-class PopenPythonScript(LightningWork):
- def on_before_run(self):
- """Called before the python script is executed."""
-
- def on_after_run(self):
- """Called after the python script is executed."""
-
- def configure_tracer(self) -> Tracer:
- """Override this hook to customize your tracer when running PythonScript with ``mode=tracer``."""
- return Tracer()
-
- def __init__(
- self,
- script_path: Union[str, Path],
- script_args: Optional[Union[str, List[str]]] = None,
- env: Optional[Dict] = None,
- **kwargs: Any,
- ):
- """The PopenPythonScript component enables to easily run a python script within a subprocess.
-
- Arguments:
- script_path: Path of the python script to run.
- script_path: The arguments to be passed to the script.
- env: Environment variables to be passed to the script.
- kwargs: LightningWork keyword arguments.
-
- Raises:
- FileNotFoundError: If the provided `script_path` doesn't exists.
-
- Example:
-
- >>> from lightning.app.components.python import PopenPythonScript
- >>> f = open("a.py", "w")
- >>> f.write("print('Hello World !')")
- 22
- >>> f.close()
- >>> python_script = PopenPythonScript("a.py")
- >>> python_script.run()
- >>> os.remove("a.py")
-
- In this example, the script will be launch with the :class:`~subprocess.Popen`.
-
- .. literalinclude:: ../../../../examples/app/components/python/component_popen.py
- :language: python
-
- """
- super().__init__(**kwargs)
- if not os.path.exists(script_path):
- raise FileNotFoundError(f"The provided `script_path` {script_path}` wasn't found.")
- self.script_path = str(script_path)
- if isinstance(script_args, str):
- script_args = script_args.split(" ")
- self.script_args = script_args if script_args else []
- self.env = env
- self.pid = None
- self.exit_code = None
-
- def run(self) -> None:
- self.on_before_run()
- self._run_with_subprocess_popen()
- self.on_after_run()
- return
-
- def _run_with_subprocess_popen(self) -> None:
- cmd = [sys.executable] + [self.script_path] + self.script_args
-
- with subprocess.Popen(
- cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True, env=self.env
- ) as proc:
- self.pid = proc.pid
- if proc.stdout:
- with proc.stdout:
- for line in iter(proc.stdout.readline, b""):
- logger.info("%s", line.decode().rstrip())
-
- self.exit_code = proc.wait()
- if self.exit_code != 0:
- raise Exception(self.exit_code)
-
- def on_exit(self):
- for child_pid in _collect_child_process_pids(os.getpid()):
- os.kill(child_pid, signal.SIGTERM)
-
-
-__all__ = ["PopenPythonScript"]
diff --git a/src/lightning/app/components/python/tracer.py b/src/lightning/app/components/python/tracer.py
deleted file mode 100644
index a33b8f97c11d0..0000000000000
--- a/src/lightning/app/components/python/tracer.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import signal
-import sys
-from copy import deepcopy
-from typing import Any, Dict, List, Optional, Union
-
-from typing_extensions import TypedDict
-
-from lightning.app.core.work import LightningWork
-from lightning.app.storage.drive import Drive
-from lightning.app.storage.payload import Payload
-from lightning.app.utilities.app_helpers import Logger, _collect_child_process_pids
-from lightning.app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
-from lightning.app.utilities.tracer import Tracer
-
-logger = Logger(__name__)
-
-
-class Code(TypedDict):
- drive: Drive
- name: str
-
-
-class TracerPythonScript(LightningWork):
- _start_method = "spawn"
-
- def on_before_run(self):
- """Called before the python script is executed."""
-
- def on_after_run(self, res: Any):
- """Called after the python script is executed."""
- for name in self.outputs:
- setattr(self, name, Payload(res[name]))
-
- def configure_tracer(self) -> Tracer:
- """Override this hook to customize your tracer when running PythonScript."""
- return Tracer()
-
- def __init__(
- self,
- script_path: str,
- script_args: Optional[Union[list, str]] = None,
- outputs: Optional[List[str]] = None,
- env: Optional[Dict] = None,
- code: Optional[Code] = None,
- **kwargs: Any,
- ):
- """The TracerPythonScript class enables to easily run a python script.
-
- When subclassing this class, you can configure your own :class:`~lightning.app.utilities.tracer.Tracer`
- by :meth:`~lightning.app.components.python.tracer.TracerPythonScript.configure_tracer` method.
-
- The tracer is quite a magical class. It enables you to inject code into a script execution without changing it.
-
- Arguments:
- script_path: Path of the python script to run.
- script_path: The arguments to be passed to the script.
- outputs: Collection of object names to collect after the script execution.
- env: Environment variables to be passed to the script.
- kwargs: LightningWork Keyword arguments.
-
- Raises:
- FileNotFoundError: If the provided `script_path` doesn't exists.
-
- **How does it work?**
-
- It works by executing the python script with python built-in `runpy
- `_ run_path method.
- This method takes any python globals before executing the script,
- e.g., you can modify classes or function from the script.
-
- Example:
-
- >>> from lightning.app.components.python import TracerPythonScript
- >>> f = open("a.py", "w")
- >>> f.write("print('Hello World !')")
- 22
- >>> f.close()
- >>> python_script = TracerPythonScript("a.py")
- >>> python_script.run()
- Hello World !
- >>> os.remove("a.py")
-
- In the example below, we subclass the :class:`~lightning.app.components.python.TracerPythonScript`
- component and override its configure_tracer method.
-
- Using the Tracer, we are patching the ``__init__`` method of the PyTorch Lightning Trainer.
- Once the script starts running and if a Trainer is instantiated, the provided ``pre_fn`` is
- called and we inject a Lightning callback.
-
- This callback has a reference to the work and on every batch end, we are capturing the
- trainer ``global_step`` and ``best_model_path``.
-
- Even more interesting, this component works for ANY PyTorch Lightning script and
- its state can be used in real time in a UI.
-
- .. literalinclude:: ../../../../examples/app/components/python/component_tracer.py
- :language: python
-
-
- Once implemented, this component can easily be integrated within a larger app
- to execute a specific python script.
-
- .. literalinclude:: ../../../../examples/app/components/python/app.py
- :language: python
-
- """
- super().__init__(**kwargs)
- self.script_path = str(script_path)
- if isinstance(script_args, str):
- script_args = script_args.split(" ")
- self.script_args = script_args if script_args else []
- self.original_args = deepcopy(self.script_args)
- self.env = env
- self.outputs = outputs or []
- for name in self.outputs:
- setattr(self, name, None)
- self.params = None
- self.drive = code.get("drive") if code else None
- self.code_name = code.get("name") if code else None
- self.restart_count = 0
-
- def run(
- self,
- params: Optional[Dict[str, Any]] = None,
- restart_count: Optional[int] = None,
- code_dir: Optional[str] = ".",
- **kwargs: Any,
- ):
- """
- Arguments:
- params: A dictionary of arguments to be be added to script_args.
- restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
- code_dir: A path string determining where the source is extracted, default is current directory.
- """
- if restart_count:
- self.restart_count = restart_count
-
- if params:
- self.params = params
- self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]
-
- if self.drive:
- assert self.code_name
- if os.path.exists(self.code_name):
- clean_tarfile(self.code_name, "r:gz")
-
- if self.code_name in self.drive.list():
- self.drive.get(self.code_name)
- extract_tarfile(self.code_name, code_dir, "r:gz")
-
- prev_cwd = os.getcwd()
- os.chdir(code_dir)
-
- if not os.path.exists(self.script_path):
- raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
-
- kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
-
- init_globals = globals()
- init_globals.update(kwargs)
-
- self.on_before_run()
- env_copy = os.environ.copy()
- if self.env:
- os.environ.update(self.env)
- res = self._run_tracer(init_globals)
- os.chdir(prev_cwd)
- os.environ = env_copy
- return self.on_after_run(res)
-
- def _run_tracer(self, init_globals):
- sys.argv = [self.script_path]
- tracer = self.configure_tracer()
- return tracer.trace(self.script_path, *self.script_args, init_globals=init_globals)
-
- def on_exit(self):
- for child_pid in _collect_child_process_pids(os.getpid()):
- os.kill(child_pid, signal.SIGTERM)
-
- @staticmethod
- def _to_script_args(k: str, v: str) -> str:
- return f"{k}={v}"
-
-
-__all__ = ["TracerPythonScript"]
diff --git a/src/lightning/app/components/serve/__init__.py b/src/lightning/app/components/serve/__init__.py
deleted file mode 100644
index bd4cf07082aa1..0000000000000
--- a/src/lightning/app/components/serve/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from lightning.app.components.serve.auto_scaler import AutoScaler
-from lightning.app.components.serve.cold_start_proxy import ColdStartProxy
-from lightning.app.components.serve.gradio_server import ServeGradio
-from lightning.app.components.serve.python_server import Category, Image, Number, PythonServer, Text
-from lightning.app.components.serve.streamlit import ServeStreamlit
-
-__all__ = [
- "ServeGradio",
- "ServeStreamlit",
- "PythonServer",
- "Image",
- "Number",
- "Category",
- "Text",
- "AutoScaler",
- "ColdStartProxy",
-]
diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py
deleted file mode 100644
index 129bb4e50b635..0000000000000
--- a/src/lightning/app/components/serve/auto_scaler.py
+++ /dev/null
@@ -1,753 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import logging
-import time
-import uuid
-from itertools import cycle
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
-from typing import SupportsFloat as Numeric
-
-import requests
-import uvicorn
-from fastapi import FastAPI, HTTPException, Request
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import RedirectResponse
-from pydantic import BaseModel
-from starlette.staticfiles import StaticFiles
-
-from lightning.app.components.serve.cold_start_proxy import ColdStartProxy
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cloud import is_running_in_cloud
-from lightning.app.utilities.imports import _is_aiohttp_available, requires
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-if _is_aiohttp_available():
- import aiohttp
- import aiohttp.client_exceptions
-
-logger = Logger(__name__)
-
-
-class _TrackableFastAPI(FastAPI):
- """A FastAPI subclass that tracks the request metadata."""
-
- def __init__(self, *args: Any, **kwargs: Any):
- super().__init__(*args, **kwargs)
- self.global_request_count = 0
- self.num_current_requests = 0
- self.last_processing_time = 0
-
-
-def _maybe_raise_granular_exception(exception: Exception) -> None:
- """Handle an exception from hitting the model servers."""
- if not isinstance(exception, Exception):
- return
-
- if isinstance(exception, HTTPException):
- raise exception
-
- if isinstance(exception, aiohttp.client_exceptions.ServerDisconnectedError):
- raise HTTPException(500, "Worker Server Disconnected") from exception
-
- if isinstance(exception, aiohttp.client_exceptions.ClientError):
- logging.exception(exception)
- raise HTTPException(500, "Worker Server error") from exception
-
- if isinstance(exception, asyncio.TimeoutError):
- raise HTTPException(408, "Request timed out") from exception
-
- if isinstance(exception, Exception) and exception.args[0] == "Server disconnected":
- raise HTTPException(500, "Worker Server disconnected") from exception
-
- logging.exception(exception)
- raise HTTPException(500, exception.args[0]) from exception
-
-
-class _SysInfo(BaseModel):
- num_workers: int
- servers: List[str]
- num_requests: int
- processing_time: int
- global_request_count: int
-
-
-class _BatchRequestModel(BaseModel):
- inputs: List[Any]
-
-
-def _create_fastapi(title: str) -> _TrackableFastAPI:
- fastapi_app = _TrackableFastAPI(title=title)
-
- fastapi_app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
-
- @fastapi_app.get("/", include_in_schema=False)
- async def docs():
- return RedirectResponse("/docs")
-
- @fastapi_app.get("/num-requests")
- async def num_requests() -> int:
- return fastapi_app.num_current_requests
-
- return fastapi_app
-
-
-class _LoadBalancer(LightningWork):
- r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
- asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
-
- After enabling you will require to send username and password from the request header for the private endpoints.
-
- Args:
- input_type: Input type.
- output_type: Output type.
- endpoint: The REST API path.
- max_batch_size: The number of requests processed at once.
- timeout_batching: The number of seconds to wait before sending the requests to process in order to allow for
- requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
- timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
- timeout_inference_request: The number of seconds to wait for inference.
- api_name: The name to be displayed on the UI. Normally, it is the name of the work class
- cold_start_proxy: The proxy service to use while the work is cold starting.
- **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
-
- """
-
- @requires(["aiohttp"])
- def __init__(
- self,
- input_type: Type[BaseModel],
- output_type: Type[BaseModel],
- endpoint: str,
- max_batch_size: int = 8,
- # all timeout args are in seconds
- timeout_batching: float = 1,
- timeout_keep_alive: int = 60,
- timeout_inference_request: int = 60,
- api_name: Optional[str] = "API", # used for displaying the name in the UI
- cold_start_proxy: Union[ColdStartProxy, str, None] = None,
- **kwargs: Any,
- ) -> None:
- super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
- self._input_type = input_type
- self._output_type = output_type
- self._timeout_keep_alive = timeout_keep_alive
- self._timeout_inference_request = timeout_inference_request
- self.servers = []
- self.max_batch_size = max_batch_size
- self.timeout_batching = timeout_batching
- self._iter = None
- self._batch = []
- self._responses = {} # {request_id: response}
- self._last_batch_sent = None
- self._server_status = {}
- self._api_name = api_name
- self.ready = False
-
- if not endpoint.startswith("/"):
- endpoint = "/" + endpoint
-
- self.endpoint = endpoint
- self._fastapi_app = None
-
- self._cold_start_proxy = None
- if cold_start_proxy:
- if isinstance(cold_start_proxy, str):
- self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy)
- elif isinstance(cold_start_proxy, ColdStartProxy):
- self._cold_start_proxy = cold_start_proxy
- else:
- raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
-
- def get_internal_url(self) -> str:
- if not self._public_ip:
- raise ValueError("Public IP not set")
- return f"http://{self._public_ip}:{self._port}"
-
- async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
- request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
- batch_request_data = _BatchRequestModel(inputs=request_data)
-
- try:
- self._server_status[server_url] = False
- async with aiohttp.ClientSession() as session:
- headers = {
- "accept": "application/json",
- "Content-Type": "application/json",
- }
- async with session.post(
- f"{server_url}{self.endpoint}",
- json=batch_request_data.dict(),
- timeout=self._timeout_inference_request,
- headers=headers,
- ) as response:
- if response.status == 408:
- raise HTTPException(408, "Request timed out")
- response.raise_for_status()
- response = await response.json()
- outputs = response["outputs"]
- if len(batch) != len(outputs):
- raise RuntimeError(f"result has {len(outputs)} items but batch is {len(batch)}")
- result = {request[0]: r for request, r in zip(batch, outputs)}
- self._responses.update(result)
- except Exception as ex:
- result = {request[0]: ex for request in batch}
- self._responses.update(result)
- finally:
- # resetting the server status so other requests can be
- # scheduled on this node
- if server_url in self._server_status:
- # TODO - if the server returns an error, track that so
- # we don't send more requests to it
- self._server_status[server_url] = True
-
- def _find_free_server(self) -> Optional[str]:
- existing = set(self._server_status.keys())
- for server in existing:
- status = self._server_status.get(server, None)
- if status is None:
- logger.error("Server is not found in the status list. This should not happen.")
- if status:
- return server
- return None
-
- async def consumer(self):
- """The consumer process that continuously checks for new requests and sends them to the API.
-
- Two instances of this function should not be running with shared `_state_server` as that would create race
- conditions
-
- """
- while True:
- await asyncio.sleep(0.05)
- batch = self._batch[: self.max_batch_size]
- is_batch_ready = len(batch) == self.max_batch_size
- if len(batch) > 0 and self._last_batch_sent is None:
- self._last_batch_sent = time.time()
-
- if self._last_batch_sent:
- is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching
- else:
- is_batch_timeout = False
-
- server_url = self._find_free_server()
- # setting the server status to be busy! This will be reset by
- # the send_batch function after the server responds
- if server_url is None:
- continue
- if batch and (is_batch_ready or is_batch_timeout):
- self._server_status[server_url] = False
- # find server with capacity
- # Saving a reference to the result of this function, protects the task disappearing mid-execution
- # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
- task = asyncio.create_task(self.send_batch(batch, server_url)) # noqa: F841
- # resetting the batch array, TODO - not locking the array
- self._batch = self._batch[len(batch) :]
- self._last_batch_sent = time.time()
-
- async def process_request(self, data: BaseModel, request_id=None):
- if request_id is None:
- request_id = uuid.uuid4().hex
- if not self.servers and not self._cold_start_proxy:
- # sleeping to trigger the scale up
- raise HTTPException(503, "None of the workers are healthy!, try again in a few seconds")
-
- # if no servers are available, proxy the request to cold start proxy handler
- if not self.servers and self._cold_start_proxy:
- return await self._cold_start_proxy.handle_request(data)
-
- # if out of capacity, proxy the request to cold start proxy handler
- if not self._has_processing_capacity() and self._cold_start_proxy:
- return await self._cold_start_proxy.handle_request(data)
-
- # if we have capacity, process the request
- self._batch.append((request_id, data))
- while True:
- await asyncio.sleep(0.05)
- if request_id in self._responses:
- result = self._responses[request_id]
- del self._responses[request_id]
- _maybe_raise_granular_exception(result)
- return result
-
- def _has_processing_capacity(self):
- """This function checks if we have processing capacity for one more request or not.
-
- Depends on the value from here, we decide whether we should proxy the request or not
-
- """
- if not self._fastapi_app:
- return False
- active_server_count = len(self.servers)
- max_processable = self.max_batch_size * active_server_count
- current_req_count = self._fastapi_app.num_current_requests
- return current_req_count < max_processable
-
- def run(self):
- logger.info(f"servers: {self.servers}")
-
- self._iter = cycle(self.servers)
-
- fastapi_app = _create_fastapi("Load Balancer")
- fastapi_app.SEND_TASK = None
- self._fastapi_app = fastapi_app
-
- input_type = self._input_type
-
- @fastapi_app.middleware("http")
- async def current_request_counter(request: Request, call_next):
- if request.scope["path"] != self.endpoint:
- return await call_next(request)
- fastapi_app.global_request_count += 1
- fastapi_app.num_current_requests += 1
- start_time = time.time()
- response = await call_next(request)
- processing_time = time.time() - start_time
- fastapi_app.last_processing_time = processing_time
- fastapi_app.num_current_requests -= 1
- return response
-
- @fastapi_app.on_event("startup")
- async def startup_event():
- fastapi_app.SEND_TASK = asyncio.create_task(self.consumer())
-
- @fastapi_app.on_event("shutdown")
- def shutdown_event():
- fastapi_app.SEND_TASK.cancel()
-
- @fastapi_app.get("/system/info", response_model=_SysInfo)
- async def sys_info():
- return _SysInfo(
- num_workers=len(self.servers),
- servers=self.servers,
- num_requests=fastapi_app.num_current_requests,
- processing_time=fastapi_app.last_processing_time,
- global_request_count=fastapi_app.global_request_count,
- )
-
- @fastapi_app.put("/system/update-servers")
- async def update_servers(servers: List[str]):
- self.servers = servers
- self._iter = cycle(self.servers)
- updated_servers = set()
- # do not try to loop over the dict keys as the dict might change from other places
- existing_servers = list(self._server_status.keys())
- for server in servers:
- updated_servers.add(server)
- if server not in existing_servers:
- self._server_status[server] = True
- logger.info(f"Registering server {server}", self._server_status)
- for existing in existing_servers:
- if existing not in updated_servers:
- logger.info(f"De-Registering server {existing}", self._server_status)
- del self._server_status[existing]
-
- @fastapi_app.post(self.endpoint, response_model=self._output_type)
- async def balance_api(inputs: input_type):
- return await self.process_request(inputs)
-
- endpoint_info_page = self._get_endpoint_info_page()
- if endpoint_info_page:
- fastapi_app.mount(
- "/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static"
- )
-
- logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
- self.ready = True
- uvicorn.run(
- fastapi_app,
- host=self.host,
- port=self.port,
- loop="uvloop",
- timeout_keep_alive=self._timeout_keep_alive,
- access_log=False,
- )
-
- def update_servers(self, server_works: List[LightningWork]):
- """Updates works that load balancer distributes requests to.
-
- AutoScaler uses this method to increase/decrease the number of works.
-
- """
- old_server_urls = set(self.servers)
- current_server_urls = {
- f"http://{server._public_ip}:{server.port}" for server in server_works if server._internal_ip
- }
-
- # doing nothing if no server work has been added/removed
- if old_server_urls == current_server_urls:
- return
-
- # checking if the url is ready or not
- available_urls = set()
- for url in current_server_urls:
- try:
- _ = requests.get(url)
- except requests.exceptions.ConnectionError:
- continue
- else:
- available_urls.add(url)
- if old_server_urls == available_urls:
- return
-
- newly_added = available_urls - old_server_urls
- if newly_added:
- logger.info(f"servers added: {newly_added}")
-
- deleted = old_server_urls - available_urls
- if deleted:
- logger.info(f"servers deleted: {deleted}")
- self.send_request_to_update_servers(list(available_urls))
-
- def send_request_to_update_servers(self, servers: List[str]):
- try:
- internal_url = self.get_internal_url()
- except ValueError:
- logger.warn("Cannot update servers as internal_url is not set")
- return
- response = requests.put(f"{internal_url}/system/update-servers", json=servers, timeout=10)
- response.raise_for_status()
-
- @staticmethod
- def _get_sample_dict_from_datatype(datatype: Any) -> dict:
- if not hasattr(datatype, "schema"):
- # not a pydantic model
- raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}")
-
- if hasattr(datatype, "get_sample_data"):
- return datatype.get_sample_data()
-
- datatype_props = datatype.schema()["properties"]
- out: Dict[str, Any] = {}
- lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False}
- for k, v in datatype_props.items():
- if v["type"] not in lut:
- raise TypeError("Unsupported type")
- out[k] = lut[v["type"]]
- return out
-
- def get_code_sample(self, url: str) -> Optional[str]:
- input_type: Any = self._input_type
- output_type: Any = self._output_type
-
- if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
- return None
- return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
-
- def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F821
- try:
- from lightning_api_access import APIAccessFrontend
- except ModuleNotFoundError:
- logger.warn(
- "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`"
- )
- return None
-
- if is_running_in_cloud():
- url = f"{self._future_url}{self.endpoint}"
- else:
- url = f"http://localhost:{self.port}{self.endpoint}"
-
- frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None}
- code_samples = self.get_code_sample(url)
- if code_samples:
- frontend_objects["code_sample"] = code_samples
- # TODO also set request/response for JS UI
- else:
- try:
- request = self._get_sample_dict_from_datatype(self._input_type)
- response = self._get_sample_dict_from_datatype(self._output_type)
- except TypeError:
- return None
- else:
- frontend_objects["request"] = request
- frontend_objects["response"] = response
- return APIAccessFrontend(apis=[frontend_objects])
-
-
-class AutoScaler(LightningFlow):
- """The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in response to
- changes in the number of incoming requests. Incoming requests will be batched and balanced across the replicas.
-
- Args:
- min_replicas: The number of works to start when app initializes.
- max_replicas: The max number of works to spawn to handle the incoming requests.
- scale_out_interval: The number of seconds to wait before checking whether to increase the number of servers.
- scale_in_interval: The number of seconds to wait before checking whether to decrease the number of servers.
- endpoint: Provide the REST API path.
- max_batch_size: (auto-batching) The number of requests to process at once.
- timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
- input_type: Input type.
- output_type: Output type.
- cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up.
-
- .. testcode::
-
- from lightning.app import LightningApp
- from lightning.app.components import AutoScaler
-
- # Example 1: Auto-scaling serve component out-of-the-box
- app = LightningApp(
- app.components.AutoScaler(
- MyPythonServer,
- min_replicas=1,
- max_replicas=8,
- scale_out_interval=10,
- scale_in_interval=10,
- )
- )
-
-
- # Example 2: Customizing the scaling logic
- class MyAutoScaler(AutoScaler):
- def scale(self, replicas: int, metrics: dict) -> int:
- pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
- replicas + metrics["pending_works"]
- )
-
- # upscale
- max_requests_per_work = self.max_batch_size
- if pending_requests_per_running_or_pending_work >= max_requests_per_work:
- return replicas + 1
-
- # downscale
- min_requests_per_work = max_requests_per_work * 0.25
- if pending_requests_per_running_or_pending_work < min_requests_per_work:
- return replicas - 1
-
- return replicas
-
-
- app = LightningApp(
- MyAutoScaler(
- MyPythonServer,
- min_replicas=1,
- max_replicas=8,
- scale_out_interval=10,
- scale_in_interval=10,
- max_batch_size=8, # for auto batching
- timeout_batching=1, # for auto batching
- )
- )
-
- """
-
- def __init__(
- self,
- work_cls: Type[LightningWork],
- min_replicas: int = 1,
- max_replicas: int = 4,
- scale_out_interval: Numeric = 10,
- scale_in_interval: Numeric = 10,
- max_batch_size: int = 8,
- timeout_batching: float = 1,
- endpoint: str = "api/predict",
- input_type: Type[BaseModel] = Dict,
- output_type: Type[BaseModel] = Dict,
- cold_start_proxy: Union[ColdStartProxy, str, None] = None,
- *work_args: Any,
- **work_kwargs: Any,
- ) -> None:
- super().__init__()
- self.num_replicas = 0
- self._work_registry = {}
-
- self._work_cls = work_cls
- self._work_args = work_args
- self._work_kwargs = work_kwargs
-
- self._input_type = input_type
- self._output_type = output_type
- self.scale_out_interval = scale_out_interval
- self.scale_in_interval = scale_in_interval
- self.max_batch_size = max_batch_size
-
- if max_replicas < min_replicas:
- raise ValueError(
- f"`max_replicas={max_replicas}` must be less than or equal to `min_replicas={min_replicas}`."
- )
- self.max_replicas = max_replicas
- self.min_replicas = min_replicas
- self._last_autoscale = time.time()
- self.fake_trigger = 0
-
- self.load_balancer = _LoadBalancer(
- input_type=self._input_type,
- output_type=self._output_type,
- endpoint=endpoint,
- max_batch_size=max_batch_size,
- timeout_batching=timeout_batching,
- cache_calls=True,
- parallel=True,
- api_name=self._work_cls.__name__,
- cold_start_proxy=cold_start_proxy,
- )
-
- @property
- def ready(self) -> bool:
- return self.load_balancer.ready
-
- @property
- def workers(self) -> List[LightningWork]:
- return [self.get_work(i) for i in range(self.num_replicas)]
-
- def create_work(self) -> LightningWork:
- """Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
- cloud_compute = self._work_kwargs.get("cloud_compute", None)
- self._work_kwargs.update({
- "start_with_flow": False,
- "cloud_compute": cloud_compute.clone() if cloud_compute else None,
- })
- return self._work_cls(*self._work_args, **self._work_kwargs)
-
- def add_work(self, work) -> str:
- """Adds a new LightningWork instance.
-
- Returns:
- The name of the new work attribute.
-
- """
- work_attribute = uuid.uuid4().hex
- work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
- setattr(self, work_attribute, work)
- self._work_registry[self.num_replicas] = work_attribute
- self.num_replicas += 1
- return work_attribute
-
- def remove_work(self, index: int) -> str:
- """Removes the ``index`` th LightningWork instance."""
- work_attribute = self._work_registry[index]
- del self._work_registry[index]
- work = getattr(self, work_attribute)
- work.stop()
- self.num_replicas -= 1
- return work_attribute
-
- def get_work(self, index: int) -> LightningWork:
- """Returns the ``LightningWork`` instance with the given index."""
- work_attribute = self._work_registry[index]
- return getattr(self, work_attribute)
-
- def run(self):
- if not self.load_balancer.is_running:
- self.load_balancer.run()
- for work in self.workers:
- work.run()
- if self.load_balancer.url:
- self.fake_trigger += 1 # Note: change state to keep calling `run`.
- self.autoscale()
-
- def scale(self, replicas: int, metrics: dict) -> int:
- """The default scaling logic that users can override.
-
- Args:
- replicas: The number of running works.
- metrics: ``metrics['pending_requests']`` is the total number of requests that are currently pending.
- ``metrics['pending_works']`` is the number of pending works.
-
- Returns:
- The target number of running works. The value will be adjusted after this method runs
- so that it satisfies ``min_replicas<=replicas<=max_replicas``.
-
- """
- pending_requests = metrics["pending_requests"]
- active_or_pending_works = replicas + metrics["pending_works"]
-
- if active_or_pending_works == 0:
- return 1 if pending_requests > 0 else 0
-
- pending_requests_per_running_or_pending_work = pending_requests / active_or_pending_works
-
- # scale out if the number of pending requests exceeds max batch size.
- max_requests_per_work = self.max_batch_size
- if pending_requests_per_running_or_pending_work >= max_requests_per_work:
- return replicas + 1
-
- # scale in if the number of pending requests is below 25% of max_requests_per_work
- min_requests_per_work = max_requests_per_work * 0.25
- if pending_requests_per_running_or_pending_work < min_requests_per_work:
- return replicas - 1
-
- return replicas
-
- @property
- def num_pending_requests(self) -> int:
- """Fetches the number of pending requests via load balancer."""
- try:
- load_balancer_url = self.load_balancer.get_internal_url()
- except ValueError:
- logger.warn("Cannot update servers as internal_url is not set")
- return 0
- return int(requests.get(f"{load_balancer_url}/num-requests").json())
-
- @property
- def num_pending_works(self) -> int:
- """The number of pending works."""
- return sum(work.is_pending for work in self.workers)
-
- def autoscale(self) -> None:
- """Adjust the number of works based on the target number returned by ``self.scale``."""
- metrics = {
- "pending_requests": self.num_pending_requests,
- "pending_works": self.num_pending_works,
- }
-
- # ensure min_replicas <= num_replicas <= max_replicas
- num_target_workers = max(
- self.min_replicas,
- min(self.max_replicas, self.scale(self.num_replicas, metrics)),
- )
-
- # scale-out
- if time.time() - self._last_autoscale > self.scale_out_interval:
- # TODO figuring out number of workers to add only based on num_replicas isn't right because pending works
- # are not added to num_replicas
- num_workers_to_add = num_target_workers - self.num_replicas
- for _ in range(num_workers_to_add):
- logger.info(f"Scaling out from {self.num_replicas} to {self.num_replicas + 1}")
- work = self.create_work()
- # TODO: move works into structures
- new_work_id = self.add_work(work)
- logger.info(f"Work created: '{new_work_id}'")
- if num_workers_to_add > 0:
- self._last_autoscale = time.time()
-
- # scale-in
- if time.time() - self._last_autoscale > self.scale_in_interval:
- # TODO figuring out number of workers to remove only based on num_replicas isn't right because pending works
- # are not added to num_replicas
- num_workers_to_remove = self.num_replicas - num_target_workers
- for _ in range(num_workers_to_remove):
- logger.info(f"Scaling in from {self.num_replicas} to {self.num_replicas - 1}")
- removed_work_id = self.remove_work(self.num_replicas - 1)
- logger.info(f"Work removed: '{removed_work_id}'")
- if num_workers_to_remove > 0:
- self._last_autoscale = time.time()
-
- self.load_balancer.update_servers(self.workers)
-
- def configure_layout(self):
- return [
- {"name": "Endpoint Info", "content": f"{self.load_balancer.url}/endpoint-info"},
- {"name": "Swagger", "content": self.load_balancer.url},
- ]
diff --git a/src/lightning/app/components/serve/catimage.png b/src/lightning/app/components/serve/catimage.png
deleted file mode 100644
index a76a35bdb88fb..0000000000000
Binary files a/src/lightning/app/components/serve/catimage.png and /dev/null differ
diff --git a/src/lightning/app/components/serve/cold_start_proxy.py b/src/lightning/app/components/serve/cold_start_proxy.py
deleted file mode 100644
index 2b56e3cd0ca44..0000000000000
--- a/src/lightning/app/components/serve/cold_start_proxy.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-from typing import Any
-
-from fastapi import HTTPException
-from pydantic import BaseModel
-
-from lightning.app.utilities.imports import _is_aiohttp_available, requires
-
-if _is_aiohttp_available():
- import aiohttp
- import aiohttp.client_exceptions
-
-
-class ColdStartProxy:
- """ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold
- starting. This is useful with services that gets realtime requests but startup time for workers is high.
-
- If the request body is same and the method is POST for the proxy service,
- then the default implementation of `handle_request` can be used. In that case
- initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request`
-
- Args:
- proxy_url (str): The url of the proxy service
-
- """
-
- @requires(["aiohttp"])
- def __init__(self, proxy_url: str):
- self.proxy_url = proxy_url
- self.proxy_timeout = 50
- if not asyncio.iscoroutinefunction(self.handle_request):
- raise TypeError("handle_request must be an `async` function")
-
- async def handle_request(self, request: BaseModel) -> Any:
- """This method is called when the request is received while the work is cold starting. The default
- implementation of this method is to forward the request body to the proxy service with POST method but the user
- can override this method to handle the request in any way.
-
- Args:
- request: The request body, a pydantic model that is being forwarded by load balancer which
- is a FastAPI service
-
- """
- try:
- async with aiohttp.ClientSession() as session:
- headers = {
- "accept": "application/json",
- "Content-Type": "application/json",
- }
- async with session.post(
- self.proxy_url,
- json=request.dict(),
- timeout=self.proxy_timeout,
- headers=headers,
- ) as response:
- return await response.json()
- except Exception as ex:
- raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}")
diff --git a/src/lightning/app/components/serve/gradio_server.py b/src/lightning/app/components/serve/gradio_server.py
deleted file mode 100644
index a413136fc6432..0000000000000
--- a/src/lightning/app/components/serve/gradio_server.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-from functools import partial
-from types import ModuleType
-from typing import Any, List, Optional
-
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.imports import _is_gradio_available, requires
-
-if _is_gradio_available():
- import gradio
-else:
- gradio = ModuleType("gradio")
- gradio.themes = ModuleType("gradio.themes")
-
- class __DummyBase:
- pass
-
- gradio.themes.Base = __DummyBase
-
-
-class ServeGradio(LightningWork, abc.ABC):
- """The ServeGradio Class enables to quickly create a ``gradio`` based UI for your LightningApp.
-
- In the example below, the ``ServeGradio`` is subclassed to deploy ``AnimeGANv2``.
-
- .. literalinclude:: ../../../../examples/app/components/serve/gradio/app.py
- :language: python
-
- The result would be the following:
-
- .. image:: https://pl-public-data.s3.amazonaws.com/assets_lightning/anime_gan.gif
- :alt: Animation showing how to AnimeGANv2 UI would looks like.
-
- """
-
- inputs: Any
- outputs: Any
- examples: Optional[List] = None
- enable_queue: bool = False
- title: Optional[str] = None
- description: Optional[str] = None
-
- _start_method = "spawn"
-
- def __init__(self, *args: Any, theme: Optional[gradio.themes.Base] = None, **kwargs: Any):
- requires("gradio")(super().__init__(*args, **kwargs))
- assert self.inputs
- assert self.outputs
- self._model = None
- self._theme = theme or ServeGradio.__get_lightning_gradio_theme()
-
- self.ready = False
-
- @property
- def model(self):
- return self._model
-
- @abc.abstractmethod
- def predict(self, *args: Any, **kwargs: Any):
- """Override with your logic to make a prediction."""
-
- @abc.abstractmethod
- def build_model(self) -> Any:
- """Override to instantiate and return your model.
-
- The model would be accessible under self.model
-
- """
-
- def run(self, *args: Any, **kwargs: Any):
- if self._model is None:
- self._model = self.build_model()
- fn = partial(self.predict, *args, **kwargs)
- fn.__name__ = self.predict.__name__
- self.ready = True
- gradio.Interface(
- fn=fn,
- inputs=self.inputs,
- outputs=self.outputs,
- examples=self.examples,
- title=self.title,
- description=self.description,
- theme=self._theme,
- ).launch(
- server_name=self.host,
- server_port=self.port,
- enable_queue=self.enable_queue,
- )
-
- def configure_layout(self) -> str:
- return self.url
-
- @staticmethod
- def __get_lightning_gradio_theme():
- return gradio.themes.Default(
- primary_hue=gradio.themes.Color(
- "#ffffff",
- "#e9d5ff",
- "#d8b4fe",
- "#c084fc",
- "#fcfcfc",
- "#a855f7",
- "#9333ea",
- "#8823e1",
- "#6b21a8",
- "#2c2730",
- "#1c1c1c",
- ),
- secondary_hue=gradio.themes.Color(
- "#c3a1e8",
- "#e9d5ff",
- "#d3bbec",
- "#c795f9",
- "#9174af",
- "#a855f7",
- "#9333ea",
- "#6700c2",
- "#000000",
- "#991ef1",
- "#33243d",
- ),
- neutral_hue=gradio.themes.Color(
- "#ede9fe",
- "#ddd6fe",
- "#c4b5fd",
- "#a78bfa",
- "#fafafa",
- "#8b5cf6",
- "#7c3aed",
- "#6d28d9",
- "#6130b0",
- "#8a4ce6",
- "#3b3348",
- ),
- ).set(
- body_background_fill="*primary_50",
- body_background_fill_dark="*primary_950",
- body_text_color_dark="*primary_100",
- body_text_size="*text_sm",
- body_text_color_subdued_dark="*primary_100",
- background_fill_primary="*primary_50",
- background_fill_primary_dark="*primary_950",
- background_fill_secondary="*primary_50",
- background_fill_secondary_dark="*primary_950",
- border_color_accent="*primary_400",
- border_color_accent_dark="*primary_900",
- border_color_primary="*primary_600",
- border_color_primary_dark="*primary_800",
- color_accent="*primary_400",
- color_accent_soft="*primary_300",
- color_accent_soft_dark="*primary_700",
- link_text_color="*primary_500",
- link_text_color_dark="*primary_50",
- link_text_color_active="*secondary_800",
- link_text_color_active_dark="*primary_500",
- link_text_color_hover="*primary_400",
- link_text_color_hover_dark="*primary_400",
- link_text_color_visited="*primary_500",
- link_text_color_visited_dark="*secondary_100",
- block_background_fill="*primary_50",
- block_background_fill_dark="*primary_900",
- block_border_color_dark="*primary_800",
- checkbox_background_color="*primary_50",
- checkbox_background_color_dark="*primary_50",
- checkbox_background_color_focus="*primary_100",
- checkbox_background_color_focus_dark="*primary_100",
- checkbox_background_color_hover="*primary_400",
- checkbox_background_color_hover_dark="*primary_500",
- checkbox_background_color_selected="*primary_300",
- checkbox_background_color_selected_dark="*primary_500",
- checkbox_border_color_dark="*primary_200",
- checkbox_border_radius="*radius_md",
- input_background_fill="*primary_50",
- input_background_fill_dark="*primary_900",
- input_radius="*radius_xxl",
- slider_color="*primary_600",
- slider_color_dark="*primary_700",
- button_large_radius="*radius_xxl",
- button_large_text_size="*text_md",
- button_small_radius="*radius_xxl",
- button_primary_background_fill_dark="*primary_800",
- button_primary_background_fill_hover_dark="*primary_700",
- button_primary_border_color_dark="*primary_800",
- button_secondary_background_fill="*neutral_200",
- button_secondary_background_fill_dark="*primary_600",
- )
diff --git a/src/lightning/app/components/serve/python_server.py b/src/lightning/app/components/serve/python_server.py
deleted file mode 100644
index 4c1621de7197d..0000000000000
--- a/src/lightning/app/components/serve/python_server.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-import asyncio
-import base64
-import os
-import platform
-from typing import TYPE_CHECKING, Any, Dict, Optional
-
-import requests
-import uvicorn
-from fastapi import FastAPI
-from lightning_utilities.core.imports import compare_version, module_available
-from pydantic import BaseModel
-
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.imports import _is_torch_available, requires
-
-if TYPE_CHECKING:
- from lightning.app.frontend.frontend import Frontend
-
-logger = Logger(__name__)
-
-# Skip doctests if requirements aren't available
-if not module_available("lightning_api_access") or not _is_torch_available():
- __doctest_skip__ = ["PythonServer", "PythonServer.*"]
-
-
-def _get_device():
- import operator
-
- import torch
-
- _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
-
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
-
- if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
- return torch.device("mps", local_rank)
-
- return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
-
-
-class _DefaultInputData(BaseModel):
- payload: str
-
-
-class _DefaultOutputData(BaseModel):
- prediction: str
-
-
-class Image(BaseModel):
- image: Optional[str] = None
-
- @staticmethod
- def get_sample_data() -> Dict[Any, Any]:
- url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
- img = requests.get(url).content
- img = base64.b64encode(img).decode("UTF-8")
- return {"image": img}
-
- @staticmethod
- def request_code_sample(url: str) -> str:
- return f"""
-import base64
-from pathlib import Path
-import requests
-
-imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
-img = requests.get(imgurl).content
-img = base64.b64encode(img).decode("UTF-8")
-response = requests.post('{url}', json={{"image": img}})
-# If you are using basic authentication for your app, you should add your credentials to the request:
-# auth = requests.auth.HTTPBasicAuth('your_username', 'your_password')
-# response = requests.post('{url}', json={{"image": img}}, auth=auth)
-"""
-
- @staticmethod
- def response_code_sample() -> str:
- return """img = response.json()["image"]
-img = base64.b64decode(img.encode("utf-8"))
-Path("response.png").write_bytes(img)
-"""
-
-
-class Category(BaseModel):
- category: Optional[int] = None
-
- @staticmethod
- def get_sample_data() -> Dict[Any, Any]:
- return {"category": 463}
-
- @staticmethod
- def response_code_sample() -> str:
- return """print("Predicted category is: ", response.json()["category"])
-"""
-
-
-class Text(BaseModel):
- text: Optional[str] = None
-
- @staticmethod
- def get_sample_data() -> Dict[Any, Any]:
- return {"text": "A portrait of a person looking away from the camera"}
-
- @staticmethod
- def request_code_sample(url: str) -> str:
- return f"""
-import base64
-from pathlib import Path
-import requests
-
-response = requests.post('{url}', json={{
- "text": "A portrait of a person looking away from the camera"
-}})
-# If you are using basic authentication for your app, you should add your credentials to the request:
-# response = requests.post('{url}', json={{
-# "text": "A portrait of a person looking away from the camera"
-# }}, auth=requests.auth.HTTPBasicAuth('your_username', 'your_password'))
-"""
-
-
-class Number(BaseModel):
- # deprecated - TODO remove this in favour of Category
- prediction: Optional[int] = None
-
- @staticmethod
- def get_sample_data() -> Dict[Any, Any]:
- return {"prediction": 463}
-
-
-class PythonServer(LightningWork, abc.ABC):
- _start_method = "spawn"
-
- @requires(["torch"])
- def __init__( # type: ignore
- self,
- input_type: type = _DefaultInputData,
- output_type: type = _DefaultOutputData,
- **kwargs: Any,
- ):
- """The PythonServer Class enables to easily get your machine learning server up and running.
-
- Arguments:
- input_type: Optional `input_type` to be provided. This needs to be a pydantic BaseModel class.
- The default data type is good enough for the basic usecases and it expects the data
- to be a json object that has one key called `payload`
-
- .. code-block:: python
-
- input_data = {"payload": "some data"}
-
- and this can be accessed as `request.payload` in the `predict` method.
-
- .. code-block:: python
-
- def predict(self, request):
- data = request.payload
-
- output_type: Optional `output_type` to be provided. This needs to be a pydantic BaseModel class.
- The default data type is good enough for the basic usecases. It expects the return value of
- the `predict` method to be a dictionary with one key called `prediction`.
-
- .. code-block:: python
-
- def predict(self, request):
- # some code
- return {"prediction": "some data"}
-
- and this can be accessed as `response.json()["prediction"]` in the client if
- you are using requests library
-
- Example:
-
- >>> from lightning.app.components.serve.python_server import PythonServer
- >>> from lightning.app import LightningApp
- ...
- >>> class SimpleServer(PythonServer):
- ...
- ... def setup(self):
- ... self._model = lambda x: x + " " + x
- ...
- ... def predict(self, request):
- ... return {"prediction": self._model(request.image)}
- ...
- >>> app = LightningApp(SimpleServer())
-
- """
- super().__init__(parallel=True, **kwargs)
- if not issubclass(input_type, BaseModel):
- raise TypeError("input_type must be a pydantic BaseModel class")
- if not issubclass(output_type, BaseModel):
- raise TypeError("output_type must be a pydantic BaseModel class")
- self._input_type = input_type
- self._output_type = output_type
-
- self.ready = False
-
- def setup(self, *args: Any, **kwargs: Any) -> None:
- """This method is called before the server starts. Override this if you need to download the model or
- initialize the weights, setting up pipelines etc.
-
- Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
- this will be called on each of them.
-
- """
- return
-
- def configure_input_type(self) -> type:
- return self._input_type
-
- def configure_output_type(self) -> type:
- return self._output_type
-
- @abc.abstractmethod
- def predict(self, request: Any) -> Any:
- """This method is called when a request is made to the server.
-
- This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
- using the model(s) etc goes here
-
- """
- pass
-
- @staticmethod
- def _get_sample_dict_from_datatype(datatype: Any) -> dict:
- if hasattr(datatype, "get_sample_data"):
- return datatype.get_sample_data()
-
- datatype_props = datatype.schema()["properties"]
- out: Dict[str, Any] = {}
- for k, v in datatype_props.items():
- if v["type"] == "string":
- out[k] = "data string"
- elif v["type"] == "number":
- out[k] = 0.0
- elif v["type"] == "integer":
- out[k] = 0
- elif v["type"] == "boolean":
- out[k] = False
- else:
- raise TypeError("Unsupported type")
- return out
-
- def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
- input_type: type = self.configure_input_type()
- output_type: type = self.configure_output_type()
-
- def predict_fn_sync(request: input_type): # type: ignore
- return self.predict(request)
-
- async def async_predict_fn(request: input_type): # type: ignore
- return await self.predict(request)
-
- if asyncio.iscoroutinefunction(self.predict):
- fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)
- else:
- fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)
-
- def get_code_sample(self, url: str) -> Optional[str]:
- input_type: Any = self.configure_input_type()
- output_type: Any = self.configure_output_type()
-
- if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
- return None
- return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
-
- def configure_layout(self) -> Optional["Frontend"]:
- try:
- from lightning_api_access import APIAccessFrontend
- except ModuleNotFoundError:
- logger.warn(
- "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`"
- )
- return None
-
- class_name = self.__class__.__name__
- url = f"{self.url}/predict"
-
- try:
- request = self._get_sample_dict_from_datatype(self.configure_input_type())
- response = self._get_sample_dict_from_datatype(self.configure_output_type())
- except TypeError:
- return None
-
- frontend_payload = {
- "name": class_name,
- "url": url,
- "method": "POST",
- "request": request,
- "response": response,
- }
-
- code_sample = self.get_code_sample(url)
- if code_sample:
- frontend_payload["code_sample"] = code_sample
-
- return APIAccessFrontend(apis=[frontend_payload])
-
- def run(self, *args: Any, **kwargs: Any) -> Any:
- """Run method takes care of configuring and setting up a FastAPI server behind the scenes.
-
- Normally, you don't need to override this method.
-
- """
- self.setup(*args, **kwargs)
-
- fastapi_app = FastAPI()
- self._attach_predict_fn(fastapi_app)
-
- self.ready = True
- logger.info(
- f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
- )
- uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")
diff --git a/src/lightning/app/components/serve/serve.py b/src/lightning/app/components/serve/serve.py
deleted file mode 100644
index 0ae40302293f3..0000000000000
--- a/src/lightning/app/components/serve/serve.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-import inspect
-import os
-import pydoc
-import subprocess
-import sys
-from typing import Any, Callable, Optional
-
-import fastapi # noqa E511
-import uvicorn
-from fastapi import FastAPI
-from fastapi.responses import JSONResponse
-
-from lightning.app.components.serve.types import _DESERIALIZER, _SERIALIZER
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-fastapi_service = FastAPI()
-
-
-class _InferenceCallable:
- def __init__(
- self,
- deserialize: Callable,
- predict: Callable,
- serialize: Callable,
- ):
- self.deserialize = deserialize
- self.predict = predict
- self.serialize = serialize
-
- async def run(self, data) -> Any:
- return self.serialize(self.predict(self.deserialize(data)))
-
-
-class ModelInferenceAPI(LightningWork, abc.ABC):
- def __init__(
- self,
- input: Optional[str] = None,
- output: Optional[str] = None,
- host: str = "127.0.0.1",
- port: int = 7777,
- workers: int = 0,
- ):
- """The ModelInferenceAPI Class enables to easily get your model served.
-
- Arguments:
- input: Optional `input` to be provided. This would make provide a built-in deserializer.
- output: Optional `output` to be provided. This would make provide a built-in serializer.
- host: Address to be used to serve the model.
- port: Port to be used to serve the model.
- workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
-
- """
- super().__init__(parallel=True, host=host, port=port)
- if input and input not in _DESERIALIZER:
- raise Exception(f"Only input in {_DESERIALIZER.keys()} are supported.")
- if output and output not in _SERIALIZER:
- raise Exception(f"Only output in {_SERIALIZER.keys()} are supported.")
- self.input = input
- self.output = output
- self.workers = workers
- self._model = None
-
- self.ready = False
-
- @property
- def model(self):
- return self._model
-
- @abc.abstractmethod
- def build_model(self) -> Any:
- """Override to define your model."""
-
- def deserialize(self, data) -> Any:
- return data
-
- @abc.abstractmethod
- def predict(self, data) -> Any:
- """Override to add your predict logic."""
-
- def serialize(self, data) -> Any:
- return data
-
- def run(self):
- global fastapi_service
- if self.workers > 1:
- # TODO: This is quite limitated
- # Find a more reliable solution to enable multi workers serving.
- env = os.environ.copy()
- module = inspect.getmodule(self).__file__
- env["LIGHTNING_MODEL_INFERENCE_API_FILE"] = module
- env["LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME"] = self.__class__.__name__
- if self.input:
- env["LIGHTNING_MODEL_INFERENCE_API_INPUT"] = self.input
- if self.output:
- env["LIGHTNING_MODEL_INFERENCE_API_OUTPUT"] = self.output
- command = [
- sys.executable,
- "-m",
- "uvicorn",
- "--workers",
- str(self.workers),
- "--host",
- str(self.host),
- "--port",
- str(self.port),
- "serve:fastapi_service",
- ]
- process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
- self.ready = True
- process.wait()
- else:
- self._populate_app(fastapi_service)
- self.ready = True
- self._launch_server(fastapi_service)
-
- def _populate_app(self, fastapi_service: FastAPI):
- self._model = self.build_model()
-
- fastapi_service.post("/predict", response_class=JSONResponse)(
- _InferenceCallable(
- deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize,
- predict=self.predict,
- serialize=_SERIALIZER[self.output] if self.output else self.serialize,
- ).run
- )
-
- def _launch_server(self, fastapi_service: FastAPI):
- logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
- uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error")
-
- def configure_layout(self) -> str:
- return f"{self.url}/docs"
-
-
-def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
- """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers
- are present."""
- render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)
- render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)
- if render_fn_name is None or render_fn_module_file is None:
- return None
- module = pydoc.importfile(render_fn_module_file)
- cls = getattr(module, render_fn_name)
- input = os.getenv("LIGHTNING_MODEL_INFERENCE_API_INPUT", None)
- output = os.getenv("LIGHTNING_MODEL_INFERENCE_API_OUTPUT", None)
- return cls(input=input, output=output)
-
-
-instance = _maybe_create_instance()
-if instance:
- instance._populate_app(fastapi_service)
diff --git a/src/lightning/app/components/serve/streamlit.py b/src/lightning/app/components/serve/streamlit.py
deleted file mode 100644
index ff0ea9ed8a8bb..0000000000000
--- a/src/lightning/app/components/serve/streamlit.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-import inspect
-import os
-import pydoc
-import subprocess
-import sys
-from typing import Any, Callable, Type
-
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import StreamLitStatePlugin
-from lightning.app.utilities.state import AppState
-
-
-class ServeStreamlit(LightningWork, abc.ABC):
- """The ``ServeStreamlit`` work allows you to use streamlit from a work.
-
- You can optionally build a model in the ``build_model`` hook, which will only be called once per session.
-
- """
-
- def __init__(self, *args: Any, **kwargs: Any):
- super().__init__(*args, **kwargs)
-
- self.ready = False
-
- self._process = None
-
- @property
- def model(self) -> Any:
- return getattr(self, "_model", None)
-
- @abc.abstractmethod
- def render(self) -> None:
- """Override with your streamlit render function."""
-
- def build_model(self) -> Any:
- """Optionally override to instantiate and return your model.
-
- The model will be accessible under ``self.model``.
-
- """
- return None
-
- def run(self) -> None:
- env = os.environ.copy()
- env["LIGHTNING_COMPONENT_NAME"] = self.name
- env["LIGHTNING_WORK"] = self.__class__.__name__
- env["LIGHTNING_WORK_MODULE_FILE"] = inspect.getmodule(self).__file__
- self._process = subprocess.Popen(
- [
- sys.executable,
- "-m",
- "streamlit",
- "run",
- __file__,
- "--server.address",
- str(self.host),
- "--server.port",
- str(self.port),
- "--server.headless",
- "true", # do not open the browser window when running locally
- ],
- env=env,
- )
- self.ready = True
- self._process.wait()
-
- def on_exit(self) -> None:
- if self._process is not None:
- self._process.kill()
-
- def configure_layout(self) -> str:
- return self.url
-
-
-class _PatchedWork:
- """The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the self
- reference in methods an properties to point to the AppState.
-
- Args:
- state: The work state to patch
- work_class: The work class to emulate
-
- """
-
- def __init__(self, state: AppState, work_class: Type):
- super().__init__()
- self._state = state
- self._work_class = work_class
-
- def __getattr__(self, name: str) -> Any:
- try:
- return getattr(self._state, name)
- except AttributeError:
- # The name isn't in the state, so check if it's a callable or a property
- attribute = inspect.getattr_static(self._work_class, name)
- if callable(attribute):
- attribute = attribute.__get__(self, self._work_class)
- return attribute
- if isinstance(attribute, (staticmethod, property)):
- return attribute.__get__(self, self._work_class)
-
- # Look for the name in the instance (e.g. for private variables)
- return object.__getattribute__(self, name)
-
- def __setattr__(self, name: str, value: Any) -> None:
- if name in ["_state", "_work_class"]:
- return object.__setattr__(self, name, value)
-
- if hasattr(self._state, name):
- return setattr(self._state, name, value)
- return object.__setattr__(self, name, value)
-
-
-def _reduce_to_component_scope(state: AppState, component_name: str) -> AppState:
- """Given the app state, this utility traverses down to the level of the given component name."""
- component_name_parts = component_name.split(".")[1:] # exclude root
- component_state = state
- for part in component_name_parts:
- component_state = getattr(component_state, part)
- return component_state
-
-
-def _get_work_class() -> Callable:
- """Import the work class specified in the environment."""
- work_name = os.environ["LIGHTNING_WORK"]
- work_module_file = os.environ["LIGHTNING_WORK_MODULE_FILE"]
- module = pydoc.importfile(work_module_file)
- return getattr(module, work_name)
-
-
-def _build_model(work: ServeStreamlit) -> None:
- import streamlit as st
-
- # Build the model (once per session, equivalent to gradio when enable_queue is Flase)
- if "_model" not in st.session_state:
- with st.spinner("Building model..."):
- st.session_state["_model"] = work.build_model()
-
- work._model = st.session_state["_model"]
-
-
-def _main() -> None:
- # Get the AppState
- app_state = AppState(plugin=StreamLitStatePlugin())
- work_state = _reduce_to_component_scope(app_state, os.environ["LIGHTNING_COMPONENT_NAME"])
-
- # Create the patched work
- work_class = _get_work_class()
- patched_work = _PatchedWork(work_state, work_class)
-
- # Build and attach the model
- _build_model(patched_work)
-
- # Render
- patched_work.render()
-
-
-if __name__ == "__main__":
- _main()
diff --git a/src/lightning/app/components/serve/types/__init__.py b/src/lightning/app/components/serve/types/__init__.py
deleted file mode 100644
index 059ca53b07682..0000000000000
--- a/src/lightning/app/components/serve/types/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning.app.components.serve.types.image import Image
-
-_SERIALIZER = {"image": Image.serialize}
-_DESERIALIZER = {"image": Image.deserialize}
diff --git a/src/lightning/app/components/serve/types/image.py b/src/lightning/app/components/serve/types/image.py
deleted file mode 100644
index 32f123fe8b63b..0000000000000
--- a/src/lightning/app/components/serve/types/image.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-from io import BytesIO
-
-from lightning.app.components.serve.types.type import BaseType
-from lightning.app.utilities.imports import _is_pil_available, _is_torch_available
-
-if _is_torch_available():
- from torch import Tensor
-
-if _is_pil_available():
- from PIL import Image as PILImage
-
-
-class Image(BaseType):
- @staticmethod
- def deserialize(data: dict):
- encoded_with_padding = (data + "===").encode("ascii")
- img = base64.b64decode(encoded_with_padding)
- buffer = BytesIO(img)
- return PILImage.open(buffer, mode="r")
-
- @staticmethod
- def serialize(tensor: "Tensor") -> str:
- tensor = tensor.squeeze(0).numpy()
- print(tensor.shape)
- image = PILImage.fromarray(tensor)
- buffer = BytesIO()
- image.save(buffer, format="PNG")
- buffer.seek(0)
- encoded = buffer.getvalue()
- return base64.b64encode(encoded).decode("ascii")
diff --git a/src/lightning/app/components/serve/types/type.py b/src/lightning/app/components/serve/types/type.py
deleted file mode 100644
index 157940a60f8e7..0000000000000
--- a/src/lightning/app/components/serve/types/type.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-from typing import Any
-
-
-class BaseType(abc.ABCMeta):
- """Base class for Types."""
-
- @abc.abstractmethod
- def serialize(self, data): # pragma: no cover
- """Serialize the incoming data to send it through the network."""
-
- @abc.abstractmethod
- def deserialize(self, *args: Any, **kwargs: Any): # pragma: no cover
- """Take the inputs from the network and deserilize/convert them them.
-
- Output from this method will go to the exposed method as arguments.
-
- """
diff --git a/src/lightning/app/components/training.py b/src/lightning/app/components/training.py
deleted file mode 100644
index ba198f063fb42..0000000000000
--- a/src/lightning/app/components/training.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
-
-from lightning.app.components.python import TracerPythonScript
-from lightning.app.core.flow import LightningFlow
-from lightning.app.storage.path import Path
-from lightning.app.structures import List as _List
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-_logger = Logger(__name__)
-
-
-class PyTorchLightningScriptRunner(TracerPythonScript):
- def __init__(
- self,
- script_path: str,
- script_args: Optional[Union[list, str]] = None,
- node_rank: int = 1,
- num_nodes: int = 1,
- sanity_serving: bool = False,
- cloud_compute: Optional[CloudCompute] = None,
- parallel: bool = True,
- raise_exception: bool = True,
- env: Optional[Dict[str, Any]] = None,
- **kwargs: Any,
- ):
- super().__init__(
- script_path,
- script_args,
- raise_exception=raise_exception,
- parallel=parallel,
- cloud_compute=cloud_compute,
- **kwargs,
- )
- self.node_rank = node_rank
- self.num_nodes = num_nodes
- self.best_model_path = None
- self.best_model_score = None
- self.monitor = None
- self.sanity_serving = sanity_serving
- self.has_finished = False
- self.env = env
-
- def configure_tracer(self):
- from lightning.pytorch import Trainer
-
- tracer = super().configure_tracer()
- tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
- return tracer
-
- def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs: Any) -> None:
- if not internal_urls:
- # Note: This is called only once.
- _logger.info(f"The node {self.node_rank} started !")
- return None
-
- if self.env:
- os.environ.update(self.env)
-
- distributed_env_vars = {
- "MASTER_ADDR": internal_urls[0][0],
- "MASTER_PORT": str(internal_urls[0][1]),
- "NODE_RANK": str(self.node_rank),
- "PL_TRAINER_NUM_NODES": str(self.num_nodes),
- "PL_TRAINER_DEVICES": "auto",
- "PL_TRAINER_ACCELERATOR": "auto",
- }
-
- os.environ.update(distributed_env_vars)
- return super().run(**kwargs)
-
- def on_after_run(self, script_globals):
- from lightning.pytorch import Trainer
- from lightning.pytorch.cli import LightningCLI
-
- for v in script_globals.values():
- if isinstance(v, LightningCLI):
- trainer = v.trainer
- break
- if isinstance(v, Trainer):
- trainer = v
- break
- else:
- raise RuntimeError("No trainer instance found.")
-
- self.monitor = trainer.checkpoint_callback.monitor
-
- if trainer.checkpoint_callback.best_model_score:
- self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
- self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
- else:
- self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)
-
- self.has_finished = True
-
- def _trainer_init_pre_middleware(self, trainer, *args: Any, **kwargs: Any):
- if self.node_rank != 0:
- return {}, args, kwargs
-
- from lightning.pytorch.serve import ServableModuleValidator
-
- callbacks = kwargs.get("callbacks", [])
- if self.sanity_serving:
- callbacks = callbacks + [ServableModuleValidator()]
- kwargs["callbacks"] = callbacks
- return {}, args, kwargs
-
- @property
- def is_running_in_cloud(self) -> bool:
- return "LIGHTNING_APP_STATE_URL" in os.environ
-
-
-class LightningTrainerScript(LightningFlow):
- def __init__(
- self,
- script_path: str,
- script_args: Optional[Union[list, str]] = None,
- num_nodes: int = 1,
- cloud_compute: CloudCompute = CloudCompute("default"),
- sanity_serving: bool = False,
- script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
- **script_runner_kwargs,
- ):
- """This component enables performing distributed multi-node multi-device training.
-
- Example::
-
- from lightning.app import LightningApp
- from lightning.app.components.training import LightningTrainerScript
- from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
- app = LightningApp(
- LightningTrainerScript(
- "train.py",
- num_nodes=2,
- cloud_compute=CloudCompute("gpu"),
- ),
- )
-
- Arguments:
- script_path: Path to the script to be executed.
- script_args: The arguments to be pass to the script.
- num_nodes: Number of nodes.
- cloud_compute: The cloud compute object used in the cloud.
- sanity_serving: Whether to validate that the model correctly implements
- the ServableModule API
-
- """
- super().__init__()
- self.script_path = script_path
- self.script_args = script_args
- self.num_nodes = num_nodes
- self.sanity_serving = sanity_serving
- self._script_runner = script_runner
- self._script_runner_kwargs = script_runner_kwargs
-
- self.ws = _List()
- for node_rank in range(self.num_nodes):
- self.ws.append(
- self._script_runner(
- script_path=self.script_path,
- script_args=self.script_args,
- cloud_compute=cloud_compute,
- node_rank=node_rank,
- sanity_serving=self.sanity_serving,
- num_nodes=self.num_nodes,
- **self._script_runner_kwargs,
- )
- )
-
- def run(self, **run_kwargs):
- for work in self.ws:
- if all(w.internal_ip for w in self.ws):
- internal_urls = [(w.internal_ip, w.port) for w in self.ws]
- work.run(internal_urls=internal_urls, **run_kwargs)
- if all(w.has_finished for w in self.ws):
- for w in self.ws:
- w.stop()
- else:
- work.run()
-
- @property
- def best_model_score(self) -> Optional[float]:
- return self.ws[0].best_model_score
-
- @property
- def best_model_paths(self) -> List[Optional[Path]]:
- return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]
diff --git a/src/lightning/app/core/__init__.py b/src/lightning/app/core/__init__.py
deleted file mode 100644
index cdf8b6aee1029..0000000000000
--- a/src/lightning/app/core/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from lightning.app.core.app import LightningApp
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-
-__all__ = ["LightningApp", "LightningFlow", "LightningWork"]
diff --git a/src/lightning/app/core/api.py b/src/lightning/app/core/api.py
deleted file mode 100644
index 5f50c6faa0a2b..0000000000000
--- a/src/lightning/app/core/api.py
+++ /dev/null
@@ -1,498 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import contextlib
-import json
-import os
-import queue
-import socket
-import sys
-import traceback
-from copy import deepcopy
-from multiprocessing import Queue
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from threading import Event, Lock, Thread
-from time import sleep
-from typing import Dict, List, Mapping, Optional, Union
-
-import uvicorn
-from deepdiff import DeepDiff, Delta
-from fastapi import FastAPI, File, HTTPException, Request, Response, UploadFile, WebSocket, status
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.params import Header
-from fastapi.responses import HTMLResponse, JSONResponse
-from fastapi.staticfiles import StaticFiles
-from fastapi.templating import Jinja2Templates
-from pydantic import BaseModel
-from websockets.exceptions import ConnectionClosed
-
-from lightning.app.api.http_methods import _HttpMethod
-from lightning.app.api.request_types import _DeltaRequest
-from lightning.app.core.constants import (
- ENABLE_PULLING_STATE_ENDPOINT,
- ENABLE_PUSHING_STATE_ENDPOINT,
- ENABLE_STATE_WEBSOCKET,
- ENABLE_UPLOAD_ENDPOINT,
- FRONTEND_DIR,
- get_cloud_queue_type,
-)
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.core.work import LightningWork
-from lightning.app.storage import Drive
-from lightning.app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
-from lightning.app.utilities.app_status import AppStatus
-from lightning.app.utilities.cloud import is_running_in_cloud
-from lightning.app.utilities.component import _context
-from lightning.app.utilities.enum import ComponentContext, OpenAPITags
-
-# TODO: fixed uuid for now, it will come from the FastAPI session
-TEST_SESSION_UUID = "1234"
-
-STATE_EVENT = "State changed"
-
-frontend_static_dir = os.path.join(FRONTEND_DIR, "static")
-
-api_app_delta_queue: Optional[Queue] = None
-
-template: dict = {"ui": {}, "app": {}}
-templates = Jinja2Templates(directory=FRONTEND_DIR)
-
-# TODO: try to avoid using global var for state store
-global_app_state_store = InMemoryStateStore()
-global_app_state_store.add(TEST_SESSION_UUID)
-
-lock = Lock()
-
-app_spec: Optional[List] = None
-app_status: Optional[AppStatus] = None
-app_annotations: Optional[List] = None
-
-# In the future, this would be abstracted to support horizontal scaling.
-responses_store = {}
-
-logger = Logger(__name__)
-
-# This can be replaced with a consumer that publishes states in a kv-store
-# in a serverless architecture
-
-
-class UIRefresher(Thread):
- def __init__(
- self,
- api_publish_state_queue: Queue,
- api_response_queue: Queue,
- refresh_interval: float = 0.1,
- ) -> None:
- super().__init__(daemon=True)
- self.api_publish_state_queue = api_publish_state_queue
- self.api_response_queue = api_response_queue
- self._exit_event = Event()
- self.refresh_interval = refresh_interval
-
- def run(self) -> None:
- # TODO: Create multiple threads to handle the background logic
- # TODO: Investigate the use of `parallel=True`
- try:
- while not self._exit_event.is_set():
- self.run_once()
- # Note: Sleep to reduce queue calls.
- sleep(self.refresh_interval)
- except Exception as ex:
- traceback.print_exc()
- raise ex
-
- def run_once(self) -> None:
- with contextlib.suppress(queue.Empty):
- global app_status
- state, app_status = self.api_publish_state_queue.get(timeout=0)
- with lock:
- global_app_state_store.set_app_state(TEST_SESSION_UUID, state)
-
- with contextlib.suppress(queue.Empty):
- responses = self.api_response_queue.get(timeout=0)
- with lock:
- # TODO: Abstract the responses store to support horizontal scaling.
- global responses_store
- for response in responses:
- responses_store[response["id"]] = response["response"]
-
- def join(self, timeout: Optional[float] = None) -> None:
- self._exit_event.set()
- super().join(timeout)
-
-
-class StateUpdate(BaseModel):
- state: dict = {}
-
-
-openapi_tags = [
- {
- "name": OpenAPITags.APP_CLIENT_COMMAND,
- "description": "The App Endpoints to be triggered exclusively from the CLI",
- },
- {
- "name": OpenAPITags.APP_COMMAND,
- "description": "The App Endpoints that can be triggered equally from the CLI or from a Http Request",
- },
- {
- "name": OpenAPITags.APP_API,
- "description": "The App Endpoints that can be triggered exclusively from a Http Request",
- },
-]
-
-app = FastAPI(openapi_tags=openapi_tags)
-
-fastapi_service = FastAPI()
-
-fastapi_service.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-
-# General sequence is:
-# * an update is generated in the UI
-# * the value and the location in the state (or the whole state, easier)
-# is sent to the REST API along with the session UID
-# * the previous state is loaded from the cache, the delta is generated
-# * the previous state is set as set_state, the delta is provided as
-# delta
-# * the app applies the delta and runs the entry_fn, which eventually
-# leads to another state
-# * the new state is published through the API
-# * the UI is updated with the new value of the state
-# Before the above happens, we need to refactor App so that it doesn't
-# rely on timeouts, but on sequences of updates (and alignments between
-# ranks)
-@fastapi_service.get("/api/v1/state", response_class=JSONResponse)
-async def get_state(
- response: Response,
- x_lightning_type: Optional[str] = Header(None),
- x_lightning_session_uuid: Optional[str] = Header(None),
- x_lightning_session_id: Optional[str] = Header(None),
-) -> Mapping:
- if x_lightning_session_uuid is None:
- raise Exception("Missing X-Lightning-Session-UUID header")
- if x_lightning_session_id is None:
- raise Exception("Missing X-Lightning-Session-ID header")
-
- if not ENABLE_PULLING_STATE_ENDPOINT:
- response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- return {"status": "failure", "reason": "This endpoint is disabled."}
-
- with lock:
- x_lightning_session_uuid = TEST_SESSION_UUID
- state = global_app_state_store.get_app_state(x_lightning_session_uuid)
- global_app_state_store.set_served_state(x_lightning_session_uuid, state)
- return state
-
-
-def _get_component_by_name(component_name: str, state: dict) -> Union[LightningFlow, LightningWork]:
- child = state
- for child_name in component_name.split(".")[1:]:
- try:
- child = child["flows"][child_name]
- except KeyError:
- child = child["structures"][child_name]
-
- if isinstance(child["vars"]["_layout"], list):
- assert len(child["vars"]["_layout"]) == 1
- return child["vars"]["_layout"][0]["target"]
- return child["vars"]["_layout"]["target"]
-
-
-@fastapi_service.get("/api/v1/layout", response_class=JSONResponse)
-async def get_layout() -> str:
- with lock:
- x_lightning_session_uuid = TEST_SESSION_UUID
- state = global_app_state_store.get_app_state(x_lightning_session_uuid)
- global_app_state_store.set_served_state(x_lightning_session_uuid, state)
- layout = deepcopy(state["vars"]["_layout"])
- for la in layout:
- if la["content"].startswith("root."):
- la["content"] = _get_component_by_name(la["content"], state)
- return json.dumps(layout)
-
-
-@fastapi_service.get("/api/v1/spec", response_class=JSONResponse)
-async def get_spec(
- response: Response,
- x_lightning_session_uuid: Optional[str] = Header(None),
- x_lightning_session_id: Optional[str] = Header(None),
-) -> Union[List, Dict]:
- if x_lightning_session_uuid is None:
- raise Exception("Missing X-Lightning-Session-UUID header")
- if x_lightning_session_id is None:
- raise Exception("Missing X-Lightning-Session-ID header")
-
- if not ENABLE_PULLING_STATE_ENDPOINT:
- response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- return {"status": "failure", "reason": "This endpoint is disabled."}
-
- global app_spec
- return app_spec or []
-
-
-@fastapi_service.post("/api/v1/delta")
-async def post_delta(
- request: Request,
- response: Response,
- x_lightning_type: Optional[str] = Header(None),
- x_lightning_session_uuid: Optional[str] = Header(None),
- x_lightning_session_id: Optional[str] = Header(None),
-) -> Optional[Dict]:
- """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
- the state."""
-
- if x_lightning_session_uuid is None:
- raise Exception("Missing X-Lightning-Session-UUID header")
- if x_lightning_session_id is None:
- raise Exception("Missing X-Lightning-Session-ID header")
-
- if not ENABLE_PUSHING_STATE_ENDPOINT:
- response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- return {"status": "failure", "reason": "This endpoint is disabled."}
-
- body: Dict = await request.json()
- assert api_app_delta_queue is not None
- api_app_delta_queue.put(_DeltaRequest(delta=Delta(body["delta"])))
- return None
-
-
-@fastapi_service.post("/api/v1/state")
-async def post_state(
- request: Request,
- response: Response,
- x_lightning_type: Optional[str] = Header(None),
- x_lightning_session_uuid: Optional[str] = Header(None),
- x_lightning_session_id: Optional[str] = Header(None),
-) -> Optional[Dict]:
- if x_lightning_session_uuid is None:
- raise Exception("Missing X-Lightning-Session-UUID header")
- if x_lightning_session_id is None:
- raise Exception("Missing X-Lightning-Session-ID header")
- # This needs to be sent so that it can be set as last state
- # in app (see sequencing above)
- # Actually: we need to make sure last_state is actually
- # the latest state seen by the UI, that is, the last state
- # ui to the UI from the API, not the last state
- # obtained by the app.
- body: Dict = await request.json()
- x_lightning_session_uuid = TEST_SESSION_UUID
-
- if not ENABLE_PUSHING_STATE_ENDPOINT:
- response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- return {"status": "failure", "reason": "This endpoint is disabled."}
-
- if "stage" in body:
- last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
- state = deepcopy(last_state)
- state["app_state"]["stage"] = body["stage"]
- deep_diff = DeepDiff(last_state, state, verbose_level=2)
- else:
- state = body["state"]
- last_state = global_app_state_store.get_served_state(x_lightning_session_uuid)
- deep_diff = DeepDiff(last_state, state, verbose_level=2)
- assert api_app_delta_queue is not None
- api_app_delta_queue.put(_DeltaRequest(delta=Delta(deep_diff)))
- return None
-
-
-@fastapi_service.put("/api/v1/upload_file/{filename}")
-async def upload_file(response: Response, filename: str, uploaded_file: UploadFile = File(...)) -> Union[str, dict]:
- if not ENABLE_UPLOAD_ENDPOINT:
- response.status_code = status.HTTP_405_METHOD_NOT_ALLOWED
- return {"status": "failure", "reason": "This endpoint is disabled."}
-
- with TemporaryDirectory() as tmp:
- drive = Drive(
- "lit://uploaded_files",
- component_name="file_server",
- allow_duplicates=True,
- root_folder=tmp,
- )
- tmp_file = os.path.join(tmp, filename)
-
- with open(tmp_file, "wb") as f:
- done = False
- while not done:
- # Note: The 8192 number doesn't have a strong reason.
- content = await uploaded_file.read(8192)
- f.write(content)
- done = content == b""
-
- with _context(str(ComponentContext.WORK)):
- drive.put(filename)
- return f"Successfully uploaded '{filename}' to the Drive"
-
-
-@fastapi_service.get("/api/v1/status", response_model=AppStatus)
-async def get_status() -> AppStatus:
- """Get the current status of the app and works."""
- global app_status
- if app_status is None:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="App status hasn't been reported yet."
- )
- return app_status
-
-
-@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse)
-async def get_annotations() -> Union[List, Dict]:
- """Get the annotations associated with this app."""
- global app_annotations
- return app_annotations or []
-
-
-@fastapi_service.get("/healthz", status_code=200)
-async def healthz(response: Response) -> dict:
- """Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
- # check the queue status only if running in cloud
- if is_running_in_cloud():
- queue_obj = QueuingSystem(get_cloud_queue_type()).get_queue(queue_name="healthz")
- # this is only being implemented on Redis Queue. For HTTP Queue, it doesn't make sense to have every single
- # app checking the status of the Queue server
- if not queue_obj.is_running:
- response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- return {"status": "failure", "reason": "Redis is not available"}
- x_lightning_session_uuid = TEST_SESSION_UUID
- state = global_app_state_store.get_app_state(x_lightning_session_uuid)
- global_app_state_store.set_served_state(x_lightning_session_uuid, state)
- if not state:
- response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- return {"status": "failure", "reason": f"State is empty {state}"}
- return {"status": "ok"}
-
-
-# Creates session websocket connection to notify client about any state changes
-# The websocket instance needs to be stored based on session id so it is accessible in the api layer
-@fastapi_service.websocket("/api/v1/ws")
-async def websocket_endpoint(websocket: WebSocket) -> None:
- await websocket.accept()
- if not ENABLE_STATE_WEBSOCKET:
- await websocket.close()
- return
- try:
- counter = global_app_state_store.counter
- while True:
- if global_app_state_store.counter != counter:
- await websocket.send_text(f"{global_app_state_store.counter}")
- counter = global_app_state_store.counter
- logger.debug("Updated websocket.")
- await asyncio.sleep(0.01)
- except ConnectionClosed:
- logger.debug("Websocket connection closed")
- await websocket.close()
-
-
-async def api_catch_all(request: Request, full_path: str) -> None:
- raise HTTPException(status_code=404, detail="Not found")
-
-
-# Serve frontend from a static directory using FastAPI
-fastapi_service.mount("/static", StaticFiles(directory=frontend_static_dir, check_dir=False), name="static")
-
-
-async def frontend_route(request: Request, full_path: str): # type: ignore[no-untyped-def]
- if "pytest" in sys.modules:
- return ""
- return templates.TemplateResponse("index.html", {"request": request})
-
-
-def register_global_routes() -> None:
- # Catch-all for nonexistent API routes (since we define a catch-all for client-side routing)
- fastapi_service.get("/api{full_path:path}", response_class=JSONResponse)(api_catch_all)
- fastapi_service.get("/{full_path:path}", response_class=HTMLResponse)(frontend_route)
-
-
-class LightningUvicornServer(uvicorn.Server):
- has_started_queue: Optional[Queue] = None
-
- def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
- self.config.setup_event_loop()
- loop = asyncio.get_event_loop()
- asyncio.ensure_future(self.serve(sockets=sockets))
- if self.has_started_queue:
- asyncio.ensure_future(self.check_is_started(self.has_started_queue))
- loop.run_forever()
-
- async def check_is_started(self, queue: Queue) -> None:
- while not self.started:
- await asyncio.sleep(0.1)
- queue.put("SERVER_HAS_STARTED")
-
-
-def start_server(
- api_publish_state_queue: Queue,
- api_delta_queue: Queue,
- api_response_queue: Queue,
- has_started_queue: Optional[Queue] = None,
- host: str = "127.0.0.1",
- port: int = 8000,
- root_path: str = "",
- uvicorn_run: bool = True,
- spec: Optional[List] = None,
- apis: Optional[List[_HttpMethod]] = None,
- app_state_store: Optional[StateStore] = None,
-) -> UIRefresher:
- global api_app_delta_queue
- global global_app_state_store
- global app_spec
- global app_annotations
-
- app_spec = spec
- api_app_delta_queue = api_delta_queue
-
- if app_state_store is not None:
- global_app_state_store = app_state_store # type: ignore[assignment]
-
- global_app_state_store.add(TEST_SESSION_UUID)
-
- # Load annotations
- annotations_path = Path("lightning-annotations.json").resolve()
- if annotations_path.exists():
- with open(annotations_path) as f:
- app_annotations = json.load(f)
-
- refresher = UIRefresher(api_publish_state_queue, api_response_queue)
- refresher.setDaemon(True)
- refresher.start()
-
- if uvicorn_run:
- host = host.split("//")[-1] if "//" in host else host
- if host == "0.0.0.0": # noqa: S104
- logger.info("Your app has started.")
- else:
- logger.info(f"Your app has started. View it in your browser: http://{host}:{port}/view")
- if has_started_queue:
- LightningUvicornServer.has_started_queue = has_started_queue
- # uvicorn is doing some uglyness by replacing uvicorn.main by click command.
- sys.modules["uvicorn.main"].Server = LightningUvicornServer
-
- # Register the user API.
- if apis:
- for api in apis:
- api.add_route(fastapi_service, api_app_delta_queue, responses_store)
-
- register_global_routes()
-
- uvicorn.run(app=fastapi_service, host=host, port=port, log_level="error", root_path=root_path)
-
- return refresher
diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py
deleted file mode 100644
index b9ff54f9a8852..0000000000000
--- a/src/lightning/app/core/app.py
+++ /dev/null
@@ -1,737 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import pickle
-import queue
-import threading
-import warnings
-from copy import deepcopy
-from time import time
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
-
-from deepdiff import DeepDiff, Delta
-from lightning_utilities.core.apply_func import apply_to_collection
-
-import lightning.app
-from lightning.app import _console
-from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
-from lightning.app.core.constants import (
- BATCH_DELTA_COUNT,
- DEBUG_ENABLED,
- FLOW_DURATION_SAMPLES,
- FLOW_DURATION_THRESHOLD,
- FRONTEND_DIR,
- STATE_ACCUMULATE_WAIT,
-)
-from lightning.app.core.queues import BaseQueue
-from lightning.app.core.work import LightningWork
-from lightning.app.frontend import Frontend
-from lightning.app.storage import Drive, Path, Payload
-from lightning.app.storage.path import _storage_root_dir
-from lightning.app.utilities import frontend
-from lightning.app.utilities.app_helpers import (
- Logger,
- _delta_to_app_state_delta,
- _LightningAppRef,
- _should_dispatch_app,
-)
-from lightning.app.utilities.app_status import AppStatus
-from lightning.app.utilities.commands.base import _process_requests
-from lightning.app.utilities.component import _convert_paths_after_init, _validate_root_flow
-from lightning.app.utilities.enum import AppStage, CacheCallsKeys
-from lightning.app.utilities.exceptions import CacheMissException, ExitAppException, LightningFlowException
-from lightning.app.utilities.layout import _collect_layout
-from lightning.app.utilities.proxies import ComponentDelta
-from lightning.app.utilities.scheduler import SchedulerThread
-from lightning.app.utilities.tree import breadth_first
-from lightning.app.utilities.warnings import LightningFlowWarning
-
-if TYPE_CHECKING:
- from lightning.app.core.flow import LightningFlow
- from lightning.app.runners.backends.backend import Backend, WorkManager
- from lightning.app.runners.runtime import Runtime
- from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-
-logger = Logger(__name__)
-
-
-class LightningApp:
- def __init__(
- self,
- root: Union["LightningFlow", LightningWork],
- flow_cloud_compute: Optional["CloudCompute"] = None,
- log_level: str = "info",
- info: Optional[frontend.AppInfo] = None,
- root_path: str = "",
- ) -> None:
- """The Lightning App, or App in short runs a tree of one or more components that interact to create end-to-end
- applications. There are two kinds of components: :class:`~lightning.app.core.flow.LightningFlow` and
- :class:`~lightning.app.core.work.LightningWork`. This modular design enables you to reuse components created by
- other users.
-
- The Lightning App alternatively run an event loop triggered by delta changes sent from
- either :class:`~lightning.app.core.work.LightningWork` or from the Lightning UI.
- Once deltas are received, the Lightning App runs
- the :class:`~lightning.app.core.flow.LightningFlow` provided.
-
- Arguments:
- root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
- components, running infinitely. It must define a `run()` method that the app can call.
- flow_cloud_compute: The default Cloud Compute used for flow, Rest API and frontend's.
- log_level: The log level for the app, one of [`info`, `debug`].
- This can be helpful when reporting bugs on Lightning repo.
- info: Provide additional info about the app which will be used to update html title,
- description and image meta tags and specify any additional tags as list of html strings.
- root_path: Set this to `/path` if you want to run your app behind a proxy at `/path` leave empty for "/".
- For instance, if you want to run your app at `https://customdomain.com/myapp`,
- set `root_path` to `/myapp`.
- You can learn more about proxy `here `_.
-
- """
-
- self.root_path = root_path # when running behind a proxy
- self.info = info
-
- from lightning.app.core.flow import _RootFlow
-
- if isinstance(root, LightningWork):
- root = _RootFlow(root)
-
- _validate_root_flow(root)
- self._root = root
- self.flow_cloud_compute = flow_cloud_compute or lightning.app.CloudCompute(name="flow-lite")
-
- # queues definition.
- self.delta_queue: Optional[BaseQueue] = None
- self.readiness_queue: Optional[BaseQueue] = None
- self.api_response_queue: Optional[BaseQueue] = None
- self.api_publish_state_queue: Optional[BaseQueue] = None
- self.api_delta_queue: Optional[BaseQueue] = None
- self.error_queue: Optional[BaseQueue] = None
- self.request_queues: Optional[Dict[str, BaseQueue]] = None
- self.response_queues: Optional[Dict[str, BaseQueue]] = None
- self.copy_request_queues: Optional[Dict[str, BaseQueue]] = None
- self.copy_response_queues: Optional[Dict[str, BaseQueue]] = None
- self.caller_queues: Optional[Dict[str, BaseQueue]] = None
- self.flow_to_work_delta_queues: Optional[Dict[str, BaseQueue]] = None
- self.work_queues: Optional[Dict[str, BaseQueue]] = None
- self.commands: Optional[List] = None
-
- self.should_publish_changes_to_api = False
- self.component_affiliation = None
- self.backend: Optional["Backend"] = None
- _LightningAppRef.connect(self)
- self.processes: Dict[str, "WorkManager"] = {}
- self.frontends: Dict[str, Frontend] = {}
- self.stage = AppStage.RUNNING
- self._has_updated: bool = True
- self._schedules: Dict[str, Dict] = {}
- self.threads: List[threading.Thread] = []
- self.exception = None
- self.collect_changes: bool = True
-
- self.status: Optional[AppStatus] = None
- # TODO: Enable ready locally for opening the UI.
- self.ready = False
-
- # NOTE: Checkpointing is disabled by default for the time being. We
- # will enable it when resuming from full checkpoint is supported. Also,
- # we will need to revisit the logic at _should_snapshot, since right now
- # we are writing checkpoints too often, and this is expensive.
- self.checkpointing: bool = False
-
- self._update_layout()
- self._update_status()
-
- self.is_headless: Optional[bool] = None
-
- self._original_state: Optional[dict] = None
- self._last_state: dict = self.state
- self.state_accumulate_wait = STATE_ACCUMULATE_WAIT
-
- self._last_run_time: float = 0.0
- self._run_times: list = []
-
- # Path attributes can't get properly attached during the initialization, because the full name
- # is only available after all Flows and Works have been instantiated.
- _convert_paths_after_init(self.root) # type: ignore[arg-type]
-
- if log_level not in ("debug", "info"):
- raise Exception(f"Log Level should be in ['debug', 'info']. Found {log_level}")
-
- # Lazily enable debugging.
- if log_level == "debug" or DEBUG_ENABLED:
- if not DEBUG_ENABLED:
- os.environ["LIGHTNING_DEBUG"] = "2"
- _console.setLevel(logging.DEBUG)
-
- logger.debug(f"ENV: {os.environ}")
-
- if _should_dispatch_app():
- os.environ["LIGHTNING_DISPATCHED"] = "1"
- from lightning.app.runners import MultiProcessRuntime
-
- MultiProcessRuntime(self).dispatch()
-
- def _update_index_file(self) -> None:
- # update index.html,
- # this should happen once for all apps before the ui server starts running.
- frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path)
-
- def get_component_by_name(self, component_name: str) -> Union["LightningFlow", LightningWork]:
- """Returns the instance corresponding to the given component name."""
- from lightning.app.structures import Dict as LightningDict
- from lightning.app.structures import List as LightningList
- from lightning.app.utilities.types import ComponentTuple
-
- if component_name == "root":
- return self.root
- if not component_name.startswith("root."):
- raise ValueError(f"Invalid component name {component_name}. Name must start with 'root'")
-
- current = self.root
- for child_name in component_name.split(".")[1:]:
- if isinstance(current, LightningDict):
- child = current[child_name]
- elif isinstance(current, LightningList):
- child = current[int(child_name)]
- else:
- child = getattr(current, child_name, None)
- if not isinstance(child, ComponentTuple):
- raise AttributeError(f"Component '{current.name}' has no child component with name '{child_name}'.")
- current = child # type: ignore[assignment]
- return current
-
- def _reset_original_state(self) -> None:
- assert self._original_state is not None
- self.set_state(self._original_state)
-
- @property
- def root(self) -> Union["LightningFlow", LightningWork]:
- """Returns the root component of the application."""
- return self._root
-
- @property
- def state(self) -> dict:
- """Return the current state of the application."""
- state = self.root.state
- state["app_state"] = {"stage": self.stage.value}
- return state
-
- @property
- def state_vars(self) -> dict:
- """Return the current state restricted to the user defined variables of the application."""
- state_vars = self.root.state_vars
- state_vars["app_state"] = {"stage": self.stage.value}
- return state_vars
-
- @property
- def state_with_changes(self) -> dict:
- """Return the current state with the new changes of the application."""
- state_with_changes = self.root.state_with_changes
- state_with_changes["app_state"] = {"stage": self.stage.value}
- return state_with_changes
-
- def set_state(self, state: dict) -> None:
- """Method to set a new app state set to the application."""
- self.set_last_state(state)
- self.root.set_state(state)
- self.stage = AppStage(state["app_state"]["stage"])
-
- @property
- def last_state(self) -> dict:
- """Returns the latest state."""
- return self._last_state
-
- @property
- def checkpoint_dir(self) -> str:
- return os.path.join(str(_storage_root_dir()), "checkpoints")
-
- def remove_changes_(self, state: dict) -> None:
- for _, child in state["flows"].items():
- self.remove_changes(child)
- state["changes"] = {}
-
- def remove_changes(self, state: dict) -> dict:
- state = deepcopy(state)
- for _, child in state["flows"].items():
- self.remove_changes_(child)
- state["changes"] = {}
- return state
-
- def set_last_state(self, state: dict) -> None:
- self._last_state = self.remove_changes(state)
-
- @staticmethod
- def populate_changes(last_state: dict, new_state: dict) -> dict:
- diff = DeepDiff(last_state, new_state, view="tree", verbose_level=2)
-
- changes_categories = [diff[key] for key in diff.to_dict()]
-
- if not changes_categories:
- return new_state
-
- for change_category in changes_categories:
- for entry in change_category:
- state_el = new_state
- change = entry.path(output_format="list")
- if "vars" not in change:
- continue
- for change_el in change:
- if change_el == "vars":
- if "changes" not in state_el:
- state_el["changes"] = {}
- state_el["changes"][change[-1]] = {"from": entry.t1, "to": entry.t2}
- break
- # move down in the dictionary
- state_el = state_el[change_el]
- return new_state
-
- @staticmethod
- def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> Optional[dict]:
- try:
- timeout = timeout or q.default_timeout
- return q.get(timeout=timeout)
- except queue.Empty:
- return None
-
- @staticmethod
- def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:
- try:
- timeout = timeout or q.default_timeout
- return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)
- except queue.Empty:
- return []
-
- def check_error_queue(self) -> None:
- exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]
- if isinstance(exception, Exception):
- self.exception = exception
- self.stage = AppStage.FAILED
-
- @property
- def flows(self) -> List[Union[LightningWork, "LightningFlow"]]:
- """Returns all the flows defined within this application."""
- return [self.root] + list(self.root.flows.values())
-
- @property
- def works(self) -> List[LightningWork]:
- """Returns all the works defined within this application."""
- return self.root.works(recurse=True)
-
- @property
- def named_works(self) -> List[Tuple[str, LightningWork]]:
- """Returns all the works defined within this application with their names."""
- return self.root.named_works(recurse=True)
-
- def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIRequest, _CommandRequest]]:
- # The aggregation would try to get as many deltas as possible
- # from both the `api_delta_queue` and `delta_queue`
- # during the `state_accumulate_wait` time.
- # The while loop can exit sooner if both queues are empty.
-
- deltas = []
- api_or_command_request_deltas = []
- t0 = time()
-
- while (time() - t0) < self.state_accumulate_wait:
- # TODO: Fetch all available deltas at once to reduce queue calls.
- received_deltas: List[Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]] = (
- self.batch_get_state_changed_from_queue(
- self.delta_queue # type: ignore[assignment,arg-type]
- )
- )
- if len(received_deltas) == []:
- break
-
- for delta in received_deltas:
- if isinstance(delta, _DeltaRequest):
- deltas.append(delta.delta)
- elif isinstance(delta, ComponentDelta):
- logger.debug(f"Received from {delta.id} : {delta.delta.to_dict()}")
- work = None
- try:
- work = self.get_component_by_name(delta.id)
- except (KeyError, AttributeError) as ex:
- logger.error(f"The component {delta.id} couldn't be accessed. Exception: {ex}")
-
- if work:
- delta = _delta_to_app_state_delta(
- self.root, # type: ignore[arg-type]
- work,
- deepcopy(delta.delta),
- )
- deltas.append(delta)
- else:
- api_or_command_request_deltas.append(delta)
-
- if api_or_command_request_deltas:
- _process_requests(self, api_or_command_request_deltas)
-
- for delta in deltas:
- # When aggregating deltas from the UI and the Works, and over the accumulation time window,
- # it can happen that deltas from these different sources disagree. Since deltas are computed on the Work
- # and UI side separately, correctness of the aggregation can only be guaranteed if both components compute
- # the delta based on the same base state. But this assumption does not hold in general, and there is no way
- # for the Flow to reject or resolve these deltas properly at the moment. Hence, we decide to ignore
- # errors coming from deepdiff when adding deltas together by setting:
- delta.log_errors = False # type: ignore[union-attr]
- delta.raise_errors = False # type: ignore[union-attr]
- return deltas
-
- def maybe_apply_changes(self) -> Optional[bool]:
- """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""
- self._send_flow_to_work_deltas(self.state)
-
- if not self.collect_changes:
- return None
-
- deltas = self._collect_deltas_from_ui_and_work_queues()
-
- if not deltas:
- # Path and Drive aren't processed by DeepDiff, so we need to convert them to dict.
- last_state = apply_to_collection(self.last_state, (Path, Drive), lambda x: x.to_dict())
- state = apply_to_collection(self.state, (Path, Drive), lambda x: x.to_dict())
- # When no deltas are received from the Rest API or work queues,
- # we need to check if the flow modified the state and populate changes.
- deep_diff = DeepDiff(last_state, state, verbose_level=2)
-
- if "unprocessed" in deep_diff:
- # pop the unprocessed key.
- unprocessed = deep_diff.pop("unprocessed")
- logger.warn(f"It seems delta differentiation resulted in {unprocessed}. Open an issue on Github.")
-
- if deep_diff:
- # TODO: Resolve changes with ``CacheMissException``.
- # new_state = self.populate_changes(self.last_state, self.state)
- self.set_last_state(self.state)
- self._has_updated = True
- return False
-
- logger.debug(f"Received {[d.to_dict() for d in deltas]}")
-
- # 2: Collect the state
- state = self.state
-
- # 3: Apply the state delta
- for delta in deltas:
- try:
- state += delta
- except Exception as ex:
- raise Exception(f"Current State {state}, {delta.to_dict()}") from ex
-
- # new_state = self.populate_changes(self.last_state, state)
- self.set_state(state)
- self._has_updated = True
- return None
-
- def run_once(self) -> bool:
- """Method used to collect changes and run the root Flow once."""
- done = False
- self._last_run_time = 0.0
-
- if self.backend is not None:
- self.backend.update_work_statuses(self.works)
-
- self._update_layout()
- self._update_status()
- self.maybe_apply_changes()
-
- if self.checkpointing and self._should_snapshot():
- self._dump_checkpoint()
-
- if self.stage == AppStage.BLOCKING:
- return done
-
- if self.stage in (AppStage.STOPPING, AppStage.FAILED):
- return True
-
- if self.stage == AppStage.RESTARTING:
- return self._apply_restarting()
-
- t0 = time()
-
- try:
- self.check_error_queue()
- # Execute the flow only if:
- # - There are state changes
- # - It is the first execution of the flow
- if self._has_updated:
- self.root.run()
- except CacheMissException:
- self._on_cache_miss_exception()
- except LightningFlowException:
- done = True
- self.stage = AppStage.FAILED
- except (ExitAppException, KeyboardInterrupt):
- done = True
- self.stage = AppStage.STOPPING
-
- if not self.ready:
- self.ready = self.root.ready
-
- self._last_run_time = time() - t0
-
- self.on_run_once_end()
- return done
-
- def _reset_run_time_monitor(self) -> None:
- self._run_times = [0.0] * FLOW_DURATION_SAMPLES
-
- def _update_run_time_monitor(self) -> None:
- self._run_times[:-1] = self._run_times[1:]
- self._run_times[-1] = self._last_run_time
-
- # Here we underestimate during the first FLOW_DURATION_SAMPLES
- # iterations, but that's ok for our purposes
- avg_elapsed_time = sum(self._run_times) / FLOW_DURATION_SAMPLES
-
- if avg_elapsed_time > FLOW_DURATION_THRESHOLD:
- warnings.warn(
- "The execution of the `run` method of the root flow is taking too long. "
- "Flow is supposed to only host coordination logic, while currently it is"
- "likely to contain long-running calls, code that performs meaningful "
- "computations or that makes blocking or asynchronous calls to third-party "
- "services. If that is the case, you should move those pieces to a Work, "
- "and make sure Flow can complete its execution in under a second.",
- LightningFlowWarning,
- )
-
- def _run(self) -> bool:
- """Entry point of the LightningApp.
-
- This would be dispatched by the Runtime objects.
-
- """
- self._original_state = deepcopy(self.state)
- done = False
-
- self.ready = self.root.ready
-
- self._start_with_flow_works()
-
- if self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
- self.api_publish_state_queue.put((self.state_vars, self.status))
-
- self._reset_run_time_monitor()
-
- while not done:
- done = self.run_once()
-
- self._update_run_time_monitor()
-
- if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
- self.api_publish_state_queue.put((self.state_vars, self.status))
-
- self._has_updated = False
-
- self._on_run_end()
-
- return True
-
- def _update_layout(self) -> None:
- if self.backend:
- self.backend.resolve_url(self, base_url=None)
-
- for component in breadth_first(self.root, types=(lightning.app.LightningFlow,)): # type: ignore[arg-type]
- layout = _collect_layout(self, component)
- component._layout = layout
-
- def _update_status(self) -> None:
- old_status = self.status
-
- work_statuses = {}
- assert self.root is not None
- for work in breadth_first(self.root, types=(lightning.app.LightningWork,)): # type: ignore[arg-type]
- work_statuses[work.name] = work.status
-
- self.status = AppStatus(
- is_ui_ready=self.ready,
- work_statuses=work_statuses,
- )
-
- # If the work statuses changed, the state delta will trigger an update.
- # If ready has changed, we trigger an update manually.
- if self.status != old_status:
- self._has_updated = True
-
- def _apply_restarting(self) -> bool:
- self._reset_original_state()
- # apply stage after restoring the original state.
- self.stage = AppStage.BLOCKING
- return False
-
- def _has_work_finished(self, work: LightningWork) -> bool:
- latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
- if latest_call_hash is None:
- return False
- return "ret" in work._calls[latest_call_hash]
-
- def _collect_work_finish_status(self) -> dict:
- work_finished_status = {work.name: self._has_work_finished(work) for work in self.works}
- assert len(work_finished_status) == len(self.works)
- return work_finished_status
-
- def _should_snapshot(self) -> bool:
- if len(self.works) == 0:
- return True
- if self._has_updated:
- work_finished_status = self._collect_work_finish_status()
- if work_finished_status:
- return all(work_finished_status.values())
- return True
- return False
-
- def state_dict(self) -> Dict:
- return self.state
-
- def load_state_dict(self, state: Dict) -> None:
- self.set_state(state)
-
- def load_state_dict_from_checkpoint_dir(
- self,
- checkpoints_dir: str,
- version: Optional[int] = None,
- ) -> None:
- if not os.path.exists(checkpoints_dir):
- raise FileNotFoundError(f"The provided directory `{checkpoints_dir}` doesn't exist.")
- checkpoints = [f for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")]
- if not checkpoints:
- raise Exception(f"No checkpoints where found in `{checkpoints_dir}`.")
-
- if version is None:
- # take the latest checkpoint.
- version = sorted(int(c.split("_")[1]) for c in checkpoints)[-1]
-
- available_checkpoints = [c for c in checkpoints if c.startswith(f"v_{version}_")]
- if not available_checkpoints:
- raise FileNotFoundError(f"The version `{version}` wasn't found in {checkpoints}.")
- if len(available_checkpoints) > 1:
- raise Exception(f"Found 2 checkpoints `{available_checkpoints}`with the same version.")
- checkpoint_path = os.path.join(checkpoints_dir, available_checkpoints[0])
- with open(checkpoint_path, "rb") as fo:
- state = pickle.load(fo)
- self.load_state_dict(state)
-
- def _dump_checkpoint(self) -> Optional[str]:
- checkpoints_dir = self.checkpoint_dir
- # TODO: Add supports to remotely saving checkpoints.
- if checkpoints_dir.startswith("s3:"):
- return None
- os.makedirs(checkpoints_dir, exist_ok=True)
-
- # Get all current version within the provided folder and sort them
- checkpoint_versions = sorted(
- int(f.split("_")[1]) for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")
- )
-
- previous_version = checkpoint_versions[-1] if checkpoint_versions else -1
-
- checkpoint_path = os.path.join(checkpoints_dir, f"v_{previous_version + 1}_{time()}.json")
-
- with open(checkpoint_path, "wb") as f:
- pickle.dump(self.state_dict(), f)
- return checkpoint_path
-
- def connect(self, runtime: "Runtime") -> None:
- """Override to customize your application to the runtime."""
- pass
-
- def _on_cache_miss_exception(self) -> None:
- if self._has_updated:
- self._update_layout()
-
- def _register_schedule(self, schedule_hash: str, schedule_metadata: Dict) -> None:
- # create a thread only if a user uses the flow's schedule method.
- if not self._schedules:
- scheduler_thread = SchedulerThread(self)
- scheduler_thread.setDaemon(True)
- self.threads.append(scheduler_thread)
- self.threads[-1].start()
- self._schedules[schedule_hash] = deepcopy(schedule_metadata)
-
- def on_run_once_end(self) -> None:
- if not self._schedules:
- return
- # disable any flow schedules.
- for flow in self.flows:
- flow._disable_running_schedules()
-
- def _on_run_end(self) -> None:
- if os.getenv("LIGHTNING_DEBUG") == "2":
- del os.environ["LIGHTNING_DEBUG"]
- _console.setLevel(logging.INFO)
-
- @staticmethod
- def _extract_vars_from_component_name(component_name: str, state: dict) -> Optional[dict]:
- child = state
- for child_name in component_name.split(".")[1:]:
- if child_name in child["flows"]:
- child = child["flows"][child_name]
- elif "structures" in child and child_name in child["structures"]:
- child = child["structures"][child_name]
- elif child_name in child["works"]:
- child = child["works"][child_name]
- else:
- return None
-
- # Filter private keys and drives
- return {
- k: v
- for k, v in child["vars"].items()
- if (
- not k.startswith("_")
- and not (isinstance(v, dict) and v.get("type", None) == "__drive__")
- and not (isinstance(v, (Payload, Path)))
- )
- }
-
- def _send_flow_to_work_deltas(self, state: dict) -> None:
- if not self.flow_to_work_delta_queues:
- return
-
- for w in self.works:
- if not w.has_started:
- continue
-
- # Don't send changes when the state has been just sent.
- if w.run.has_sent:
- continue
-
- state_work = self._extract_vars_from_component_name(w.name, state)
- last_state_work = self._extract_vars_from_component_name(w.name, self._last_state)
-
- # Note: The work was dynamically created or deleted.
- if state_work is None or last_state_work is None:
- continue
-
- deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict()
-
- if "unprocessed" in deep_diff:
- deep_diff.pop("unprocessed")
-
- if deep_diff:
- logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")
- self.flow_to_work_delta_queues[w.name].put(deep_diff)
-
- def _start_with_flow_works(self) -> None:
- for w in self.works:
- if w._start_with_flow:
- parallel = w.parallel
- w._parallel = True
- w.start()
- w._parallel = parallel
diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py
deleted file mode 100644
index f33278e5bf5ca..0000000000000
--- a/src/lightning/app/core/constants.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from pathlib import Path
-from typing import Optional
-
-import lightning_cloud.env
-
-
-def get_lightning_cloud_url() -> str:
- # detect local development
- if os.getenv("VSCODE_PROXY_URI", "").startswith("http://localhost:9800"):
- return "http://localhost:9800"
- # DO NOT CHANGE!
- return os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
-
-
-SUPPORTED_PRIMITIVE_TYPES = (type(None), str, int, float, bool)
-STATE_UPDATE_TIMEOUT = 0.001
-STATE_ACCUMULATE_WAIT = 0.15
-# Duration in seconds of a moving average of a full flow execution
-# beyond which an exception is raised.
-FLOW_DURATION_THRESHOLD = 1.0
-# Number of samples for the moving average of the duration of flow execution
-FLOW_DURATION_SAMPLES = 5
-
-APP_SERVER_HOST = os.getenv("LIGHTNING_APP_STATE_URL", "http://127.0.0.1")
-APP_SERVER_IN_CLOUD = "http://lightningapp" in APP_SERVER_HOST
-APP_SERVER_PORT = 7501
-APP_STATE_MAX_SIZE_BYTES = 1024 * 1024 # 1 MB
-
-WARNING_QUEUE_SIZE = 1000
-# different flag because queue debug can be very noisy, and almost always not useful unless debugging the queue itself.
-QUEUE_DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_QUEUE_DEBUG_ENABLED", "0")))
-
-REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
-REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
-REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
-REDIS_QUEUES_READ_DEFAULT_TIMEOUT = 0.005
-
-HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
-HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
-HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
-HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))
-
-USER_ID = os.getenv("USER_ID", "1234")
-FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")
-PACKAGE_LIGHTNING = os.getenv("PACKAGE_LIGHTNING", None)
-CLOUD_UPLOAD_WARNING = int(os.getenv("CLOUD_UPLOAD_WARNING", "2"))
-DISABLE_DEPENDENCY_CACHE = bool(int(os.getenv("DISABLE_DEPENDENCY_CACHE", "0")))
-# Project under which the resources need to run in cloud. If this env is not set,
-# cloud runner will try to get the default project from the cloud
-LIGHTNING_CLOUD_PROJECT_ID = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")
-LIGHTNING_CLOUD_PRINT_SPECS = os.getenv("LIGHTNING_CLOUD_PRINT_SPECS")
-LIGHTNING_DIR = os.getenv("LIGHTNING_DIR", str(Path.home() / ".lightning"))
-LIGHTNING_CREDENTIAL_PATH = os.getenv("LIGHTNING_CREDENTIAL_PATH", str(Path(LIGHTNING_DIR) / "credentials.json"))
-DOT_IGNORE_FILENAME = ".lightningignore"
-LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
-LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
-LIGHTNING_MODELS_PUBLIC_REGISTRY = "https://lightning.ai/v1/models"
-
-LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST")
-LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0"))
-
-# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
-DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
-ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(
- int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", "0"))
-) # This isn't used in the cloud yet.
-
-# env var trigger running setup commands in the app
-ENABLE_APP_COMMENT_COMMAND_EXECUTION = bool(int(os.getenv("ENABLE_APP_COMMENT_COMMAND_EXECUTION", "0")))
-
-
-DEBUG: bool = lightning_cloud.env.DEBUG
-DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_DEBUG", "0")))
-ENABLE_PULLING_STATE_ENDPOINT = bool(int(os.getenv("ENABLE_PULLING_STATE_ENDPOINT", "1")))
-ENABLE_PUSHING_STATE_ENDPOINT = ENABLE_PULLING_STATE_ENDPOINT and bool(
- int(os.getenv("ENABLE_PUSHING_STATE_ENDPOINT", "1"))
-)
-ENABLE_STATE_WEBSOCKET = bool(int(os.getenv("ENABLE_STATE_WEBSOCKET", "1")))
-ENABLE_UPLOAD_ENDPOINT = bool(int(os.getenv("ENABLE_UPLOAD_ENDPOINT", "1")))
-
-# directory where system customization sync files stored
-SYS_CUSTOMIZATIONS_SYNC_ROOT = "/tmp/sys-customizations-sync" # todo
-# directory where system customization sync files will be copied to be packed into app tarball
-SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"
-
-BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))
-
-
-def enable_multiple_works_in_default_container() -> bool:
- return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
-
-
-def get_cloud_queue_type() -> Optional[str]:
- value = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", None)
- if value is None and enable_interruptible_works():
- value = "http"
- return value
-
-
-# Number of seconds to wait between filesystem checks when waiting for files in remote storage
-REMOTE_STORAGE_WAIT = 0.5
-
-
-# interruptible support
-def enable_interruptible_works() -> bool:
- return bool(int(os.getenv("LIGHTNING_INTERRUPTIBLE_WORKS", "0")))
-
-
-def get_cluster_driver() -> Optional[str]:
- return "direct"
diff --git a/src/lightning/app/core/flow.py b/src/lightning/app/core/flow.py
deleted file mode 100644
index f9ffcca61c5a9..0000000000000
--- a/src/lightning/app/core/flow.py
+++ /dev/null
@@ -1,861 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import warnings
-from copy import deepcopy
-from datetime import datetime
-from types import FrameType
-from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Tuple, Union, cast
-
-from deepdiff import DeepHash
-
-from lightning.app.core.work import LightningWork
-from lightning.app.frontend import Frontend
-from lightning.app.storage.drive import Drive, _maybe_create_drive
-from lightning.app.storage.path import Path
-from lightning.app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name, is_overridden
-from lightning.app.utilities.component import _sanitize_state
-from lightning.app.utilities.exceptions import ExitAppException, LightningFlowException
-from lightning.app.utilities.introspection import _is_init_context, _is_run_context
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute, _maybe_create_cloud_compute
-
-if TYPE_CHECKING:
- from lightning.app.runners.backends.backend import Backend
-
-
-class LightningFlow:
- _INTERNAL_STATE_VARS = {
- # Internal protected variables that are still part of the state (even though they are prefixed with "_")
- "_paths",
- "_layout",
- }
-
- def __init__(self) -> None:
- """The LightningFlow is used by the :class:`~lightning.app.core.app.LightningApp` to coordinate and manage
- long- running jobs contained, the :class:`~lightning.app.core.work.LightningWork`.
-
- A LightningFlow is characterized by:
-
- * A set of state variables.
- * Long-running jobs (:class:`~lightning.app.core.work.LightningWork`).
- * Its children ``LightningFlow`` or ``LightningWork`` with their state variables.
-
- **State variables**
-
- The LightningFlow are special classes whose attributes require to be
- json-serializable (e.g., int, float, bool, list, dict, ...).
-
- They also may not reach into global variables unless they are constant.
-
- The attributes need to be all defined in `__init__` method,
- and eventually assigned to different values throughout the lifetime of the object.
- However, defining new attributes outside of `__init__` is not allowed.
-
- Attributes taken together represent the state of the component.
- Components are capable of retrieving their state and that of their
- children recursively at any time. They are also capable of setting
- an externally provided state recursively to its children.
-
- **Execution model and work**
-
- The entry point for execution is the ``run`` method at the root component.
- The ``run`` method of the root component may call the ``run`` method of its children, and the children
- may call the ``run`` methods of their children and so on.
-
- The ``run`` method of the root component is called repeatedly in a while loop forever until the app gets
- terminated. In this programming model (reminiscent of React, Vue or Streamlit from the JavaScript world),
- the values of the state variables, or their changes, are translated into actions throughout the component
- hierarchy. This means the flow of execution will only be affected by state changes in a component or one of
- its children, and otherwise remain idempotent.
-
- The actions themselves are self-contained within :class:`~lightning.app.core.work.LightningWork`.
- The :class:`~lightning.app.core.work.LightningWork` are typically used for long-running jobs,
- like downloading a dataset, performing a query, starting a computationally heavy script.
- While one may access any state variable in a LightningWork from a LightningFlow, one may not
- directly call methods of other components from within a LightningWork as LightningWork can't have any children.
- This limitation allows applications to be distributed at scale.
-
- **Component hierarchy and App**
-
- Given the above characteristics, a root LightningFlow, potentially containing
- children components, can be passed to an App object and its execution
- can be distributed (each LightningWork will be run within its own process
- or different arrangements).
-
- Example:
-
- >>> from lightning.app import LightningFlow
- >>> class RootFlow(LightningFlow):
- ... def __init__(self):
- ... super().__init__()
- ... self.counter = 0
- ... def run(self):
- ... self.counter += 1
- ...
- >>> flow = RootFlow()
- >>> flow.run()
- >>> assert flow.counter == 1
- >>> assert flow.state["vars"]["counter"] == 1
-
- """
- self._state: set = set()
- self._name: str = ""
- self._flows: set = set()
- self._works: set = set()
- self._structures: set = set()
- self._calls: dict = {}
- self._changes: dict = {}
- self._layout: Union[List[Dict], Dict] = {}
- self._paths: dict = {}
- self._backend: Optional["Backend"] = None
- # tuple instead of a list so that it cannot be modified without using the setter
- self._lightningignore: Tuple[str, ...] = ()
-
- @property
- def name(self) -> str:
- """Return the current LightningFlow name."""
- return self._name or "root"
-
- def __setattr__(self, name: str, value: Any) -> None:
- attr = getattr(self.__class__, name, None)
- if isinstance(attr, property) and attr.fset is not None:
- return attr.fset(self, value)
-
- from lightning.app.structures import Dict as ComponentDict
- from lightning.app.structures import List as ComponentList
-
- if (
- not _is_init_context(self)
- and name not in self._state
- and name not in self._paths
- and (
- not isinstance(value, (LightningWork, LightningFlow))
- or (isinstance(value, (LightningWork, LightningFlow)) and not _is_run_context(self))
- )
- and name not in self._works.union(self._flows)
- and self._is_state_attribute(name)
- ):
- raise AttributeError(f"Cannot set attributes that were not defined in __init__: {name}")
-
- if isinstance(value, str) and value.startswith("lit://"):
- value = Path(value)
-
- if self._is_state_attribute(name):
- if hasattr(self, name):
- if name in self._flows and value != getattr(self, name):
- raise AttributeError(f"Cannot set attributes as the flow can't be changed once defined: {name}")
-
- if name in self._works and value != getattr(self, name):
- raise AttributeError(f"Cannot set attributes as the work can't be changed once defined: {name}")
-
- if isinstance(value, (list, dict)) and value:
- _type = (LightningFlow, LightningWork, ComponentList, ComponentDict)
- if isinstance(value, list) and all(isinstance(va, _type) for va in value):
- value = ComponentList(*value)
-
- if isinstance(value, dict) and all(isinstance(va, _type) for va in value.values()):
- value = ComponentDict(**value)
-
- if isinstance(value, LightningFlow):
- self._flows.add(name)
- _set_child_name(self, value, name)
- if name in self._state:
- self._state.remove(name)
- # Attach the backend to the flow and its children work.
- if self._backend:
- LightningFlow._attach_backend(value, self._backend)
- for work in value.works():
- work._register_cloud_compute()
-
- elif isinstance(value, LightningWork):
- self._works.add(name)
- _set_child_name(self, value, name)
- if name in self._state:
- self._state.remove(name)
- if self._backend:
- self._backend._wrap_run_method(_LightningAppRef().get_current(), value) # type: ignore[arg-type]
- value._register_cloud_compute()
-
- elif isinstance(value, (ComponentDict, ComponentList)):
- self._structures.add(name)
- _set_child_name(self, value, name)
-
- _backend = getattr(self, "backend", None)
- if _backend is not None:
- value._backend = _backend
-
- for flow in value.flows:
- if _backend is not None:
- LightningFlow._attach_backend(flow, _backend)
-
- for work in value.works:
- work._register_cloud_compute()
- if _backend is not None:
- _backend._wrap_run_method(_LightningAppRef().get_current(), work)
-
- elif isinstance(value, Path):
- # In the init context, the full name of the Flow and Work is not known, i.e., we can't serialize
- # the path without losing the information of origin and consumer. Hence, we delay the serialization
- # of the path object until the app is instantiated.
- if not _is_init_context(self):
- self._paths[name] = value.to_dict()
- self._state.add(name)
-
- elif isinstance(value, Drive):
- value = deepcopy(value)
- value.component_name = self.name
- self._state.add(name)
-
- elif isinstance(value, CloudCompute):
- self._state.add(name)
-
- elif _is_json_serializable(value):
- self._state.add(name)
-
- if not isinstance(value, Path) and hasattr(self, "_paths") and name in self._paths:
- # The attribute changed type from Path to another
- self._paths.pop(name)
-
- else:
- raise AttributeError(
- f"Only JSON-serializable attributes are currently supported"
- f" (str, int, float, bool, tuple, list, dict etc.) to be part of {self} state. "
- f"Found the attribute {name} with {value} instead. \n"
- "HINT: Private attributes defined as follows `self._x = y` won't be shared between components "
- "and therefore don't need to be JSON-serializable."
- )
-
- super().__setattr__(name, value)
- return None
-
- @staticmethod
- def _attach_backend(flow: "LightningFlow", backend: "Backend") -> None:
- """Attach the backend to all flows and its children."""
- flow._backend = backend
-
- for name in flow._structures:
- getattr(flow, name)._backend = backend
-
- for child_flow in flow.flows.values():
- child_flow._backend = backend
- for name in child_flow._structures:
- getattr(child_flow, name)._backend = backend
-
- app = _LightningAppRef().get_current()
-
- for child_work in flow.works():
- child_work._backend = backend
- backend._wrap_run_method(app, child_work) # type: ignore[arg-type]
-
- def __getattr__(self, item: str) -> Any:
- if item in self.__dict__.get("_paths", {}):
- return Path.from_dict(self._paths[item])
- return self.__getattribute__(item)
-
- @property
- def ready(self) -> bool:
- """Override to customize when your App should be ready."""
- flows = self.flows
- return all(flow.ready for flow in flows.values()) if flows else True
-
- @property
- def changes(self) -> dict:
- return self._changes.copy()
-
- @property
- def state(self) -> dict:
- """Returns the current flow state along its children."""
- children_state = {child: getattr(self, child).state for child in self._flows}
- works_state = {work: getattr(self, work).state for work in self._works}
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- # this may have the challenge that ret cannot be pickled, we'll need to handle this
- "calls": self._calls.copy(),
- "flows": children_state,
- "works": works_state,
- "structures": {child: getattr(self, child).state for child in self._structures},
- "changes": {},
- }
-
- @property
- def state_vars(self) -> dict:
- children_state = {child: getattr(self, child).state_vars for child in self._flows}
- works_state = {work: getattr(self, work).state_vars for work in self._works}
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- "flows": children_state,
- "works": works_state,
- "structures": {child: getattr(self, child).state_vars for child in self._structures},
- }
-
- @property
- def state_with_changes(self) -> dict:
- children_state = {child: getattr(self, child).state_with_changes for child in self._flows}
- works_state = {work: getattr(self, work).state_with_changes for work in self._works}
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- # this may have the challenge that ret cannot be pickled, we'll need to handle this
- "calls": self._calls.copy(),
- "flows": children_state,
- "works": works_state,
- "structures": {child: getattr(self, child).state_with_changes for child in self._structures},
- "changes": self.changes,
- }
-
- @property
- def flows(self) -> Dict[str, "LightningFlow"]:
- """Return its children LightningFlow."""
- flows = {}
- for el in sorted(self._flows):
- flow = getattr(self, el)
- flows[flow.name] = flow
- flows.update(flow.flows)
- for struct_name in sorted(self._structures):
- flows.update(getattr(self, struct_name).flows)
- return flows
-
- @property
- def lightningignore(self) -> Tuple[str, ...]:
- """Programmatic equivalent of the ``.lightningignore`` file."""
- return self._lightningignore
-
- @lightningignore.setter
- def lightningignore(self, lightningignore: Tuple[str, ...]) -> None:
- if self._backend is not None:
- raise RuntimeError(
- f"Your app has been already dispatched, so modifying the `{self.name}.lightningignore` does not have an"
- " effect"
- )
- self._lightningignore = lightningignore
-
- def works(self, recurse: bool = True) -> List[LightningWork]:
- """Return its :class:`~lightning.app.core.work.LightningWork`."""
- works = [getattr(self, el) for el in sorted(self._works)]
- if not recurse:
- return works
- for child_name in sorted(self._flows):
- for w in getattr(self, child_name).works(recurse=recurse):
- works.append(w)
- for struct_name in sorted(self._structures):
- for w in getattr(self, struct_name).works:
- works.append(w)
- return works
-
- def named_works(self, recurse: bool = True) -> List[Tuple[str, LightningWork]]:
- """Return its :class:`~lightning.app.core.work.LightningWork` with their names."""
- return [(w.name, w) for w in self.works(recurse=recurse)]
-
- def set_state(self, provided_state: Dict, recurse: bool = True) -> None:
- """Method to set the state to this LightningFlow, its children and
- :class:`~lightning.app.core.work.LightningWork`.
-
- Arguments:
- provided_state: The state to be reloaded
- recurse: Whether to apply the state down children.
-
- """
- for k, v in provided_state["vars"].items():
- if isinstance(v, Dict):
- v = _maybe_create_drive(self.name, v)
- if isinstance(v, Dict):
- v = _maybe_create_cloud_compute(v)
- setattr(self, k, v)
- self._changes = provided_state["changes"]
- self._calls.update(provided_state["calls"])
-
- if not recurse:
- return
-
- for child, state in provided_state["flows"].items():
- getattr(self, child).set_state(state)
- for work, state in provided_state["works"].items():
- getattr(self, work).set_state(state)
- for structure, state in provided_state["structures"].items():
- getattr(self, structure).set_state(state)
-
- def stop(self, end_msg: str = "") -> None:
- """Method used to exit the application."""
- if end_msg:
- print(end_msg)
- raise ExitAppException
-
- def fail(self, end_msg: str = "") -> None:
- """Method used to exit and fail the application."""
- if end_msg:
- print(end_msg)
- raise LightningFlowException
-
- def _exit(self, end_msg: str = "") -> None:
- """Used to exit the application.
-
- Private method.
-
- .. deprecated:: 1.9.0
- This function is deprecated and will be removed in 2.0.0. Use :meth:`stop` instead.
-
- """
- warnings.warn(
- DeprecationWarning(
- "This function is deprecated and will be removed in 2.0.0. Use `LightningFlow.stop` instead."
- )
- )
-
- return self.stop(end_msg=end_msg)
-
- @staticmethod
- def _is_state_attribute(name: str) -> bool:
- """Every public attribute is part of the state by default and all protected (prefixed by '_') or private
- (prefixed by '__') attributes are not.
-
- Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
-
- """
- return name in LightningFlow._INTERNAL_STATE_VARS or not name.startswith("_")
-
- def run(self, *args: Any, **kwargs: Any) -> None:
- """Override with your own logic."""
- pass
-
- def schedule(
- self, cron_pattern: str, start_time: Optional[datetime] = None, user_key: Optional[str] = None
- ) -> bool:
- """The schedule method is used to run a part of the flow logic on timely manner.
-
- .. code-block:: python
-
- from lightning.app import LightningFlow
-
-
- class Flow(LightningFlow):
- def run(self):
- if self.schedule("hourly"):
- print("run some code every hour")
-
- Arguments:
- cron_pattern: The cron pattern to provide. Learn more at https://crontab.guru/.
- start_time: The start time of the cron job.
- user_key: Optional key used to improve the caching mechanism.
-
- A best practice is to avoid running a dynamic flow or work under the self.schedule method.
- Instead, instantiate them within the condition, but run them outside.
-
- .. code-block:: python
-
- from lightning.app import LightningFlow
- from lightning.app.structures import List
-
-
- class SchedulerDAG(LightningFlow):
- def __init__(self):
- super().__init__()
- self.dags = List()
-
- def run(self):
- if self.schedule("hourly"):
- self.dags.append(DAG(...))
-
- for dag in self.dags:
- payload = dag.run()
-
- **Learn more about Scheduling**
-
- .. raw:: html
-
-
-
-
- .. displayitem::
- :header: Schedule your components
- :description: Learn more scheduling.
- :col_css: col-md-4
- :button_link: ../../../glossary/scheduling.html
- :height: 180
- :tag: Basic
-
- .. displayitem::
- :header: Build your own DAG
- :description: Learn more DAG scheduling with examples.
- :col_css: col-md-4
- :button_link: ../../../examples/app/dag/dag.html
- :height: 180
- :tag: Basic
-
- .. raw:: html
-
-
-
-
-
- """
- if not user_key:
- frame = cast(FrameType, inspect.currentframe()).f_back
- assert frame is not None
- cache_key = f"{cron_pattern}.{frame.f_code.co_filename}.{frame.f_lineno}"
- else:
- cache_key = user_key
-
- call_hash = f"{self.schedule.__name__}:{DeepHash(cache_key)[cache_key]}"
-
- if "scheduling" not in self._calls:
- self._calls["scheduling"] = {}
-
- entered = call_hash in self._calls["scheduling"]
-
- expr_aliases = {
- "midnight": "@midnight",
- "hourly": "@hourly",
- "daily": "@daily",
- "weekly": "@weekly",
- "monthly": "@monthly",
- "yearly": "@yearly",
- "annually": "@annually",
- }
-
- if cron_pattern in expr_aliases:
- cron_pattern = expr_aliases[cron_pattern]
-
- if not entered:
- if not start_time:
- start_time = datetime.now()
-
- schedule_metadata = {
- "running": False,
- "cron_pattern": cron_pattern,
- "start_time": str(start_time.isoformat()),
- "name": self.name,
- }
-
- self._calls["scheduling"][call_hash] = schedule_metadata
- app = _LightningAppRef().get_current()
- if app:
- app._register_schedule(call_hash, schedule_metadata)
- return True
-
- return self._calls["scheduling"][call_hash]["running"]
-
- def _enable_schedule(self, call_hash: str) -> None:
- self._calls["scheduling"][call_hash]["running"] = True
-
- def _disable_running_schedules(self) -> None:
- if "scheduling" not in self._calls:
- return
- for call_hash in self._calls["scheduling"]:
- self._calls["scheduling"][call_hash]["running"] = False
-
- def configure_layout(self) -> Union[Dict[str, Any], List[Dict[str, Any]], Frontend]:
- """Configure the UI layout of this LightningFlow.
-
- You can either
-
- 1. Return a single :class:`~lightning.app.frontend.frontend.Frontend` object to serve a user interface
- for this Flow.
- 2. Return a single dictionary to expose the UI of a child flow.
- 3. Return a list of dictionaries to arrange the children of this flow in one or multiple tabs.
-
- **Example:** Serve a static directory (with at least a file index.html inside).
-
- .. code-block:: python
-
- from lightning.app.frontend import StaticWebFrontend
-
-
- class Flow(LightningFlow):
- ...
-
- def configure_layout(self):
- return StaticWebFrontend("path/to/folder/to/serve")
-
- **Example:** Serve a streamlit UI (needs the streamlit package to be installed).
-
- .. code-block:: python
-
- from lightning.app.frontend import StaticWebFrontend
-
-
- class Flow(LightningFlow):
- ...
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=my_streamlit_ui)
-
-
- def my_streamlit_ui(state):
- # add your streamlit code here!
- import streamlit as st
-
-
- **Example:** Arrange the UI of my children in tabs (default UI by Lightning).
-
- .. code-block:: python
-
- class Flow(LightningFlow):
- def configure_layout(self):
- return [
- dict(name="First Tab", content=self.child0),
- dict(name="Second Tab", content=self.child1),
- dict(name="Lightning", content="https://lightning.ai"),
- ]
-
- If you don't implement ``configure_layout``, Lightning will collect all children and display their UI in a tab
- (if they have their own ``configure_layout`` implemented).
-
- Note:
- This hook gets called at the time of app creation and then again as part of the loop. If desired, the
- returned layout configuration can depend on the state. The only exception are the flows that return a
- :class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
- in order for the runtime to start the server.
-
- **Learn more about adding UI**
-
- .. raw:: html
-
-
-
-
- .. displayitem::
- :header: Add a web user interface (UI)
- :description: Learn more how to integrate several UIs.
- :col_css: col-md-4
- :button_link: ../../../workflows/add_web_ui/index.html
- :height: 180
- :tag: Basic
-
- .. raw:: html
-
-
-
-
-
- """
- return [{"name": name, "content": component} for (name, component) in self.flows.items()]
-
- def experimental_iterate(self, iterable: Iterable, run_once: bool = True, user_key: str = "") -> Generator:
- """This method should always be used with any kind of iterable to ensure its fault tolerant.
-
- If you want your iterable to always be consumed from scratch, you shouldn't use this method.
-
- Arguments:
- iterable: Iterable to iterate over. The iterable shouldn't have side effects or be random.
- run_once: Whether to run the entire iteration only once.
- Otherwise, it would restart from the beginning.
- user_key: Key to be used to track the caching mechanism.
-
- """
- if not isinstance(iterable, Iterable):
- raise TypeError(f"An iterable should be provided to `self.iterate` method. Found {iterable}")
-
- # TODO: Find a better way. Investigated using __reduce__, but state change invalidate the cache.
- if not user_key:
- frame = cast(FrameType, inspect.currentframe()).f_back
- assert frame is not None
- cache_key = f"{frame.f_code.co_filename}.{frame.f_code.co_firstlineno}"
- else:
- cache_key = user_key
-
- call_hash = f"{self.experimental_iterate.__name__}:{DeepHash(cache_key)[cache_key]}"
- entered = call_hash in self._calls
- has_started = entered and self._calls[call_hash]["counter"] > 0
- has_finished = entered and self._calls[call_hash]["has_finished"]
-
- if has_finished:
- if not run_once:
- self._calls[call_hash].update({"counter": 0, "has_finished": False})
- else:
- return range(0)
-
- if not has_started:
- self._calls[call_hash] = {
- "name": self.experimental_iterate.__name__,
- "call_hash": call_hash,
- "counter": 0,
- "has_finished": False,
- }
-
- skip_counter = max(self._calls[call_hash]["counter"], 0)
-
- for counter, value in enumerate(iterable):
- if skip_counter:
- skip_counter -= 1
- continue
- self._calls[call_hash].update({"counter": counter})
- yield value
-
- self._calls[call_hash].update({"has_finished": True})
-
- def configure_commands(self) -> None:
- """Configure the commands of this LightningFlow.
-
- Returns a list of dictionaries mapping a command name to a flow method.
-
- .. code-block:: python
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.names = []
-
- def configure_commands(self):
- return {"my_command_name": self.my_remote_method}
-
- def my_remote_method(self, name):
- self.names.append(name)
-
- Once the app is running with the following command:
-
- .. code-block:: bash
-
- lightning_app run app app.py
-
- .. code-block:: bash
-
- lightning_app my_command_name --args name=my_own_name
-
- """
- raise NotImplementedError
-
- def configure_api(self) -> None:
- """Configure the API routes of the LightningFlow.
-
- Returns a list of HttpMethod such as Post or Get.
-
- .. code-block:: python
-
- from lightning.app import LightningFlow
- from lightning.app.api import Post
-
- from pydantic import BaseModel
-
-
- class HandlerModel(BaseModel):
- name: str
-
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.names = []
-
- def handler(self, config: HandlerModel) -> None:
- self.names.append(config.name)
-
- def configure_api(self):
- return [Post("/v1/api/request", self.handler)]
-
- Once the app is running, you can access the Swagger UI of the app
- under the ``/docs`` route.
-
- """
- raise NotImplementedError
-
- def state_dict(self) -> dict:
- """Returns the current flow state but not its children."""
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- "calls": self._calls.copy(),
- "changes": {},
- "flows": {},
- "works": {},
- "structures": {},
- }
-
- def load_state_dict(
- self,
- flow_state: Dict[str, Any],
- children_states: Dict[str, Any],
- strict: bool = True,
- ) -> None:
- """Reloads the state of this flow and its children.
-
- .. code-block:: python
-
-
- class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
- class Flow(LightningFlow):
- def run(self):
- # dynamically create a work.
- if not getattr(self, "w", None):
- self.w = WorkReload()
-
- self.w.run()
-
- def load_state_dict(self, flow_state, children_states, strict) -> None:
- # 1: Re-instantiate the dynamic work
- self.w = Work()
-
- # 2: Make any states modification / migration.
- ...
-
- # 3: Call the parent ``load_state_dict`` to
- # recursively reload the states.
- super().load_state_dict(
- flow_state,
- children_states,
- strict,
- )
-
- Arguments:
- flow_state: The state of the current flow.
- children_states: The state of the dynamic children of this flow.
- strict: Whether to raise an exception if a dynamic
- children hasn't been re-created.
-
- """
- self.set_state(flow_state, recurse=False)
- direct_children_states = {k: v for k, v in children_states.items() if "." not in k}
- for child_name, state in direct_children_states.items():
- child = getattr(self, child_name, None)
- if isinstance(child, LightningFlow):
- lower_children_states = {
- k.replace(child_name + ".", ""): v
- for k, v in children_states.items()
- if k.startswith(child_name) and k != child_name
- }
- child.load_state_dict(state, lower_children_states, strict=strict)
- elif isinstance(child, LightningWork):
- child.set_state(state)
- elif strict:
- raise ValueError(f"The component {child_name} wasn't instantiated for the component {self.name}")
-
-
-class _RootFlow(LightningFlow):
- def __init__(self, work: LightningWork) -> None:
- super().__init__()
- self.work = work
-
- @property
- def ready(self) -> bool:
- ready = getattr(self.work, "ready", None)
- if ready is not None:
- return ready
- return self.work.url != ""
-
- def run(self) -> None:
- if self.work.has_succeeded:
- self.work.stop()
- self.stop()
- self.work.run()
-
- def configure_layout(self) -> list:
- if is_overridden("configure_layout", self.work):
- return [{"name": "Main", "content": self.work}]
- return []
diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py
deleted file mode 100644
index d37251c824616..0000000000000
--- a/src/lightning/app/core/queues.py
+++ /dev/null
@@ -1,574 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import multiprocessing
-import pickle
-import queue # needed as import instead from/import for mocking in tests
-import time
-import warnings
-from abc import ABC, abstractmethod
-from enum import Enum
-from pathlib import Path
-from typing import Any, List, Optional, Tuple
-from urllib.parse import urljoin
-
-import backoff
-import requests
-from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
-
-from lightning.app.core.constants import (
- BATCH_DELTA_COUNT,
- HTTP_QUEUE_REFRESH_INTERVAL,
- HTTP_QUEUE_REQUESTS_PER_SECOND,
- HTTP_QUEUE_TOKEN,
- HTTP_QUEUE_URL,
- LIGHTNING_DIR,
- QUEUE_DEBUG_ENABLED,
- REDIS_HOST,
- REDIS_PASSWORD,
- REDIS_PORT,
- REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
- STATE_UPDATE_TIMEOUT,
- WARNING_QUEUE_SIZE,
-)
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.imports import _is_redis_available, requires
-from lightning.app.utilities.network import HTTPClient
-
-if _is_redis_available():
- import redis
-
-logger = Logger(__name__)
-
-
-READINESS_QUEUE_CONSTANT = "READINESS_QUEUE"
-ERROR_QUEUE_CONSTANT = "ERROR_QUEUE"
-DELTA_QUEUE_CONSTANT = "DELTA_QUEUE"
-HAS_SERVER_STARTED_CONSTANT = "HAS_SERVER_STARTED_QUEUE"
-CALLER_QUEUE_CONSTANT = "CALLER_QUEUE"
-API_STATE_PUBLISH_QUEUE_CONSTANT = "API_STATE_PUBLISH_QUEUE"
-API_DELTA_QUEUE_CONSTANT = "API_DELTA_QUEUE"
-API_REFRESH_QUEUE_CONSTANT = "API_REFRESH_QUEUE"
-ORCHESTRATOR_REQUEST_CONSTANT = "ORCHESTRATOR_REQUEST"
-ORCHESTRATOR_RESPONSE_CONSTANT = "ORCHESTRATOR_RESPONSE"
-ORCHESTRATOR_COPY_REQUEST_CONSTANT = "ORCHESTRATOR_COPY_REQUEST"
-ORCHESTRATOR_COPY_RESPONSE_CONSTANT = "ORCHESTRATOR_COPY_RESPONSE"
-WORK_QUEUE_CONSTANT = "WORK_QUEUE"
-API_RESPONSE_QUEUE_CONSTANT = "API_RESPONSE_QUEUE"
-FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT = "FLOW_TO_WORKS_DELTA_QUEUE"
-
-
-class QueuingSystem(Enum):
- MULTIPROCESS = "multiprocess"
- REDIS = "redis"
- HTTP = "http"
-
- def get_queue(self, queue_name: str) -> "BaseQueue":
- if self == QueuingSystem.MULTIPROCESS:
- return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
- if self == QueuingSystem.REDIS:
- return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
- return RateLimitedQueue(
- HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
- )
-
- def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- def get_readiness_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{READINESS_QUEUE_CONSTANT}" if queue_id else READINESS_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- def get_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- def get_error_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{ERROR_QUEUE_CONSTANT}" if queue_id else ERROR_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- def get_has_server_started_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{HAS_SERVER_STARTED_CONSTANT}" if queue_id else HAS_SERVER_STARTED_CONSTANT
- return self.get_queue(queue_name)
-
- def get_caller_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{CALLER_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{CALLER_QUEUE_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_api_state_publish_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{API_STATE_PUBLISH_QUEUE_CONSTANT}" if queue_id else API_STATE_PUBLISH_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- # TODO: This is hack, so we can remove this queue entirely when fully optimized.
- def get_api_delta_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = f"{queue_id}_{DELTA_QUEUE_CONSTANT}" if queue_id else DELTA_QUEUE_CONSTANT
- return self.get_queue(queue_name)
-
- def get_orchestrator_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
- if queue_id
- else f"{ORCHESTRATOR_REQUEST_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_orchestrator_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
- if queue_id
- else f"{ORCHESTRATOR_RESPONSE_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_orchestrator_copy_request_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
- if queue_id
- else f"{ORCHESTRATOR_COPY_REQUEST_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_orchestrator_copy_response_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
- if queue_id
- else f"{ORCHESTRATOR_COPY_RESPONSE_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_work_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{WORK_QUEUE_CONSTANT}_{work_name}" if queue_id else f"{WORK_QUEUE_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
- def get_flow_to_work_delta_queue(self, work_name: str, queue_id: Optional[str] = None) -> "BaseQueue":
- queue_name = (
- f"{queue_id}_{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"
- if queue_id
- else f"{FLOW_TO_WORKS_DELTA_QUEUE_CONSTANT}_{work_name}"
- )
- return self.get_queue(queue_name)
-
-
-class BaseQueue(ABC):
- """Base Queue class that has a similar API to the Queue class in python."""
-
- @abstractmethod
- def __init__(self, name: str, default_timeout: float):
- self.name = name
- self.default_timeout = default_timeout
-
- @abstractmethod
- def put(self, item: Any) -> None:
- pass
-
- @abstractmethod
- def get(self, timeout: Optional[float] = None) -> Any:
- """Returns the left most element of the queue.
-
- Parameters
- ----------
- timeout:
- Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
- A timeout of None can be used to block indefinitely.
-
- """
- pass
-
- @abstractmethod
- def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
- """Returns the left most elements of the queue.
-
- Parameters
- ----------
- timeout:
- Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
- A timeout of None can be used to block indefinitely.
- count:
- The number of element to get from the queue
-
- """
-
- @property
- def is_running(self) -> bool:
- """Returns True if the queue is running, False otherwise.
-
- Child classes should override this property and implement custom logic as required
-
- """
- return True
-
-
-class MultiProcessQueue(BaseQueue):
- def __init__(self, name: str, default_timeout: float) -> None:
- self.name = name
- self.default_timeout = default_timeout
- context = multiprocessing.get_context("spawn")
- self.queue = context.Queue()
-
- def put(self, item: Any) -> None:
- self.queue.put(item)
-
- def get(self, timeout: Optional[float] = None) -> Any:
- if timeout == 0:
- timeout = self.default_timeout
- return self.queue.get(timeout=timeout, block=(timeout is None))
-
- def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
- if timeout == 0:
- timeout = self.default_timeout
- # For multiprocessing, we can simply collect the latest upmost element
- return [self.queue.get(timeout=timeout, block=(timeout is None))]
-
-
-class RedisQueue(BaseQueue):
- @requires("redis")
- def __init__(
- self,
- name: str,
- default_timeout: float,
- host: Optional[str] = None,
- port: Optional[int] = None,
- password: Optional[str] = None,
- ):
- """
- Parameters
- ----------
- name:
- The name of the list to use
- default_timeout:
- Default timeout for redis read
- host:
- The hostname of the redis server
- port:
- The port of the redis server
- password:
- Redis password
- """
- if name is None:
- raise ValueError("You must specify a name for the queue")
- self.host = host or REDIS_HOST
- self.port = port or REDIS_PORT
- self.password = password or REDIS_PASSWORD
- self.name = name
- self.default_timeout = default_timeout
- self.redis = redis.Redis(host=self.host, port=self.port, password=self.password)
-
- def put(self, item: Any) -> None:
- from lightning.app.core.work import LightningWork
-
- is_work = isinstance(item, LightningWork)
-
- # TODO: Be careful to handle with a lock if another thread needs
- # to access the work backend one day.
- # The backend isn't picklable
- # Raises a TypeError: cannot pickle '_thread.RLock' object
- if is_work:
- backend = item._backend
- item._backend = None
-
- value = pickle.dumps(item)
- queue_len = self.length()
- if queue_len >= WARNING_QUEUE_SIZE:
- warnings.warn(
- f"The Redis Queue {self.name} length is larger than the "
- f"recommended length of {WARNING_QUEUE_SIZE}. "
- f"Found {queue_len}. This might cause your application to crash, "
- "please investigate this."
- )
- try:
- self.redis.rpush(self.name, value)
- except redis.exceptions.ConnectionError:
- raise ConnectionError(
- "Your app failed because it couldn't connect to Redis. "
- "Please try running your app again. "
- "If the issue persists, please contact support@lightning.ai"
- )
-
- # The backend isn't pickable.
- if is_work:
- item._backend = backend
-
- def get(self, timeout: Optional[float] = None) -> Any:
- """Returns the left most element of the redis queue.
-
- Parameters
- ----------
- timeout:
- Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
- A timeout of None can be used to block indefinitely.
-
- """
- if timeout is None:
- # this means it's blocking in redis
- timeout = 0
- elif timeout == 0:
- timeout = self.default_timeout
-
- try:
- out = self.redis.blpop([self.name], timeout=timeout)
- except redis.exceptions.ConnectionError:
- raise ConnectionError(
- "Your app failed because it couldn't connect to Redis. "
- "Please try running your app again. "
- "If the issue persists, please contact support@lightning.ai"
- )
-
- if out is None:
- raise queue.Empty
- return pickle.loads(out[1])
-
- def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
- return [self.get(timeout=timeout)]
-
- def clear(self) -> None:
- """Clear all elements in the queue."""
- self.redis.delete(self.name)
-
- def length(self) -> int:
- """Returns the number of elements in the queue."""
- try:
- return self.redis.llen(self.name)
- except redis.exceptions.ConnectionError:
- raise ConnectionError(
- "Your app failed because it couldn't connect to Redis. "
- "Please try running your app again. "
- "If the issue persists, please contact support@lightning.ai"
- )
-
- @property
- def is_running(self) -> bool:
- """Pinging the redis server to see if it is alive."""
- try:
- return self.redis.ping()
- except redis.exceptions.ConnectionError:
- return False
-
- def to_dict(self) -> dict:
- return {
- "type": "redis",
- "name": self.name,
- "default_timeout": self.default_timeout,
- "host": self.host,
- "port": self.port,
- "password": self.password,
- }
-
- @classmethod
- def from_dict(cls, state: dict) -> "RedisQueue":
- return cls(**state)
-
-
-class RateLimitedQueue(BaseQueue):
- def __init__(self, queue: BaseQueue, requests_per_second: float):
- """This is a queue wrapper that will block on get or put calls if they are made too quickly.
-
- Args:
- queue: The queue to wrap.
- requests_per_second: The target number of get or put requests per second.
-
- """
- self.name = queue.name
- self.default_timeout = queue.default_timeout
-
- self._queue = queue
- self._seconds_per_request = 1 / requests_per_second
-
- self._last_get = 0.0
-
- @property
- def is_running(self) -> bool:
- return self._queue.is_running
-
- def _wait_until_allowed(self, last_time: float) -> None:
- t = time.time()
- diff = t - last_time
- if diff < self._seconds_per_request:
- time.sleep(self._seconds_per_request - diff)
-
- def get(self, timeout: Optional[float] = None) -> Any:
- self._wait_until_allowed(self._last_get)
- self._last_get = time.time()
- return self._queue.get(timeout=timeout)
-
- def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
- self._wait_until_allowed(self._last_get)
- self._last_get = time.time()
- return self._queue.batch_get(timeout=timeout)
-
- def put(self, item: Any) -> None:
- return self._queue.put(item)
-
-
-class HTTPQueue(BaseQueue):
- def __init__(self, name: str, default_timeout: float) -> None:
- """
- Parameters
- ----------
- name:
- The name of the Queue to use. In the current implementation, we expect the name to be of the format
- `appID_queueName`. Based on this assumption, we try to fetch the app id and the queue name by splitting
- the `name` argument.
- default_timeout:
- Default timeout for redis read
- """
- if name is None:
- raise ValueError("You must specify a name for the queue")
- self.app_id, self._name_suffix = self._split_app_id_and_queue_name(name)
- self.name = name # keeping the name for debugging
- self.default_timeout = default_timeout
- self.client = HTTPClient(base_url=HTTP_QUEUE_URL, auth_token=HTTP_QUEUE_TOKEN, log_callback=debug_log_callback)
-
- @property
- def is_running(self) -> bool:
- """Pinging the http redis server to see if it is alive."""
- try:
- url = urljoin(HTTP_QUEUE_URL, "health")
- resp = requests.get(
- url,
- headers={"Authorization": f"Bearer {HTTP_QUEUE_TOKEN}"},
- timeout=1,
- )
- if resp.status_code == 200:
- return True
- except (ConnectionError, ConnectTimeout, ReadTimeout):
- return False
- return False
-
- def get(self, timeout: Optional[float] = None) -> Any:
- if not self.app_id:
- raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
-
- # it's a blocking call, we need to loop and call the backend to mimic this behavior
- if timeout is None:
- while True:
- try:
- try:
- return self._get()
- except requests.exceptions.HTTPError:
- pass
- except queue.Empty:
- time.sleep(HTTP_QUEUE_REFRESH_INTERVAL)
-
- # make one request and return the result
- if timeout == 0:
- try:
- return self._get()
- except requests.exceptions.HTTPError:
- return None
-
- # timeout is some value - loop until the timeout is reached
- start_time = time.time()
- while (time.time() - start_time) < timeout:
- try:
- try:
- return self._get()
- except requests.exceptions.HTTPError:
- if timeout > self.default_timeout:
- return None
- raise queue.Empty
- except queue.Empty:
- # Note: In theory, there isn't a need for a sleep as the queue shouldn't
- # block the flow if the queue is empty.
- # However, as the Http Server can saturate,
- # let's add a sleep here if a higher timeout is provided
- # than the default timeout
- if timeout > self.default_timeout:
- time.sleep(0.05)
- return None
-
- def _get(self) -> Any:
- try:
- resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "pop"})
- if resp.status_code == 204:
- raise queue.Empty
- return pickle.loads(resp.content)
- except ConnectionError:
- # Note: If the Http Queue service isn't available,
- # we consider the queue is empty to avoid failing the app.
- raise queue.Empty
-
- def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
- try:
- resp = self.client.post(
- f"v1/{self.app_id}/{self._name_suffix}",
- query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)},
- )
- if resp.status_code == 204:
- raise queue.Empty
- return [pickle.loads(base64.b64decode(data)) for data in resp.json()]
- except ConnectionError:
- # Note: If the Http Queue service isn't available,
- # we consider the queue is empty to avoid failing the app.
- raise queue.Empty
-
- @backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
- def put(self, item: Any) -> None:
- if not self.app_id:
- raise ValueError(f"The Lightning App ID couldn't be extracted from the queue name: {self.name}")
-
- value = pickle.dumps(item)
- queue_len = self.length()
- if queue_len >= WARNING_QUEUE_SIZE:
- warnings.warn(
- f"The Queue {self._name_suffix} length is larger than the recommended length of {WARNING_QUEUE_SIZE}. "
- f"Found {queue_len}. This might cause your application to crash, please investigate this."
- )
- resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", data=value, query_params={"action": "push"})
- if resp.status_code != 201:
- raise RuntimeError(f"Failed to push to queue: {self._name_suffix}")
-
- def length(self) -> int:
- if not self.app_id:
- raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
-
- try:
- val = self.client.get(f"/v1/{self.app_id}/{self._name_suffix}/length")
- return int(val.text)
- except requests.exceptions.HTTPError:
- return 0
-
- @staticmethod
- def _split_app_id_and_queue_name(queue_name: str) -> Tuple[str, str]:
- """This splits the app id and the queue name into two parts.
-
- This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
- accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
-
- """
- if "_" not in queue_name:
- return "", queue_name
- app_id, queue_name = queue_name.split("_", 1)
- return app_id, queue_name
-
- def to_dict(self) -> dict:
- return {
- "type": "http",
- "name": self.name,
- "default_timeout": self.default_timeout,
- }
-
- @classmethod
- def from_dict(cls, state: dict) -> "HTTPQueue":
- return cls(**state)
-
-
-def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None:
- if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists():
- logger.info(message, *args, **kwargs)
diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py
deleted file mode 100644
index 9b4ada4144649..0000000000000
--- a/src/lightning/app/core/work.py
+++ /dev/null
@@ -1,769 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import sys
-import time
-import warnings
-from copy import deepcopy
-from functools import partial, wraps
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
-
-from deepdiff import DeepHash, Delta
-
-from lightning.app.core.queues import BaseQueue
-from lightning.app.storage.drive import Drive, _maybe_create_drive
-from lightning.app.storage.path import Path
-from lightning.app.storage.payload import Payload
-from lightning.app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, is_overridden
-from lightning.app.utilities.app_status import WorkStatus
-from lightning.app.utilities.component import _is_flow_context, _sanitize_state
-from lightning.app.utilities.enum import (
- CacheCallsKeys,
- WorkFailureReasons,
- WorkStageStatus,
- WorkStopReasons,
- make_status,
-)
-from lightning.app.utilities.exceptions import LightningWorkException
-from lightning.app.utilities.introspection import _is_init_context
-from lightning.app.utilities.network import find_free_network_port
-from lightning.app.utilities.packaging.build_config import BuildConfig
-from lightning.app.utilities.packaging.cloud_compute import (
- _CLOUD_COMPUTE_STORE,
- CloudCompute,
- _CloudComputeStore,
- _maybe_create_cloud_compute,
-)
-from lightning.app.utilities.proxies import Action, LightningWorkSetAttrProxy, ProxyWorkRun, WorkRunExecutor, unwrap
-
-if TYPE_CHECKING:
- from lightning.app.frontend import Frontend
-
-
-class LightningWork:
- _INTERNAL_STATE_VARS = (
- # Internal protected variables that are still part of the state (even though they are prefixed with "_")
- "_paths",
- "_host",
- "_port",
- "_url",
- "_restarting",
- "_internal_ip",
- "_public_ip",
- )
-
- _run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
- # TODO: Move to spawn for all Operating System.
- _start_method = "spawn" if sys.platform in ("darwin", "win32") else "fork"
-
- def __init__(
- self,
- parallel: bool = False,
- cache_calls: bool = True,
- raise_exception: bool = True,
- host: str = "127.0.0.1",
- port: Optional[int] = None,
- local_build_config: Optional[BuildConfig] = None,
- cloud_build_config: Optional[BuildConfig] = None,
- cloud_compute: Optional[CloudCompute] = None,
- run_once: Optional[bool] = None, # TODO: Remove run_once
- start_with_flow: bool = True,
- ):
- """LightningWork, or Work in short, is a building block for long-running jobs.
-
- The LightningApp runs its :class:`~lightning.app.core.flow.LightningFlow` component
- within an infinite loop and track the ``LightningWork`` status update.
-
- Use LightningWork for third-party services or for launching heavy jobs such as
- downloading data, training or serving a model.
-
- Each LightningWork is running in its own independent process. Works are self-isolated from the rest,
- e.g any state changes happening within the work will be reflected within the flow but not the other way around.
-
- Arguments:
- parallel: Whether to run in parallel mode or not. When False, the flow waits for the work to finish.
- cache_calls: Whether the ``run`` method should cache its input arguments and not run again when provided
- with the same arguments in subsequent calls.
- raise_exception: Whether to re-raise an exception in the flow when raised from within the work run method.
- host: Bind socket to this host
- port: Bind socket to this port. Be default, this is None and should be called within your run method.
- local_build_config: The local BuildConfig isn't used until Lightning supports DockerRuntime.
- cloud_build_config: The cloud BuildConfig enables user to easily configure machine before running this work.
- run_once: Deprecated in favor of cache_calls. This will be removed soon.
- start_with_flow: Whether the work should be started at the same time as the root flow. Only applies to works
- defined in ``__init__``.
-
- **Learn More About Lightning Work Inner Workings**
-
- .. raw:: html
-
-
-
-
- .. displayitem::
- :header: The Lightning Work inner workings.
- :description: Learn more Lightning Work.
- :col_css: col-md-4
- :button_link: ../../core_api/lightning_work/index.html
- :height: 180
- :tag: Basic
-
- .. raw:: html
-
-
-
-
-
- """
- from lightning.app.runners.backends.backend import Backend
-
- if run_once is not None:
- warnings.warn(
- "The `run_once` argument to LightningWork is deprecated in favor of `cache_calls` and will be removed"
- " in the next version. Use `cache_calls` instead."
- )
- self._cache_calls = run_once if run_once is not None else cache_calls
- self._state = {
- "_host",
- "_port",
- "_url",
- "_future_url",
- "_internal_ip",
- "_public_ip",
- "_restarting",
- "_cloud_compute",
- "_display_name",
- }
- self._parallel: bool = parallel
- self._host: str = host
- self._port: Optional[int] = port
- self._url: str = ""
- self._future_url: str = "" # The cache URL is meant to defer resolving the url values.
- self._internal_ip: str = ""
- self._public_ip: str = ""
- # setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
- self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
- self._name: str = ""
- self._display_name: str = ""
- # The ``self._calls`` is used to track whether the run
- # method with a given set of input arguments has already been called.
- # Example of its usage:
- # {
- # 'latest_call_hash': '167fe2e',
- # '167fe2e': {
- # 'statuses': [
- # {'stage': 'pending', 'timestamp': 1659433519.851271},
- # {'stage': 'running', 'timestamp': 1659433519.956482},
- # {'stage': 'stopped', 'timestamp': 1659433520.055768}]}
- # ]
- # },
- # ...
- # }
- self._calls: dict = {CacheCallsKeys.LATEST_CALL_HASH: None}
- self._changes: dict = {}
- self._raise_exception = raise_exception
- self._paths: dict = {}
- self._request_queue: Optional[BaseQueue] = None
- self._response_queue: Optional[BaseQueue] = None
- self._restarting: bool = False
- self._start_with_flow = start_with_flow
- self._local_build_config = local_build_config or BuildConfig()
- self._cloud_build_config = cloud_build_config or BuildConfig()
- self._cloud_compute = cloud_compute or CloudCompute()
- # tuple instead of a list so that it cannot be modified without using the setter
- self._lightningignore: Tuple[str, ...] = ()
- self._backend: Optional[Backend] = None
- self._check_run_is_implemented()
- self._on_init_end()
-
- @property
- def url(self) -> str:
- """Returns the current url of the work."""
- return self._url
-
- @url.setter
- def url(self, url: str) -> None:
- self._url = url
-
- @property
- def host(self) -> str:
- """Returns the current host of the work."""
- return self._host
-
- @property
- def port(self) -> int:
- if self._port is None:
- self._port = find_free_network_port()
- return self._port
-
- @property
- def internal_ip(self) -> str:
- """The internal ip address of this LightningWork, reachable by other Work locally and in the cloud.
-
- By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
- Locally, the address is 127.0.0.1 and in the cloud it will be determined by the cluster.
-
- """
- return self._internal_ip
-
- @property
- def public_ip(self) -> str:
- """The public ip address of this LightningWork, reachable from the internet.
-
- By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
- Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
-
- """
- return self._public_ip
-
- def _on_init_end(self) -> None:
- self._local_build_config.on_work_init(self)
- self._cloud_build_config.on_work_init(self, self._cloud_compute)
-
- @staticmethod
- def _is_state_attribute(name: str) -> bool:
- """Every public attribute is part of the state by default and all protected (prefixed by '_') or private
- (prefixed by '__') attributes are not.
-
- Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
-
- """
- return name in LightningWork._INTERNAL_STATE_VARS or not name.startswith("_")
-
- @property
- def name(self) -> str:
- """Returns the name of the LightningWork."""
- return self._name
-
- @property
- def display_name(self) -> str:
- """Returns the display name of the LightningWork in the cloud.
-
- The display name needs to set before the run method of the work is called.
-
- """
- return self._display_name
-
- @display_name.setter
- def display_name(self, display_name: str) -> None:
- """Sets the display name of the LightningWork in the cloud."""
- if not self.has_started:
- self._display_name = display_name
- elif self._display_name != display_name:
- raise RuntimeError("The display name can be set only before the work has started.")
-
- @property
- def cache_calls(self) -> bool:
- """Returns whether the ``run`` method should cache its input arguments and not run again when provided with the
- same arguments in subsequent calls."""
- return self._cache_calls
-
- @property
- def parallel(self) -> bool:
- """Whether to run in parallel mode or not.
-
- When parallel is False, the flow waits for the work to finish.
-
- """
- return self._parallel
-
- @property
- def local_build_config(self) -> BuildConfig:
- return self._local_build_config
-
- @local_build_config.setter
- def local_build_config(self, build_config: BuildConfig) -> None:
- self._local_build_config = build_config
- self._local_build_config.on_work_init(self)
-
- @property
- def cloud_build_config(self) -> BuildConfig:
- """Returns the cloud build config used to prepare the selected cloud hardware."""
- return self._cloud_build_config
-
- @cloud_build_config.setter
- def cloud_build_config(self, build_config: BuildConfig) -> None:
- self._cloud_build_config = build_config
- self._cloud_build_config.on_work_init(self, cloud_compute=self._cloud_compute)
-
- @property
- def cloud_compute(self) -> CloudCompute:
- return self._cloud_compute
-
- @cloud_compute.setter
- def cloud_compute(self, cloud_compute: CloudCompute) -> None:
- """Returns the cloud compute used to select the cloud hardware."""
- # A new ID
- current_id = self._cloud_compute.id
- new_id = cloud_compute.id
- if current_id != new_id:
- compute_store: _CloudComputeStore = _CLOUD_COMPUTE_STORE[current_id]
- compute_store.remove(self.name)
- self._cloud_compute = cloud_compute
-
- @property
- def lightningignore(self) -> Tuple[str, ...]:
- """Programmatic equivalent of the ``.lightningignore`` file."""
- return self._lightningignore
-
- @lightningignore.setter
- def lightningignore(self, lightningignore: Tuple[str, ...]) -> None:
- if self._backend is not None:
- raise RuntimeError(
- f"Your app has been already dispatched, so modifying the `{self.name}.lightningignore` does not have an"
- " effect"
- )
- self._lightningignore = lightningignore
-
- @property
- def status(self) -> WorkStatus:
- """Return the current status of the work.
-
- All statuses are stored in the state.
-
- """
- call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
- if call_hash in self._calls:
- statuses = self._calls[call_hash]["statuses"]
- # deltas aren't necessarily coming in the expected order.
- statuses = sorted(statuses, key=lambda x: x["timestamp"])
- latest_status = statuses[-1]
- if latest_status.get("reason") == WorkFailureReasons.TIMEOUT:
- return self._aggregate_status_timeout(statuses)
- return WorkStatus(**latest_status)
- return WorkStatus(stage=WorkStageStatus.NOT_STARTED, timestamp=time.time())
-
- @property
- def statuses(self) -> List[WorkStatus]:
- """Return all the status of the work."""
- call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
- if call_hash in self._calls:
- statuses = self._calls[call_hash]["statuses"]
- # deltas aren't necessarily coming in the expected order.
- statuses = sorted(statuses, key=lambda x: x["timestamp"])
- return [WorkStatus(**status) for status in statuses]
- return []
-
- @property
- def has_started(self) -> bool:
- """Return whether the work has started."""
- return self.status.stage != WorkStageStatus.NOT_STARTED
-
- @property
- def has_stopped(self) -> bool:
- """Return whether the work has stopped."""
- return self.status.stage == WorkStageStatus.STOPPED
-
- @property
- def has_succeeded(self) -> bool:
- """Return whether the work has succeeded."""
- return self.status.stage == WorkStageStatus.SUCCEEDED
-
- @property
- def has_failed(self) -> bool:
- """Return whether the work has failed."""
- return self.status.stage == WorkStageStatus.FAILED
-
- @property
- def has_timeout(self) -> bool:
- """Return whether the work has time-out."""
- return self.has_failed and self.status.reason == WorkFailureReasons.TIMEOUT
-
- @property
- def is_running(self) -> bool:
- """Return whether the work is running."""
- return self.status.stage == WorkStageStatus.RUNNING
-
- @property
- def is_pending(self) -> bool:
- """Return whether the work is pending."""
- return self.status.stage == WorkStageStatus.PENDING
-
- @property
- def num_timeouts(self) -> int:
- """Return the number of timeout status since the lastest succeeded run."""
- status = self.status
- if status.reason == WorkFailureReasons.TIMEOUT:
- return status.count
- return 0
-
- @property
- def num_successes(self) -> int:
- """Returns the number of successful runs."""
- # FIXME: Resolve this within single process runtime.
- run_keys = [key for key in self._calls if key.startswith("run:")]
- if not run_keys:
- return 0
-
- has_succeeded_counter = 0
- for run_key in run_keys:
- c = len([s for s in self._calls[run_key]["statuses"] if s["stage"] == WorkStageStatus.SUCCEEDED])
- has_succeeded_counter += c
-
- return has_succeeded_counter
-
- def _get_property_if_exists(self, name: str) -> Union[property, None]:
- attr = getattr(self.__class__, name, None)
- return attr if isinstance(attr, property) else None
-
- def __setattr__(self, name: str, value: Any) -> None:
- property_object = self._get_property_if_exists(name)
- if property_object is not None and property_object.fset is not None:
- property_object.fset(self, value)
- else:
- setattr_fn = getattr(self, "_setattr_replacement", None) or self._default_setattr
- setattr_fn(name, value)
-
- def _default_setattr(self, name: str, value: Any) -> None:
- from lightning.app.core.flow import LightningFlow
-
- # Allow the run method to be patched with ProxyWorkRun (done by certain Runtime implementations).
- allowed_to_set_run = name == "run" and (
- isinstance(value, ProxyWorkRun)
- or (unwrap(value) == unwrap(self.run))
- or (isinstance(value, partial) and value.func.__name__ == "_dynamic_run_wrapper")
- )
-
- is_proxy_setattr = isinstance(value, LightningWorkSetAttrProxy)
- is_init_context = _is_init_context(self)
-
- if (
- not is_init_context
- and name not in self._state
- and name not in self._paths
- and self._is_state_attribute(name)
- and not allowed_to_set_run
- ):
- raise AttributeError(f"Cannot set attributes that were not defined in __init__: {name}.")
-
- if isinstance(value, str) and value.startswith("lit://"):
- value = Path(value)
-
- if self._is_state_attribute(name):
- if isinstance(value, (LightningFlow, LightningWork)):
- raise LightningWorkException(
- "A ``LightningWork`` isn't allowed to take any children "
- f"such as ``LightningWork`` or ``LightningFlow``. Found {value}."
- )
-
- if isinstance(value, Path):
- value._attach_work(work=self)
- value._attach_queues(self._request_queue, self._response_queue) # type: ignore[arg-type]
- value._name = name
- # In the init context, the full name of the Flow and Work is not known, i.e., we can't serialize
- # the path without losing the information of origin and consumer. Hence, we delay the serialization
- # of the path object until the app is instantiated.
- if not is_init_context:
- self._paths[name] = value.to_dict()
- self._state.add(name)
-
- elif isinstance(value, Payload):
- if is_init_context:
- raise AttributeError("The Payload object should be set only within the run method of the work.")
- value._attach_work(work=self)
- value._name = name
- self._state.add(name)
-
- elif isinstance(value, Drive):
- value = deepcopy(value)
- value.component_name = self.name
- self._state.add(name)
-
- elif allowed_to_set_run or is_proxy_setattr:
- # enable overriding the run method (dispatcher)
- pass
-
- elif _is_json_serializable(value):
- self._state.add(name)
-
- else:
- raise AttributeError(
- f"Only JSON-serializable attributes are currently supported"
- f" (str, int, float, bool, tuple, list, dict etc.) to be part of {self} state. "
- f"Found the attribute {name} with {value} instead. \n"
- "HINT: Private attributes defined as follows `self._x = y` won't be shared between components "
- "and therefore don't need to be JSON-serializable. If you need to include non-JSON serializable "
- "objects in the state, you can use the `lightning.app.storage.Payload` API."
- )
-
- super().__setattr__(name, value)
-
- def __getattribute__(self, name: str) -> Any:
- try:
- attr = object.__getattribute__(self, name)
- except AttributeError as ex:
- if str(ex).endswith("'_state'"):
- raise AttributeError(f"Did you forget to call super().__init__() in {self}")
- raise ex
-
- if isinstance(attr, ProxyWorkRun):
- return attr
-
- if callable(attr) and getattr(attr, "__name__", "") == "run" and getattr(self, "_cache_calls", False):
- # disable while building the class.
- return self._wrap_run_for_caching(attr)
- return attr
-
- def __getattr__(self, item: str) -> Any:
- if item in self.__dict__.get("_paths", {}) and not _is_init_context(self):
- path = Path.from_dict(self._paths[item])
- path._attach_work(work=self)
- path._attach_queues(self._request_queue, self._response_queue) # type: ignore[arg-type]
- return path
- return self.__getattribute__(item)
-
- def _call_hash(self, fn: Callable, args: Any, kwargs: Any) -> str:
- hash_args = args[1:] if len(args) > 0 and args[0] == self else args
- call_obj = {"args": hash_args, "kwargs": kwargs}
- # Note: Generate a hash as 167fe2e.
- # Seven was selected after checking upon Github default SHA length
- # and to minimize hidden state size.
- return str(DeepHash(call_obj)[call_obj])[:7]
-
- def _wrap_run_for_caching(self, fn: Callable) -> Callable:
- @wraps(fn)
- def new_fn(*args: Any, **kwargs: Any) -> Any:
- call_hash = self._call_hash(fn, args, kwargs)
-
- entered = call_hash in self._calls
- returned = entered and "ret" in self._calls[call_hash]
-
- if returned:
- entry = self._calls[call_hash]
- return entry["ret"]
-
- self._calls[call_hash] = {}
-
- result = fn(*args, **kwargs)
-
- self._calls[call_hash] = {"ret": result}
-
- return result
-
- return new_fn
-
- @property
- def changes(self) -> dict:
- return self._changes.copy()
-
- @property
- def state(self) -> dict:
- """Returns the current state of this LightningWork."""
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- # this may have the challenge that ret cannot be pickled, we'll need to handle this
- "calls": self._calls.copy(),
- "changes": {},
- }
-
- @property
- def state_vars(self) -> dict:
- return {"vars": _sanitize_state({el: getattr(self, el) for el in self._state})}
-
- @property
- def state_with_changes(self) -> dict:
- return {
- "vars": _sanitize_state({el: getattr(self, el) for el in self._state}),
- # this may have the challenge that ret cannot be pickled, we'll need to handle this
- "calls": self._calls.copy(),
- "changes": self.changes,
- }
-
- def set_state(self, provided_state: dict) -> None:
- for k, v in provided_state["vars"].items():
- if isinstance(v, Dict):
- v = _maybe_create_drive(self.name, v)
- if isinstance(v, Dict):
- v = _maybe_create_cloud_compute(v)
- setattr(self, k, v)
-
- self._changes = provided_state["changes"]
-
- # Note, this is handled by the flow only.
- if _is_flow_context():
- self._cleanup_calls(provided_state["calls"])
-
- self._calls = provided_state["calls"]
-
- @staticmethod
- def _cleanup_calls(calls: Dict[str, Any]) -> None:
- # 1: Collect all the in_progress call hashes
- in_progress_call_hash = [k for k in list(calls) if k not in (CacheCallsKeys.LATEST_CALL_HASH)]
-
- for call_hash in in_progress_call_hash:
- if "statuses" not in calls[call_hash]:
- continue
-
- # 2: Filter the statuses by timestamp
- statuses = sorted(calls[call_hash]["statuses"], key=lambda x: x["timestamp"])
-
- # If the latest status is succeeded, then drop everything before.
- if statuses[-1]["stage"] == WorkStageStatus.SUCCEEDED:
- status = statuses[-1]
- status["timestamp"] = int(status["timestamp"])
- calls[call_hash]["statuses"] = [status]
- else:
- # TODO: Some status are being duplicated,
- # this seems related to the StateObserver.
- final_statuses = []
- for status in statuses:
- if status not in final_statuses:
- final_statuses.append(status)
- calls[call_hash]["statuses"] = final_statuses
-
- def start(self) -> None:
- """Starts LightingWork component via CloudCompute."""
- if self.status.stage == WorkStageStatus.STOPPED:
- raise Exception("A work can be started only once for now.")
-
- # This enables to start the run method with a phony input and exit.
- self.run(Action(method="start"))
-
- def run(self, *args: Any, **kwargs: Any) -> None:
- """Override to add your own logic.
-
- Raises:
- LightningPlatformException: If resource exceeds platform quotas or other constraints.
-
- """
-
- def on_exception(self, exception: BaseException) -> None:
- """Override to customize how to handle exception in the run method."""
- if self._raise_exception:
- raise exception
-
- def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
- """Method used to return the first request and the total count of timeout after the latest succeeded status."""
- succeeded_statuses = [
- status_idx for status_idx, status in enumerate(statuses) if status["stage"] == WorkStageStatus.SUCCEEDED
- ]
- if succeeded_statuses:
- succeed_status_id = succeeded_statuses[-1] + 1
- statuses = statuses[succeed_status_id:]
- timeout_statuses = [status for status in statuses if status.get("reason") == WorkFailureReasons.TIMEOUT]
- assert statuses[0]["stage"] == WorkStageStatus.PENDING
- status = {**timeout_statuses[-1], "timestamp": statuses[0]["timestamp"]}
- return WorkStatus(**status, count=len(timeout_statuses))
-
- def on_exit(self) -> None:
- """Override this hook to add your logic when the work is exiting.
-
- Note: This hook is not guaranteed to be called when running in the cloud.
-
- """
- pass
-
- def stop(self) -> None:
- """Stops LightingWork component and shuts down hardware provisioned via CloudCompute.
-
- This can only be called from a ``LightningFlow``.
-
- """
- if not self._backend:
- raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.")
- if self.status.stage == WorkStageStatus.STOPPED:
- return
- latest_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
- stop_status = make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING)
- self._calls[latest_hash]["statuses"].append(stop_status)
- app = _LightningAppRef().get_current()
- self._backend.stop_work(app, self) # type: ignore[arg-type]
-
- def delete(self) -> None:
- """Delete LightingWork component and shuts down hardware provisioned via CloudCompute.
-
- Locally, the work.delete() behaves as work.stop().
-
- """
- if not self._backend:
- raise Exception(
- "Can't delete the work, it looks like it isn't attached to a LightningFlow. "
- "Make sure to assign the Work to a flow instance."
- )
- app = _LightningAppRef().get_current()
- self._backend.delete_work(app, self)
-
- def _check_run_is_implemented(self) -> None:
- if not is_overridden("run", instance=self, parent=LightningWork):
- raise TypeError(
- f"The work `{self.__class__.__name__}` is missing the `run()` method. This is required. Implement it"
- " first and then call it in your Flow."
- )
-
- def _register_cloud_compute(self) -> None:
- internal_id = self.cloud_compute.id
- assert internal_id
- if internal_id not in _CLOUD_COMPUTE_STORE:
- _CLOUD_COMPUTE_STORE[internal_id] = _CloudComputeStore(id=internal_id, component_names=[])
- _CLOUD_COMPUTE_STORE[internal_id].add_component_name(self.name)
-
- def apply_flow_delta(self, delta: Delta) -> None:
- """Override to customize how the flow should update the work state."""
- # TODO: Add support for thread safe locking over JSON Serializable objects.
- if any(k not in ["values_changed", "type_changed"] for k in delta.to_dict()):
- raise Exception(
- "A forbidden operation to update the work from the flow was detected."
- f" Found {delta.to_dict()}, only `values_changed` and `type_changes` are currently allowed."
- )
-
- vars = self.state["vars"] + delta
- for name, value in vars.items():
- property_object = self._get_property_if_exists(name)
- if property_object is not None and property_object.fset is not None:
- property_object.fset(self, value)
- else:
- self._default_setattr(name, value)
-
- def configure_layout(self) -> Union[None, str, "Frontend"]:
- """Configure the UI of this LightningWork.
-
- You can either
-
- 1. Return a single :class:`~lightning.app.frontend.frontend.Frontend` object to serve a user interface
- for this Work.
- 2. Return a string containing a URL to act as the user interface for this Work.
- 3. Return ``None`` to indicate that this Work doesn't currently have a user interface.
-
- **Example:** Serve a static directory (with at least a file index.html inside).
-
- .. code-block:: python
-
- from lightning.app.frontend import StaticWebFrontend
-
-
- class Work(LightningWork):
- def configure_layout(self):
- return StaticWebFrontend("path/to/folder/to/serve")
-
- **Example:** Arrange the UI of my children in tabs (default UI by Lightning).
-
- .. code-block:: python
-
- class Work(LightningWork):
- def configure_layout(self):
- return [
- dict(name="First Tab", content=self.child0),
- dict(name="Second Tab", content=self.child1),
- dict(name="Lightning", content="https://lightning.ai"),
- ]
-
- If you don't implement ``configure_layout``, Lightning will use ``self.url``.
-
- Note:
- This hook gets called at the time of app creation and then again as part of the loop. If desired, a
- returned URL can depend on the state. This is not the case if the work returns a
- :class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
- in order for the runtime to start the server.
-
- """
diff --git a/src/lightning/app/frontend/__init__.py b/src/lightning/app/frontend/__init__.py
deleted file mode 100644
index b4f5f5a1ba022..0000000000000
--- a/src/lightning/app/frontend/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.frontend.just_py.just_py import JustPyFrontend
-from lightning.app.frontend.panel import AppStateWatcher, PanelFrontend
-from lightning.app.frontend.stream_lit import StreamlitFrontend
-from lightning.app.frontend.web import StaticWebFrontend
-
-__all__ = ["AppStateWatcher", "Frontend", "JustPyFrontend", "PanelFrontend", "StaticWebFrontend", "StreamlitFrontend"]
diff --git a/src/lightning/app/frontend/frontend.py b/src/lightning/app/frontend/frontend.py
deleted file mode 100644
index fef945f6ff690..0000000000000
--- a/src/lightning/app/frontend/frontend.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Optional
-
-if TYPE_CHECKING:
- from lightning.app.core.flow import LightningFlow
-
-
-class Frontend(ABC):
- """Base class for any frontend that gets exposed by LightningFlows.
-
- The flow attribute will be set by the app while bootstrapping.
-
- """
-
- def __init__(self) -> None:
- self.flow: Optional["LightningFlow"] = None
-
- @abstractmethod
- def start_server(self, host: str, port: int, root_path: str = "") -> None:
- """Start the process that serves the UI at the given hostname and port number.
-
- Arguments:
- host: The hostname where the UI will be served. This gets determined by the dispatcher (e.g., cloud),
- but defaults to localhost when running locally.
- port: The port number where the UI will be served. This gets determined by the dispatcher, which by default
- chooses any free port when running locally.
- root_path: root_path for the server if app in exposed via a proxy at `/`
-
-
- Example:
-
- An custom implementation could look like this:
-
- .. code-block:: python
-
- def start_server(self, host, port, root_path=""):
- self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)])
-
- """
-
- @abstractmethod
- def stop_server(self) -> None:
- """Stop the process that was started with :meth:`start_server` so the App can shut down.
-
- This method gets called when the LightningApp terminates.
-
- Example:
-
- .. code-block:: python
-
- def stop_server(self):
- self._process.kill()
-
- """
diff --git a/src/lightning/app/frontend/just_py/__init__.py b/src/lightning/app/frontend/just_py/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/frontend/just_py/just_py.py b/src/lightning/app/frontend/just_py/just_py.py
deleted file mode 100644
index 11a9d55799544..0000000000000
--- a/src/lightning/app/frontend/just_py/just_py.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import os
-import sys
-from subprocess import Popen
-from time import sleep
-from typing import Callable, Optional
-
-import lightning.app
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.utilities.log import get_logfile
-
-
-class JustPyFrontend(Frontend):
- """A frontend for wrapping JustPy code in your LightingFlow.
-
- Return this in your `LightningFlow.configure_layout()` method if you wish to build the UI with ``justpy``.
- To use this frontend, you must first install the `justpy` package (if running locally):
-
- .. code-block:: bash
-
- pip install justpy
-
- Arguments:
- render_fn: A function that contains your justpy code. This function must accept exactly one argument, the
- ``AppState`` object which you can use to access variables in your flow (see example below).
-
- Example:
-
- In your LightningFlow, override the method `configure_layout`:
-
- .. code-block:: python
-
- from typing import Callable
- from lightning import LightningApp, LightningFlow
- from lightning.app.frontend import JustPyFrontend
-
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- print(self.counter)
-
- def configure_layout(self):
- return JustPyFrontend(render_fn=render_fn)
-
-
- def render_fn(get_state: Callable) -> Callable:
- import justpy as jp
-
- def my_click(self, *_):
- state = get_state()
- old_counter = state.counter
- state.counter += 1
- self.text = f"Click Me ! Old Counter: {old_counter} New Counter: {state.counter}"
-
- def webpage():
- wp = jp.WebPage()
- d = jp.Div(text="Hello ! Click Me!")
- d.on("click", my_click)
- wp.add(d)
- return wp
-
- return webpage
-
-
- app = LightningApp(Flow())
-
- """
-
- def __init__(self, render_fn: Callable) -> None:
- super().__init__()
-
- if inspect.ismethod(render_fn):
- raise TypeError(
- "The `JustPyFrontend` doesn't support `render_fn` being a method. Please, use a pure function."
- )
-
- self.render_fn = render_fn
- self._process: Optional[Popen] = None
-
- def start_server(self, host: str, port: int, root_path: str = "") -> None:
- env = os.environ.copy()
- env["LIGHTNING_FLOW_NAME"] = self.flow.name # type: ignore
- env["LIGHTNING_RENDER_FUNCTION"] = self.render_fn.__name__
- env["LIGHTNING_RENDER_MODULE_FILE"] = inspect.getmodule(self.render_fn).__file__ # type: ignore
- env["LIGHTNING_HOST"] = host
- env["LIGHTNING_PORT"] = str(port)
- std_out_out = get_logfile("output.log")
- path = os.path.join(os.path.dirname(lightning.app.frontend.just_py.__file__), "just_py_base.py")
- with open(std_out_out, "wb") as stdout:
- self._process = Popen(f"{sys.executable} {path}", env=env, stdout=stdout, stderr=sys.stderr, shell=True)
-
- sleep(1)
-
- def stop_server(self) -> None:
- assert self._process
- self._process.terminate()
diff --git a/src/lightning/app/frontend/just_py/just_py_base.py b/src/lightning/app/frontend/just_py/just_py_base.py
deleted file mode 100644
index ce6d009ef7905..0000000000000
--- a/src/lightning/app/frontend/just_py/just_py_base.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import pydoc
-from typing import Any, Callable
-
-from lightning.app.frontend.utils import _reduce_to_flow_scope
-from lightning.app.utilities.state import AppState
-
-
-def _get_state() -> AppState:
- app_state = AppState()
- return _reduce_to_flow_scope(app_state, flow=os.environ["LIGHTNING_FLOW_NAME"])
-
-
-def _webpage() -> Any:
- import justpy as jp
-
- wp = jp.WebPage()
- d = jp.Div(text="")
- wp.add(d)
- return wp
-
-
-def _get_render_fn_from_environment() -> Callable:
- render_fn_name = os.environ["LIGHTNING_RENDER_FUNCTION"]
- render_fn_module_file = os.environ["LIGHTNING_RENDER_MODULE_FILE"]
- module = pydoc.importfile(render_fn_module_file)
- return getattr(module, render_fn_name)
-
-
-def _main() -> None:
- """Run the render_fn with the current flow_state."""
- import justpy as jp
-
- # Fetch the information of which flow attaches to this justpy instance
- flow_name = os.environ["LIGHTNING_FLOW_NAME"]
-
- # Call the provided render function.
- # Pass it the state, scoped to the current flow.
- render_fn = _get_render_fn_from_environment()
- host = os.environ["LIGHTNING_HOST"]
- port = int(os.environ["LIGHTNING_PORT"])
- entry_fn = render_fn(_get_state)
- if not isinstance(entry_fn, Callable): # type: ignore
- raise Exception("You need to return a function with JustPy Frontend.")
-
- jp.app.add_jproute(f"/{flow_name}", entry_fn)
-
- jp.justpy(_webpage, host=host, port=port)
-
-
-if __name__ == "__main__":
- _main()
diff --git a/src/lightning/app/frontend/panel/__init__.py b/src/lightning/app/frontend/panel/__init__.py
deleted file mode 100644
index 8d11a3d2815c4..0000000000000
--- a/src/lightning/app/frontend/panel/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework."""
-
-from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
-from lightning.app.frontend.panel.panel_frontend import PanelFrontend
-
-__all__ = ["AppStateWatcher", "PanelFrontend"]
diff --git a/src/lightning/app/frontend/panel/app_state_comm.py b/src/lightning/app/frontend/panel/app_state_comm.py
deleted file mode 100644
index ae3316167d4a4..0000000000000
--- a/src/lightning/app/frontend/panel/app_state_comm.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""The watch_app_state function enables us to trigger a callback function when ever the app state changes."""
-
-# Todo: Refactor with Streamlit
-# Note: It would be nice one day to just watch changes within the Flow scope instead of whole app
-from __future__ import annotations
-
-import asyncio
-import os
-from threading import Thread
-from typing import Callable
-
-import websockets
-
-from lightning.app.core.constants import APP_SERVER_PORT
-from lightning.app.utilities.app_helpers import Logger
-
-_logger = Logger(__name__)
-
-_CALLBACKS = []
-_THREAD: Thread = None
-
-
-def _get_ws_port():
- if "LIGHTNING_APP_STATE_URL" in os.environ:
- return 8080
- return APP_SERVER_PORT
-
-
-def _get_ws_url():
- port = _get_ws_port()
- return f"ws://localhost:{port}/api/v1/ws"
-
-
-def _run_callbacks():
- for callback in _CALLBACKS:
- callback()
-
-
-def _target_fn():
- async def update_fn():
- ws_url = _get_ws_url()
- _logger.debug("connecting to web socket %s", ws_url)
- async with websockets.connect(ws_url) as websocket: # pylint: disable=no-member
- while True:
- await websocket.recv()
- # Note: I have not seen use cases where the two lines below are needed
- # Changing '< 0.2' to '< 1' makes the App very sluggish to the end user
- # Also the implementation can cause the App state to lag behind because only 1 update
- # is received per 0.2 second (or 1 second).
- # while (time.time() - last_updated) < 0.2:
- # time.sleep(0.05)
-
- # Todo: Add some kind of throttling. If 10 messages are received within 100ms then
- # there is no need to trigger the app state changed, request state and update
- # 10 times.
- _logger.debug("App State Changed. Running callbacks")
- _run_callbacks()
-
- asyncio.run(update_fn())
-
-
-def _start_websocket():
- global _THREAD # pylint: disable=global-statement
- if not _THREAD:
- _logger.debug("Starting the watch_app_state thread.")
- _THREAD = Thread(target=_target_fn)
- _THREAD.setDaemon(True)
- _THREAD.start()
- _logger.debug("thread started")
-
-
-def _watch_app_state(callback: Callable):
- """Start the process that serves the UI at the given hostname and port number.
-
- Arguments:
- callback: A function to run when the App state changes. Must be thread safe.
-
- Example:
-
- .. code-block:: python
-
- def handle_state_change():
- print("The App State changed.")
- watch_app_state(handle_state_change)
-
- """
- _CALLBACKS.append(callback)
- _start_websocket()
diff --git a/src/lightning/app/frontend/panel/app_state_watcher.py b/src/lightning/app/frontend/panel/app_state_watcher.py
deleted file mode 100644
index 8abc9cb52e272..0000000000000
--- a/src/lightning/app/frontend/panel/app_state_watcher.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""The ``AppStateWatcher`` enables a Frontend to:
-
-- subscribe to App state changes
-- to access and change the App state.
-
-This is particularly useful for the ``PanelFrontend`` but can be used by other frontends too.
-
-"""
-
-from __future__ import annotations
-
-import os
-
-from lightning.app.frontend.panel.app_state_comm import _watch_app_state
-from lightning.app.frontend.utils import _get_flow_state
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.imports import _is_param_available, requires
-from lightning.app.utilities.state import AppState
-
-_logger = Logger(__name__)
-
-
-if _is_param_available():
- from param import ClassSelector, Parameterized, edit_constant
-else:
- Parameterized = object
- ClassSelector = dict
-
-
-class AppStateWatcher(Parameterized):
- """The `AppStateWatcher` enables a Frontend to:
-
- - Subscribe to any App state changes.
- - To access and change the App state from the UI.
-
- This is particularly useful for the `PanelFrontend , but can be used by
- other frontends too.
-
- Example
- -------
-
- .. code-block:: python
-
- import param
-
- app = AppStateWatcher()
-
- app.state.counter = 1
-
-
- @param.depends(app.param.state, watch=True)
- def update(state):
- print(f"The counter was updated to {state.counter}")
-
-
- app.state.counter += 1
-
- This would print ``The counter was updated to 2``.
-
- The ``AppStateWatcher`` is built on top of Param, which is a framework like dataclass, attrs and
- Pydantic which additionally provides powerful and unique features for building reactive apps.
-
- Please note the ``AppStateWatcher`` is a singleton, i.e., only one instance is instantiated
-
- """
-
- state: AppState = ClassSelector(
- class_=AppState,
- constant=True,
- doc="The AppState holds the state of the app reduced to the scope of the Flow",
- )
-
- def __new__(cls):
- # This makes the AppStateWatcher a *singleton*.
- # The AppStateWatcher is a singleton to minimize the number of requests etc..
- if not hasattr(cls, "_instance"):
- cls._instance = super().__new__(cls)
- return cls._instance
-
- @requires("param")
- def __init__(self):
- # It is critical to initialize only once
- # See https://github.com/holoviz/param/issues/643
- if not hasattr(self, "_initialized"):
- super().__init__(name="singleton")
- self._start_watching()
- self.param.state.allow_None = False
- self._initialized = True
-
- # The below was observed when using mocks during testing
- if not self.state:
- raise Exception(".state has not been set.")
- if not self.state._state:
- raise Exception(".state._state has not been set.")
-
- def _start_watching(self):
- # Create a thread listening to state changes.
- _watch_app_state(self._update_flow_state)
- self._update_flow_state()
-
- def _get_flow_state(self) -> AppState:
- flow = os.environ["LIGHTNING_FLOW_NAME"]
- return _get_flow_state(flow)
-
- def _update_flow_state(self):
- # Todo: Consider whether to only update if ._state changed
- # This might be much more performant.
- with edit_constant(self):
- self.state = self._get_flow_state()
- _logger.debug("Requested App State.")
diff --git a/src/lightning/app/frontend/panel/panel_frontend.py b/src/lightning/app/frontend/panel/panel_frontend.py
deleted file mode 100644
index 6185ab7ce6429..0000000000000
--- a/src/lightning/app/frontend/panel/panel_frontend.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""The PanelFrontend wraps your Panel code in your LightningFlow."""
-
-from __future__ import annotations
-
-import inspect
-import os
-import pathlib
-import subprocess
-import sys
-from typing import Callable, TextIO
-
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.frontend.utils import _get_frontend_environment
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cloud import is_running_in_cloud
-from lightning.app.utilities.imports import requires
-from lightning.app.utilities.log import get_logfile
-
-_logger = Logger(__name__)
-
-
-def _has_panel_autoreload() -> bool:
- """Returns True if the PANEL_AUTORELOAD environment variable is set to 'yes' or 'true'.
-
- Please note the casing of value does not matter
-
- """
- return os.environ.get("PANEL_AUTORELOAD", "no").lower() in ["yes", "y", "true"]
-
-
-class PanelFrontend(Frontend):
- """The `PanelFrontend` enables you to serve Panel code as a Frontend for your LightningFlow.
-
- Reference: https://lightning.ai/lightning-docs/workflows/add_web_ui/panel/
-
- Args:
- entry_point: The path to a .py or .ipynb file, or a pure function. The file or function must contain your Panel
- code. The function can optionally accept an ``AppStateWatcher`` argument.
-
- Raises:
- TypeError: Raised if the ``entry_point`` provided is a class method
-
- Example:
-
- To use the `PanelFrontend`, you must first install the `panel` package:
-
- .. code-block:: bash
-
- pip install panel
-
- Create the files `panel_app_basic.py` and `app_basic.py` with the content below.
-
- **panel_app_basic.py**
-
- .. code-block:: python
-
- import panel as pn
-
- pn.panel("Hello **Panel ⚡** World").servable()
-
- **app_basic.py**
-
- .. code-block:: python
-
- from lightning.app import LightningFlow, LightningApp
- from lightning.app.frontend.panel import PanelFrontend
-
-
- class LitPanel(LightningFlow):
- def configure_layout(self):
- return PanelFrontend("panel_app_basic.py")
-
-
- class LitApp(LightningFlow):
- def __init__(self):
- super().__init__()
- self.lit_panel = LitPanel()
-
- def configure_layout(self):
- return {"name": "home", "content": self.lit_panel}
-
-
- app = LightningApp(LitApp())
-
- Start the Lightning server with `lightning run app app_basic.py`.
-
- For development you can get Panel autoreload by setting the ``PANEL_AUTORELOAD``
- environment variable to 'yes', i.e. run
- ``PANEL_AUTORELOAD=yes lightning run app app_basic.py``
-
- """
-
- @requires("panel")
- def __init__(self, entry_point: str | Callable):
- super().__init__()
-
- if inspect.ismethod(entry_point):
- raise TypeError(
- "The `PanelFrontend` doesn't support `entry_point` being a method. Please, use a pure function."
- )
-
- self.entry_point = entry_point
- self._process: None | subprocess.Popen = None
- self._log_files: dict[str, TextIO] = {}
- _logger.debug("PanelFrontend Frontend with %s is initialized.", entry_point)
-
- def start_server(self, host: str, port: int, root_path: str = "") -> None:
- _logger.debug("PanelFrontend starting server on %s:%s", host, port)
-
- # 1: Prepare environment variables and arguments.
- env = _get_frontend_environment(
- self.flow.name,
- self.entry_point,
- port,
- host,
- )
- command = self._get_popen_args(host, port)
-
- if is_running_in_cloud():
- self._open_log_files()
-
- self._process = subprocess.Popen(command, env=env, **self._log_files) # pylint: disable=consider-using-with
-
- def stop_server(self) -> None:
- if self._process is None:
- raise RuntimeError("Server is not running. Call `PanelFrontend.start_server()` first.")
- self._process.kill()
- self._close_log_files()
-
- def _close_log_files(self):
- for file_ in self._log_files.values():
- if not file_.closed:
- file_.close()
- self._log_files = {}
-
- def _open_log_files(self) -> None:
- # Don't log to file when developing locally. Makes it harder to debug.
- self._close_log_files()
-
- std_err_out = get_logfile("error.log")
- std_out_out = get_logfile("output.log")
- stderr = std_err_out.open("wb")
- stdout = std_out_out.open("wb")
- self._log_files = {"stdout": stderr, "stderr": stdout}
-
- def _get_popen_args(self, host: str, port: int) -> list:
- if callable(self.entry_point):
- path = str(pathlib.Path(__file__).parent / "panel_serve_render_fn.py")
- else:
- path = pathlib.Path(self.entry_point)
-
- abs_path = str(path)
- # The app is served at http://localhost:{port}/{flow}/{entry_point}
- # Lightning embeds http://localhost:{port}/{flow} but this redirects to the above and
- # seems to work fine.
- command = [
- sys.executable,
- "-m",
- "panel",
- "serve",
- abs_path,
- "--port",
- str(port),
- "--address",
- host,
- "--prefix",
- self.flow.name,
- "--allow-websocket-origin",
- _get_allowed_hosts(),
- ]
- if _has_panel_autoreload():
- command.append("--autoreload")
- _logger.debug("PanelFrontend command %s", command)
- return command
-
-
-def _get_allowed_hosts() -> str:
- """Returns a comma separated list of host[:port] that should be allowed to connect."""
- # TODO: Enable only lightning.ai domain in the cloud
- return "*"
diff --git a/src/lightning/app/frontend/panel/panel_serve_render_fn.py b/src/lightning/app/frontend/panel/panel_serve_render_fn.py
deleted file mode 100644
index 4ba8de45813a0..0000000000000
--- a/src/lightning/app/frontend/panel/panel_serve_render_fn.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""This file gets run by Python to launch a Panel Server with Lightning.
-
-We will call the ``render_fn`` that the user provided to the PanelFrontend.
-
-It requires the following environment variables to be set
-
-
-- LIGHTNING_RENDER_FUNCTION
-- LIGHTNING_RENDER_MODULE_FILE
-
-Example:
-
-.. code-block:: bash
-
- python panel_serve_render_fn
-
-"""
-
-import inspect
-import os
-import pydoc
-from typing import Callable
-
-from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
-
-
-def _get_render_fn_from_environment(render_fn_name: str, render_fn_module_file: str) -> Callable:
- """Returns the render_fn function to serve in the Frontend."""
- module = pydoc.importfile(render_fn_module_file)
- return getattr(module, render_fn_name)
-
-
-def _get_render_fn():
- render_fn_name = os.environ["LIGHTNING_RENDER_FUNCTION"]
- render_fn_module_file = os.environ["LIGHTNING_RENDER_MODULE_FILE"]
- render_fn = _get_render_fn_from_environment(render_fn_name, render_fn_module_file)
- if inspect.signature(render_fn).parameters:
-
- def _render_fn_wrapper():
- app = AppStateWatcher()
- return render_fn(app)
-
- return _render_fn_wrapper
- return render_fn
-
-
-def _main():
- import panel as pn
-
- # I use caching for efficiency reasons. It shaves off 10ms from having
- # to get_render_fn_from_environment every time
- if "lightning_render_fn" not in pn.state.cache:
- pn.state.cache["lightning_render_fn"] = _get_render_fn()
- pn.state.cache["lightning_render_fn"]()
-
-
-if __name__.startswith("bokeh"):
- _main()
diff --git a/src/lightning/app/frontend/stream_lit.py b/src/lightning/app/frontend/stream_lit.py
deleted file mode 100644
index 0cdc37296d931..0000000000000
--- a/src/lightning/app/frontend/stream_lit.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import os
-import subprocess
-import sys
-from typing import Callable, Optional
-
-import lightning.app
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.utilities.cloud import is_running_in_cloud
-from lightning.app.utilities.imports import requires
-from lightning.app.utilities.log import get_logfile
-
-
-class StreamlitFrontend(Frontend):
- """A frontend for wrapping Streamlit code in your LightingFlow.
-
- Return this in your `LightningFlow.configure_layout()` method if you wish to build the UI with ``streamlit``.
- To use this frontend, you must first install the `streamlit` package (if running locally):
-
- .. code-block:: bash
-
- pip install streamlit
-
- Arguments:
- render_fn: A function that contains your streamlit code. This function must accept exactly one argument, the
- `AppState` object which you can use to access variables in your flow (see example below).
-
- Example:
-
- In your LightningFlow, override the method `configure_layout`:
-
- .. code-block:: python
-
- class MyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=my_streamlit_ui)
-
-
- # define this function anywhere you want
- # this gets called anytime the UI needs to refresh
- def my_streamlit_ui(state):
- import streamlit as st
-
- st.write("Hello from streamlit!")
- st.write(state.counter)
-
- """
-
- @requires("streamlit")
- def __init__(self, render_fn: Callable) -> None:
- super().__init__()
-
- if inspect.ismethod(render_fn):
- raise TypeError(
- "The `StreamlitFrontend` doesn't support `render_fn` being a method. Please, use a pure function."
- )
-
- self.render_fn = render_fn
- self._process: Optional[subprocess.Popen] = None
-
- def start_server(self, host: str, port: int) -> None:
- env = os.environ.copy()
- env["LIGHTNING_FLOW_NAME"] = self.flow.name
- env["LIGHTNING_RENDER_FUNCTION"] = self.render_fn.__name__
- env["LIGHTNING_RENDER_MODULE_FILE"] = inspect.getmodule(self.render_fn).__file__
- std_err_out = get_logfile("error.log")
- std_out_out = get_logfile("output.log")
- with open(std_err_out, "wb") as stderr, open(std_out_out, "wb") as stdout:
- self._process = subprocess.Popen(
- [
- sys.executable,
- "-m",
- "streamlit",
- "run",
- os.path.join(os.path.dirname(lightning.app.frontend.__file__), "streamlit_base.py"),
- "--server.address",
- str(host),
- "--server.port",
- str(port),
- "--server.baseUrlPath",
- self.flow.name,
- "--server.headless",
- "true", # do not open the browser window when running locally
- "--server.enableXsrfProtection",
- "true" if is_running_in_cloud() else "false",
- ],
- env=env,
- stdout=stdout,
- stderr=stderr,
- )
-
- def stop_server(self) -> None:
- if self._process is None:
- raise RuntimeError("Server is not running. Call `StreamlitFrontend.start_server()` first.")
- self._process.kill()
diff --git a/src/lightning/app/frontend/streamlit_base.py b/src/lightning/app/frontend/streamlit_base.py
deleted file mode 100644
index a0ebc4b7cf66d..0000000000000
--- a/src/lightning/app/frontend/streamlit_base.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""This file gets run by streamlit, which we launch within Lightning.
-
-From here, we will call the render function that the user provided in ``configure_layout``.
-
-"""
-
-import os
-import pydoc
-from typing import Callable
-
-from lightning.app.frontend.utils import _reduce_to_flow_scope
-from lightning.app.utilities.app_helpers import StreamLitStatePlugin
-from lightning.app.utilities.state import AppState
-
-
-def _get_render_fn_from_environment() -> Callable:
- render_fn_name = os.environ["LIGHTNING_RENDER_FUNCTION"]
- render_fn_module_file = os.environ["LIGHTNING_RENDER_MODULE_FILE"]
- module = pydoc.importfile(render_fn_module_file)
- return getattr(module, render_fn_name)
-
-
-def _main():
- """Run the render_fn with the current flow_state."""
- app_state = AppState(plugin=StreamLitStatePlugin())
-
- # Fetch the information of which flow attaches to this streamlit instance
- flow_state = _reduce_to_flow_scope(app_state, flow=os.environ["LIGHTNING_FLOW_NAME"])
-
- # Call the provided render function.
- # Pass it the state, scoped to the current flow.
- render_fn = _get_render_fn_from_environment()
- render_fn(flow_state)
-
-
-if __name__ == "__main__":
- _main()
diff --git a/src/lightning/app/frontend/utils.py b/src/lightning/app/frontend/utils.py
deleted file mode 100644
index c05f7e143fe11..0000000000000
--- a/src/lightning/app/frontend/utils.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Utility functions for lightning Frontends."""
-
-from __future__ import annotations
-
-import inspect
-import os
-from typing import Callable
-
-from lightning.app.core.flow import LightningFlow
-from lightning.app.utilities.state import AppState
-
-
-def _reduce_to_flow_scope(state: AppState, flow: str | LightningFlow) -> AppState:
- """Returns a new AppState with the scope reduced to the given flow."""
- flow_name = flow.name if isinstance(flow, LightningFlow) else flow
- flow_name_parts = flow_name.split(".")[1:] # exclude root
- flow_state = state
- for part in flow_name_parts:
- flow_state = getattr(flow_state, part)
- return flow_state
-
-
-def _get_flow_state(flow: str) -> AppState:
- """Returns an AppState scoped to the current Flow.
-
- Returns:
- AppState: An AppState scoped to the current Flow.
-
- """
- app_state = AppState()
- app_state._request_state() # pylint: disable=protected-access
- return _reduce_to_flow_scope(app_state, flow)
-
-
-def _get_frontend_environment(flow: str, render_fn_or_file: Callable | str, port: int, host: str) -> os._Environ:
- """Returns an _Environ with the environment variables for serving a Frontend app set.
-
- Args:
- flow: The name of the flow, for example root.lit_frontend
- render_fn_or_file: A function to render
- port: The port number, for example 54321
- host: The host, for example 'localhost'
-
- Returns:
- os._Environ: An environment
-
- """
- env = os.environ.copy()
- env["LIGHTNING_FLOW_NAME"] = flow
- env["LIGHTNING_RENDER_PORT"] = str(port)
- env["LIGHTNING_RENDER_ADDRESS"] = str(host)
-
- if isinstance(render_fn_or_file, str):
- env["LIGHTNING_RENDER_FILE"] = render_fn_or_file
- else:
- env["LIGHTNING_RENDER_FUNCTION"] = render_fn_or_file.__name__
- env["LIGHTNING_RENDER_MODULE_FILE"] = inspect.getmodule(render_fn_or_file).__file__
-
- return env
diff --git a/src/lightning/app/frontend/web.py b/src/lightning/app/frontend/web.py
deleted file mode 100644
index 2e7d9f3f2f8e3..0000000000000
--- a/src/lightning/app/frontend/web.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import multiprocessing as mp
-from argparse import ArgumentParser
-from typing import Optional
-from urllib.parse import urljoin
-
-import uvicorn
-from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
-from starlette.staticfiles import StaticFiles
-
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.utilities.log import get_logfile
-from lightning.app.utilities.network import find_free_network_port
-
-
-class StaticWebFrontend(Frontend):
- """A frontend that serves static files from a directory using FastAPI.
-
- Return this in your `LightningFlow.configure_layout()` method if you wish to serve a HTML page.
-
- Arguments:
- serve_dir: A local directory to serve files from. This directory should at least contain a file `index.html`.
- root_path: A path prefix when routing traffic from behind a proxy at `/`
-
- Example:
-
- In your LightningFlow, override the method `configure_layout`:
-
- .. code-block:: python
-
- def configure_layout(self):
- return StaticWebFrontend("path/to/folder/to/serve")
-
- """
-
- def __init__(self, serve_dir: str) -> None:
- super().__init__()
- self.serve_dir = serve_dir
- self._process: Optional[mp.Process] = None
-
- def start_server(self, host: str, port: int, root_path: str = "") -> None:
- log_file = str(get_logfile())
- self._process = mp.Process(
- target=_start_server,
- kwargs={
- "host": host,
- "port": port,
- "serve_dir": self.serve_dir,
- "path": f"/{self.flow.name}",
- "log_file": log_file,
- "root_path": root_path,
- },
- )
- self._process.start()
-
- def stop_server(self) -> None:
- if self._process is None:
- raise RuntimeError("Server is not running. Call `StaticWebFrontend.start_server()` first.")
- self._process.kill()
-
-
-def _healthz():
- """Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
- return {"status": "ok"}
-
-
-def _start_server(
- serve_dir: str, host: str = "localhost", port: int = -1, path: str = "/", log_file: str = "", root_path: str = ""
-) -> None:
- if port == -1:
- port = find_free_network_port()
- fastapi_service = FastAPI()
-
- fastapi_service.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # trailing / is required for urljoin to properly join the path. In case of
- # multiple trailing /, urljoin removes them
- fastapi_service.get(urljoin(f"{path}/", "healthz"), status_code=200)(_healthz)
- fastapi_service.mount(urljoin(path, root_path), StaticFiles(directory=serve_dir, html=True), name="static")
-
- log_config = _get_log_config(log_file) if log_file else uvicorn.config.LOGGING_CONFIG
-
- uvicorn.run(app=fastapi_service, host=host, port=port, log_config=log_config, root_path=root_path)
-
-
-def _get_log_config(log_file: str) -> dict:
- """Returns a logger configuration in the format expected by uvicorn that sends all logs to the given logfile."""
- # Modified from the default config found in uvicorn.config.LOGGING_CONFIG
- return {
- "version": 1,
- "disable_existing_loggers": False,
- "formatters": {
- "default": {
- "()": "uvicorn.logging.DefaultFormatter",
- "fmt": "%(levelprefix)s %(message)s",
- "use_colors": False,
- },
- },
- "handlers": {
- "default": {
- "formatter": "default",
- "class": "logging.FileHandler",
- "filename": log_file,
- },
- },
- "loggers": {
- "uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
- "uvicorn.error": {"handlers": ["default"], "level": "INFO", "propagate": False},
- "uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
- },
- }
-
-
-if __name__ == "__main__": # pragma: no-cover
- parser = ArgumentParser()
- parser.add_argument("serve_dir", type=str)
- parser.add_argument("root_path", type=str, default="")
- parser.add_argument("--host", type=str, default="localhost")
- parser.add_argument("--port", type=int, default=-1)
- args = parser.parse_args()
- _start_server(serve_dir=args.serve_dir, host=args.host, port=args.port, root_path=args.root_path)
diff --git a/src/lightning/app/launcher/__init__.py b/src/lightning/app/launcher/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/launcher/launcher.py b/src/lightning/app/launcher/launcher.py
deleted file mode 100644
index 7dc9fca11db42..0000000000000
--- a/src/lightning/app/launcher/launcher.py
+++ /dev/null
@@ -1,440 +0,0 @@
-import inspect
-import logging
-import os
-import signal
-import sys
-import time
-import traceback
-from functools import partial
-from multiprocessing import Process
-from typing import Callable, Dict, List, Optional, Tuple, TypedDict
-
-ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER = bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
-
-if True: # ToDo: Avoid Module level import not at top of file
- from lightning.app.core import constants
- from lightning.app.core.api import start_server
- from lightning.app.core.flow import LightningFlow
- from lightning.app.core.queues import MultiProcessQueue, QueuingSystem
- from lightning.app.storage.orchestrator import StorageOrchestrator
- from lightning.app.utilities.app_commands import run_app_commands
- from lightning.app.utilities.cloud import _sigterm_flow_handler
- from lightning.app.utilities.component import _set_flow_context, _set_frontend_context
- from lightning.app.utilities.enum import AppStage
- from lightning.app.utilities.exceptions import ExitAppException
- from lightning.app.utilities.load_app import extract_metadata_from_app, load_app_from_file
- from lightning.app.utilities.proxies import WorkRunner
- from lightning.app.utilities.redis import check_if_redis_running
-
-if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER:
- from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend as CloudBackend
-else:
- from lightning.app.launcher.lightning_backend import CloudBackend
-
-if True: # Avoid Module level import not at top of file
- from lightning.app.utilities.app_helpers import convert_print_to_logger_info
- from lightning.app.utilities.packaging.lightning_utils import enable_debugging
-
-if hasattr(constants, "get_cloud_queue_type"):
- CLOUD_QUEUE_TYPE = constants.get_cloud_queue_type() or "redis"
-else:
- CLOUD_QUEUE_TYPE = "redis"
-
-logger = logging.getLogger(__name__)
-
-
-class FlowRestAPIQueues(TypedDict):
- api_publish_state_queue: MultiProcessQueue
- api_response_queue: MultiProcessQueue
-
-
-@convert_print_to_logger_info
-@enable_debugging
-def start_application_server(
- entrypoint_file: str, host: str, port: int, queue_id: str, queues: Optional[FlowRestAPIQueues] = None
-):
- logger.debug(f"Run Lightning Work {entrypoint_file} {host} {port} {queue_id}")
- queue_system = QueuingSystem(CLOUD_QUEUE_TYPE)
-
- wait_for_queues(queue_system)
-
- kwargs = {
- "api_delta_queue": queue_system.get_api_delta_queue(queue_id=queue_id),
- }
-
- # Note: Override the queues if provided
- if isinstance(queues, Dict):
- kwargs.update(queues)
- else:
- kwargs.update({
- "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id),
- "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id),
- })
-
- app = load_app_from_file(entrypoint_file)
-
- from lightning.app.api.http_methods import _add_tags_to_api, _validate_api
- from lightning.app.utilities.app_helpers import is_overridden
- from lightning.app.utilities.commands.base import _commands_to_api, _prepare_commands
-
- apis = []
- if is_overridden("configure_api", app.root):
- apis = app.root.configure_api()
- _validate_api(apis)
- _add_tags_to_api(apis, ["app_api"])
-
- if is_overridden("configure_commands", app.root):
- commands = _prepare_commands(app)
- apis += _commands_to_api(commands)
-
- start_server(
- host=host,
- port=port,
- apis=apis,
- **kwargs,
- spec=extract_metadata_from_app(app),
- )
-
-
-@convert_print_to_logger_info
-@enable_debugging
-def run_lightning_work(
- file: str,
- work_name: str,
- queue_id: str,
-):
- """This staticmethod runs the specified work in the current process.
-
- It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
- specific logic is being implemented here
-
- """
- logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}")
-
- queues = QueuingSystem(CLOUD_QUEUE_TYPE)
- wait_for_queues(queues)
-
- caller_queue = queues.get_caller_queue(work_name=work_name, queue_id=queue_id)
- readiness_queue = queues.get_readiness_queue(queue_id=queue_id)
- delta_queue = queues.get_delta_queue(queue_id=queue_id)
- error_queue = queues.get_error_queue(queue_id=queue_id)
-
- request_queues = queues.get_orchestrator_request_queue(work_name=work_name, queue_id=queue_id)
- response_queues = queues.get_orchestrator_response_queue(work_name=work_name, queue_id=queue_id)
- copy_request_queues = queues.get_orchestrator_copy_request_queue(work_name=work_name, queue_id=queue_id)
- copy_response_queues = queues.get_orchestrator_copy_response_queue(work_name=work_name, queue_id=queue_id)
-
- run_app_commands(file)
-
- load_app_from_file(file)
-
- queue = queues.get_work_queue(work_name=work_name, queue_id=queue_id)
- work = queue.get()
-
- extras = {}
-
- if hasattr(work, "_run_executor_cls"):
- extras["run_executor_cls"] = work._run_executor_cls
-
- WorkRunner(
- work=work,
- work_name=work_name,
- caller_queue=caller_queue,
- delta_queue=delta_queue,
- readiness_queue=readiness_queue,
- error_queue=error_queue,
- request_queue=request_queues,
- response_queue=response_queues,
- copy_request_queue=copy_request_queues,
- copy_response_queue=copy_response_queues,
- **extras,
- )()
-
-
-@convert_print_to_logger_info
-@enable_debugging
-def run_lightning_flow(entrypoint_file: str, queue_id: str, base_url: str, queues: Optional[FlowRestAPIQueues] = None):
- _set_flow_context()
-
- logger.debug(f"Run Lightning Flow {entrypoint_file} {queue_id} {base_url}")
-
- app = load_app_from_file(entrypoint_file)
- app.backend = CloudBackend(entrypoint_file, queue_id=queue_id)
-
- queue_system = app.backend.queues
- app.backend.update_lightning_app_frontend(app)
- wait_for_queues(queue_system)
-
- app.backend.resolve_url(app, base_url)
- if app.root_path != "":
- app._update_index_file()
- app.backend._prepare_queues(app)
-
- # Note: Override the queues if provided
- if queues:
- app.api_publish_state_queue = queues["api_publish_state_queue"]
- app.api_response_queue = queues["api_response_queue"]
-
- LightningFlow._attach_backend(app.root, app.backend)
-
- app.should_publish_changes_to_api = True
-
- storage_orchestrator = StorageOrchestrator(
- app,
- app.request_queues,
- app.response_queues,
- app.copy_request_queues,
- app.copy_response_queues,
- )
- storage_orchestrator.setDaemon(True)
- storage_orchestrator.start()
-
- # refresh the layout with the populated urls.
- app._update_layout()
-
- # register a signal handler to clean all works.
- if sys.platform != "win32":
- signal.signal(signal.SIGTERM, partial(_sigterm_flow_handler, app=app))
-
- if "apis" in inspect.signature(start_server).parameters:
- from lightning.app.utilities.commands.base import _prepare_commands
-
- _prepare_commands(app)
-
- # Once the bootstrapping is done, running the rank 0
- # app with all the components inactive
- try:
- app._run()
- except ExitAppException:
- pass
- except Exception:
- app.stage = AppStage.FAILED
- print(traceback.format_exc())
-
- storage_orchestrator.join(0)
- app.backend.stop_all_works(app.works)
-
- exit_code = 1 if app.stage == AppStage.FAILED else 0
- print(f"Finishing the App with exit_code: {str(exit_code)}...")
-
- if not exit_code:
- app.backend.stop_app(app)
-
- sys.exit(exit_code)
-
-
-@convert_print_to_logger_info
-@enable_debugging
-def serve_frontend(file: str, flow_name: str, host: str, port: int):
- """This staticmethod runs the specified frontend for a given flow in a new process.
-
- It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
- specific logic is being implemented here.
-
- """
- _set_frontend_context()
- logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}")
- app = load_app_from_file(file)
- if flow_name not in app.frontends:
- raise ValueError(f"Could not find frontend for flow with name {flow_name}.")
- frontend = app.frontends[flow_name]
- assert frontend.flow.name == flow_name
-
- frontend.start_server(host, port)
-
-
-def start_server_in_process(target: Callable, args: Tuple = (), kwargs: Dict = {}) -> Process:
- p = Process(target=target, args=args, kwargs=kwargs)
- p.start()
- return p
-
-
-def format_row(elements, col_widths, padding=1):
- elements = [el.ljust(w - padding * 2) for el, w in zip(elements, col_widths)]
- pad = " " * padding
- elements = [f"{pad}{el}{pad}" for el in elements]
- return f'|{"|".join(elements)}|'
-
-
-def tabulate(data, headers):
- data = [[str(el) for el in row] for row in data]
- col_widths = [len(el) for el in headers]
- for row in data:
- col_widths = [max(len(el), curr) for el, curr in zip(row, col_widths)]
- col_widths = [w + 2 for w in col_widths]
- seps = ["-" * w for w in col_widths]
- lines = [format_row(headers, col_widths), format_row(seps, col_widths, padding=0)] + [
- format_row(row, col_widths) for row in data
- ]
- return "\n".join(lines)
-
-
-def manage_server_processes(processes: List[Tuple[str, Process]]) -> None:
- if not processes:
- return
-
- sigterm_called = [False]
-
- def _sigterm_handler(*_):
- sigterm_called[0] = True
-
- if sys.platform != "win32":
- signal.signal(signal.SIGTERM, _sigterm_handler)
-
- # Since frontends run user code, any of them could fail. In that case,
- # we want to fail all of them, as well as the application server, and
- # exit the command with an error status code.
-
- exitcode = 0
-
- while True:
- # We loop until
- # 1. Get a sigterm
- # 2. All the children died but all with exit code 0
- # 3. At-least one of the child died with non-zero exit code
-
- # sleeping quickly at the starting of every loop
- # moving this to the end of the loop might result in some flaky tests
- time.sleep(1)
-
- if sigterm_called[0]:
- print("Got SIGTERM. Exiting execution!!!")
- break
- if all(not p.is_alive() and p.exitcode == 0 for _, p in processes):
- print("All the components are inactive with exitcode 0. Exiting execution!!!")
- break
- if any((not p.is_alive() and p.exitcode != 0) for _, p in processes):
- print("Found dead components with non-zero exit codes, exiting execution!!! Components: ")
- print(
- tabulate(
- [(name, p.exitcode) for name, p in processes if not p.is_alive() and p.exitcode != 0],
- headers=["Name", "Exit Code"],
- )
- )
- exitcode = 1
- break
-
- # sleeping for the last set of logs to reach stdout
- time.sleep(2)
-
- # Cleanup
- for _, p in processes:
- if p.is_alive():
- os.kill(p.pid, signal.SIGTERM)
-
- # Give processes time to terminate
- for _, p in processes:
- p.join(5)
-
- # clean the remaining ones.
- if any(p.is_alive() for _, p in processes):
- for _, p in processes:
- if p.is_alive():
- os.kill(p.pid, signal.SIGKILL)
-
- # this sleep is just a precaution - signals might take a while to get raised.
- time.sleep(1)
- sys.exit(1)
-
- sys.exit(exitcode)
-
-
-def _get_frontends_from_app(entrypoint_file):
- """This function is used to get the frontends from the app. It will be used to start the frontends in a separate
- process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded locally to
- set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 if the
- app.spec.frontends is not set.
-
- NOTE: frontend_name are sorted to ensure that they get consistent ports.
-
- :param entrypoint_file: The entrypoint file for the app
- :return: A list of tuples of the form (frontend_name, port_number)
-
- """
- app = load_app_from_file(entrypoint_file)
-
- frontends = []
- # This value of the port should be synced with the port value in the backend.
- # If you change this value, you should also change the value in the backend.
- flow_frontends_starting_port = 8081
- for frontend in sorted(app.frontends.keys()):
- frontends.append((frontend, flow_frontends_starting_port))
- flow_frontends_starting_port += 1
-
- return frontends
-
-
-@convert_print_to_logger_info
-@enable_debugging
-def start_flow_and_servers(
- entrypoint_file: str,
- base_url: str,
- queue_id: str,
- host: str,
- port: int,
- flow_names_and_ports: Tuple[Tuple[str, int]],
-):
- processes: List[Tuple[str, Process]] = []
-
- # Queues between Flow and its Rest API are using multiprocessing to:
- # - reduce redis load
- # - increase UI responsiveness and RPS
- queue_system = QueuingSystem.MULTIPROCESS
- queues = {
- "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id),
- "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id),
- }
-
- # In order to avoid running this function 3 seperate times while executing the
- # `run_lightning_flow`, `start_application_server`, & `serve_frontend` functions
- # in a subprocess we extract this to the top level. If we intend to make changes
- # to be able to start these components in seperate containers, the implementation
- # will have to move a call to this function within the initialization process.
- run_app_commands(entrypoint_file)
-
- flow_process = start_server_in_process(
- run_lightning_flow,
- args=(
- entrypoint_file,
- queue_id,
- base_url,
- ),
- kwargs={"queues": queues},
- )
- processes.append(("Flow", flow_process))
-
- server_process = start_server_in_process(
- target=start_application_server,
- args=(
- entrypoint_file,
- host,
- port,
- queue_id,
- ),
- kwargs={"queues": queues},
- )
- processes.append(("Server", server_process))
-
- if not flow_names_and_ports:
- flow_names_and_ports = _get_frontends_from_app(entrypoint_file)
-
- for name, fe_port in flow_names_and_ports:
- frontend_process = start_server_in_process(target=serve_frontend, args=(entrypoint_file, name, host, fe_port))
- processes.append((name, frontend_process))
-
- manage_server_processes(processes)
-
-
-def wait_for_queues(queue_system: QueuingSystem) -> None:
- queue_check_start_time = int(time.time())
-
- if hasattr(queue_system, "get_queue"):
- while not queue_system.get_queue("healthz").is_running:
- if (int(time.time()) - queue_check_start_time) % 10 == 0:
- logger.warning("Waiting for http queues to start...")
- time.sleep(1)
- else:
- while not check_if_redis_running():
- if (int(time.time()) - queue_check_start_time) % 10 == 0:
- logger.warning("Waiting for redis queues to start...")
- time.sleep(1)
diff --git a/src/lightning/app/launcher/lightning_backend.py b/src/lightning/app/launcher/lightning_backend.py
deleted file mode 100644
index 0a974057ef070..0000000000000
--- a/src/lightning/app/launcher/lightning_backend.py
+++ /dev/null
@@ -1,525 +0,0 @@
-import inspect
-import json
-import logging
-import os
-import random
-import string
-import urllib
-from time import monotonic, sleep, time
-from typing import List, Optional
-
-from lightning_cloud.openapi import (
- AppinstancesIdBody,
- Externalv1LightningappInstance,
- Externalv1Lightningwork,
- V1BuildSpec,
- V1Drive,
- V1DriveSpec,
- V1DriveStatus,
- V1DriveType,
- V1Flowserver,
- V1LightningappInstanceState,
- V1LightningappRestartPolicy,
- V1LightningworkClusterDriver,
- V1LightningworkDrives,
- V1LightningworkSpec,
- V1LightningworkState,
- V1ListLightningworkResponse,
- V1Metadata,
- V1NetworkConfig,
- V1PackageManager,
- V1PythonDependencyInfo,
- V1SourceType,
- V1UserRequestedComputeConfig,
-)
-from lightning_cloud.openapi.rest import ApiException
-
-from lightning.app.core import LightningApp, LightningWork
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.runners.backends.backend import Backend
-from lightning.app.storage import Drive, Mount
-from lightning.app.utilities.enum import WorkStageStatus, WorkStopReasons, make_status
-from lightning.app.utilities.exceptions import LightningPlatformException
-from lightning.app.utilities.network import LightningClient, _check_service_url_is_ready
-
-logger = logging.getLogger(__name__)
-
-from lightning_cloud.openapi import SpecLightningappInstanceIdWorksBody, WorksIdBody # noqa: E402
-
-LIGHTNING_STOP_TIMEOUT = int(os.getenv("LIGHTNING_STOP_TIMEOUT", 2 * 60))
-
-
-def cloud_work_stage_to_work_status_stage(stage: V1LightningworkState) -> str:
- """Maps the Work stage names from the cloud backend to the status names in the Lightning framework."""
- mapping = {
- V1LightningworkState.STOPPED: WorkStageStatus.STOPPED,
- V1LightningworkState.PENDING: WorkStageStatus.PENDING,
- V1LightningworkState.NOT_STARTED: WorkStageStatus.PENDING,
- V1LightningworkState.IMAGE_BUILDING: WorkStageStatus.PENDING,
- V1LightningworkState.RUNNING: WorkStageStatus.RUNNING,
- V1LightningworkState.FAILED: WorkStageStatus.FAILED,
- }
- if stage not in mapping:
- raise ValueError(f"Cannot map the lightning-cloud work state {stage} to the lightning status stage.")
- return mapping[stage]
-
-
-class CloudBackend(Backend):
- def __init__(
- self,
- entrypoint_file,
- queue_id: Optional[str] = None,
- status_update_interval: int = 5,
- ) -> None:
- # TODO: Properly handle queue_id in the cloud.
- super().__init__(entrypoint_file, queues=QueuingSystem("http"), queue_id=queue_id)
- self._status_update_interval = status_update_interval
- self._last_time_updated = None
- self.client = LightningClient(retry=True)
- self.base_url: Optional[str] = None
-
- @staticmethod
- def _work_to_spec(work: LightningWork) -> V1LightningworkSpec:
- work_requirements = "\n".join(work.cloud_build_config.requirements)
-
- build_spec = V1BuildSpec(
- commands=work.cloud_build_config.build_commands(),
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages=work_requirements
- ),
- image=work.cloud_build_config.image,
- )
-
- drive_specs: List[V1LightningworkDrives] = []
- for drive_attr_name, drive in [
- (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
- ]:
- if drive.protocol == "lit://":
- drive_type = V1DriveType.NO_MOUNT_S3
- source_type = V1SourceType.S3
- else:
- drive_type = V1DriveType.UNSPECIFIED
- source_type = V1SourceType.UNSPECIFIED
-
- drive_specs.append(
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(name=f"{work.name}.{drive_attr_name}"),
- spec=V1DriveSpec(
- drive_type=drive_type,
- source_type=source_type,
- source=f"{drive.protocol}{drive.id}",
- ),
- status=V1DriveStatus(),
- ),
- mount_location=str(drive.root_folder),
- ),
- )
-
- # this should really be part of the work.cloud_compute struct, but to save
- # time we are not going to modify the backend in this set of PRs & instead
- # use the same s3 drives API which we used before.
- if work.cloud_compute.mounts is not None:
- if isinstance(work.cloud_compute.mounts, Mount):
- drive_specs.append(
- _create_mount_drive_spec(
- work_name=work.name,
- mount=work.cloud_compute.mounts,
- )
- )
- else:
- for mount in work.cloud_compute.mounts:
- drive_specs.append(
- _create_mount_drive_spec(
- work_name=work.name,
- mount=mount,
- )
- )
-
- if hasattr(work.cloud_compute, "interruptible"):
- preemptible = work.cloud_compute.interruptible
- else:
- preemptible = work.cloud_compute.preemptible
-
- colocation_group_id = None
- if hasattr(work.cloud_compute, "colocation_group_id"):
- colocation_group_id = work.cloud_compute.colocation_group_id
-
- user_compute_config = V1UserRequestedComputeConfig(
- name=work.cloud_compute.name,
- count=1,
- disk_size=work.cloud_compute.disk_size,
- preemptible=preemptible,
- shm_size=work.cloud_compute.shm_size,
- affinity_identifier=colocation_group_id,
- )
-
- random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311
-
- return V1LightningworkSpec(
- build_spec=build_spec,
- drives=drive_specs,
- user_requested_compute_config=user_compute_config,
- network_config=[V1NetworkConfig(name=random_name, port=work.port)],
- desired_state=V1LightningworkState.RUNNING,
- restart_policy=V1LightningappRestartPolicy.NEVER,
- cluster_driver=V1LightningworkClusterDriver.DIRECT,
- )
-
- def create_work(self, app: LightningApp, work: LightningWork) -> None:
- app_id = self._get_app_id()
- project_id = self._get_project_id()
- list_response: V1ListLightningworkResponse = self.client.lightningwork_service_list_lightningwork(
- project_id=project_id, app_id=app_id
- )
- external_specs: List[Externalv1Lightningwork] = list_response.lightningworks
-
- # Find THIS work in the list of all registered works
- external_spec = None
- for es in external_specs:
- if es.name == work.name:
- external_spec = es
- break
-
- if external_spec is None:
- spec = self._work_to_spec(work)
- try:
- fn = SpecLightningappInstanceIdWorksBody.__init__
- params = list(inspect.signature(fn).parameters)
- extras = {}
- if "display_name" in params:
- extras["display_name"] = getattr(work, "display_name", "")
-
- external_spec = self.client.lightningwork_service_create_lightningwork(
- project_id=project_id,
- spec_lightningapp_instance_id=app_id,
- body=SpecLightningappInstanceIdWorksBody(
- name=work.name,
- spec=spec,
- **extras,
- ),
- )
- # overwriting spec with return value
- spec = external_spec.spec
- except ApiException as e:
- # We might get exceed quotas, or be out of credits.
- message = json.loads(e.body).get("message")
- raise LightningPlatformException(message) from None
- elif external_spec.spec.desired_state == V1LightningworkState.RUNNING:
- spec = external_spec.spec
- work._port = spec.network_config[0].port
- else:
- # Signal the LightningWorkState to go into state RUNNING
- spec = external_spec.spec
-
- # getting the updated spec but ignoring everything other than port & drives
- new_spec = self._work_to_spec(work)
-
- spec.desired_state = V1LightningworkState.RUNNING
- spec.network_config[0].port = new_spec.network_config[0].port
- spec.drives = new_spec.drives
- spec.user_requested_compute_config = new_spec.user_requested_compute_config
- spec.build_spec = new_spec.build_spec
- spec.env = new_spec.env
- try:
- self.client.lightningwork_service_update_lightningwork(
- project_id=project_id,
- id=external_spec.id,
- spec_lightningapp_instance_id=app_id,
- body=WorksIdBody(spec),
- )
- except ApiException as e:
- # We might get exceed quotas, or be out of credits.
- message = json.loads(e.body).get("message")
- raise LightningPlatformException(message) from None
-
- # Replace the undefined url and host by the known one.
- work._host = "0.0.0.0" # noqa: S104
- work._future_url = f"{self._get_proxy_scheme()}://{spec.network_config[0].host}"
-
- # removing the backend to avoid the threadlock error
- _backend = work._backend
- work._backend = None
- app.work_queues[work.name].put(work)
- work._backend = _backend
-
- logger.info(f"Starting work {work.name}")
- logger.debug(f"With the following external spec: {external_spec}")
-
- def update_work_statuses(self, works: List[LightningWork]) -> None:
- """Pulls the status of each Work instance in the cloud.
-
- Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is
- being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to
- fetch the status and update it in the states.
-
- """
- if not works:
- return
-
- # TODO: should this run in a timer thread instead?
- if self._last_time_updated is not None and monotonic() - self._last_time_updated < self._status_update_interval:
- return
-
- cloud_work_specs = self._get_cloud_work_specs(self.client)
- local_works = works
- for cloud_work_spec in cloud_work_specs:
- for local_work in local_works:
- # TODO (tchaton) Better resolve pending status after succeeded
-
- # 1. Skip if the work isn't the current one.
- if local_work.name != cloud_work_spec.name:
- continue
-
- # 2. Logic for idle timeout
- self._handle_idle_timeout(
- local_work.cloud_compute.idle_timeout,
- local_work,
- cloud_work_spec,
- )
-
- # 3. Map the cloud phase to the local one
- cloud_stage = cloud_work_stage_to_work_status_stage(
- cloud_work_spec.status.phase,
- )
-
- # 4. Detect if the work failed during pending phase
- if local_work.status.stage == WorkStageStatus.PENDING and cloud_stage in WorkStageStatus.FAILED:
- if local_work._raise_exception:
- raise Exception(f"The work {local_work.name} failed during pending phase.")
- logger.error(f"The work {local_work.name} failed during pending phase.")
-
- # 5. Skip the pending and running as this is already handled by Lightning.
- if cloud_stage in (WorkStageStatus.PENDING, WorkStageStatus.RUNNING):
- continue
-
- # TODO: Add the logic for wait_timeout
- if local_work.status.stage != cloud_stage:
- latest_hash = local_work._calls["latest_call_hash"]
- if latest_hash is None:
- continue
- local_work._calls[latest_hash]["statuses"].append(make_status(cloud_stage))
-
- self._last_time_updated = monotonic()
-
- def stop_all_works(self, works: List[LightningWork]) -> None:
- """Stop resources for all LightningWorks in this app.
-
- The Works are stopped rather than deleted so that they can be inspected for debugging.
-
- """
- cloud_works = self._get_cloud_work_specs(self.client)
-
- for cloud_work in cloud_works:
- self._stop_work(cloud_work)
-
- def all_works_stopped(works: List[Externalv1Lightningwork]) -> bool:
- for work in works:
- # deleted work won't be in the request hence only checking for stopped & failed
- if work.status.phase not in (
- V1LightningworkState.STOPPED,
- V1LightningworkState.FAILED,
- ):
- return False
- return True
-
- t0 = time()
- while not all_works_stopped(self._get_cloud_work_specs(self.client)):
- # Wait a little..
- print("Waiting for works to stop...")
- sleep(3)
-
- # Break if we reached timeout.
- if time() - t0 > LIGHTNING_STOP_TIMEOUT:
- break
-
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- if not self.base_url:
- self.base_url = base_url
-
- for flow in app.flows:
- if self.base_url:
- # Replacing the path with complete URL
- if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
- raise ValueError(
- "Base URL doesn't have a valid scheme, expected it to start with 'http://' or 'https://' "
- )
- if isinstance(flow._layout, dict) and "target" not in flow._layout:
- # FIXME: Why _check_service_url_is_ready doesn't work ?
- frontend_url = urllib.parse.urljoin(self.base_url, flow.name + "/")
- flow._layout["target"] = frontend_url
-
- for work in app.works:
- if (
- work._url == ""
- and work.status.stage
- in (
- WorkStageStatus.RUNNING,
- WorkStageStatus.SUCCEEDED,
- )
- and work._internal_ip != ""
- and _check_service_url_is_ready(f"http://{work._internal_ip}:{work._port}")
- ):
- work._url = work._future_url
-
- @staticmethod
- def _get_proxy_scheme() -> str:
- return os.environ.get("LIGHTNING_PROXY_SCHEME", "https")
-
- @staticmethod
- def _get_app_id() -> str:
- return os.environ["LIGHTNING_CLOUD_APP_ID"]
-
- @staticmethod
- def _get_project_id() -> str:
- return os.environ["LIGHTNING_CLOUD_PROJECT_ID"]
-
- @staticmethod
- def _get_cloud_work_specs(client: LightningClient) -> List[Externalv1Lightningwork]:
- list_response: V1ListLightningworkResponse = client.lightningwork_service_list_lightningwork(
- project_id=CloudBackend._get_project_id(),
- app_id=CloudBackend._get_app_id(),
- )
- return list_response.lightningworks
-
- def _handle_idle_timeout(self, idle_timeout: float, work: LightningWork, resp: Externalv1Lightningwork) -> None:
- if idle_timeout is None:
- return
-
- if work.status.stage != WorkStageStatus.SUCCEEDED:
- return
-
- if time() > (idle_timeout + work.status.timestamp):
- logger.info(f"Idle Timeout {idle_timeout} has triggered. Stopping gracefully the {work.name}.")
- latest_hash = work._calls["latest_call_hash"]
- status = make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING)
- work._calls[latest_hash]["statuses"].append(status)
- self._stop_work(resp)
- logger.debug(f"Stopping work: {resp.id}")
-
- def _register_queues(self, app, work):
- super()._register_queues(app, work)
- kw = {"queue_id": self.queue_id, "work_name": work.name}
- app.work_queues.update({work.name: self.queues.get_work_queue(**kw)})
-
- def stop_work(self, app: LightningApp, work: LightningWork) -> None:
- cloud_works = self._get_cloud_work_specs(self.client)
- for cloud_work in cloud_works:
- if work.name == cloud_work.name:
- self._stop_work(cloud_work)
-
- def _stop_work(self, work_resp: Externalv1Lightningwork) -> None:
- spec: V1LightningworkSpec = work_resp.spec
- if spec.desired_state == V1LightningworkState.DELETED:
- # work is set to be deleted. Do nothing
- return
- if spec.desired_state == V1LightningworkState.STOPPED:
- # work is set to be stopped already. Do nothing
- return
- if work_resp.status.phase == V1LightningworkState.FAILED:
- # work is already failed. Do nothing
- return
- spec.desired_state = V1LightningworkState.STOPPED
- self.client.lightningwork_service_update_lightningwork(
- project_id=CloudBackend._get_project_id(),
- id=work_resp.id,
- spec_lightningapp_instance_id=CloudBackend._get_app_id(),
- body=WorksIdBody(spec),
- )
- print(f"Stopping {work_resp.name} ...")
-
- def delete_work(self, app: LightningApp, work: LightningWork) -> None:
- cloud_works = self._get_cloud_work_specs(self.client)
- for cloud_work in cloud_works:
- if work.name == cloud_work.name:
- self._delete_work(cloud_work)
-
- def _delete_work(self, work_resp: Externalv1Lightningwork) -> None:
- spec: V1LightningworkSpec = work_resp.spec
- if spec.desired_state == V1LightningworkState.DELETED:
- # work is set to be deleted. Do nothing
- return
- spec.desired_state = V1LightningworkState.DELETED
- self.client.lightningwork_service_update_lightningwork(
- project_id=CloudBackend._get_project_id(),
- id=work_resp.id,
- spec_lightningapp_instance_id=CloudBackend._get_app_id(),
- body=WorksIdBody(spec),
- )
- print(f"Deleting {work_resp.name} ...")
-
- def update_lightning_app_frontend(self, app: "lightning.LightningApp"): # noqa: F821
- """Used to create frontend's if the app couldn't be loaded locally."""
- if not len(app.frontends.keys()):
- return
-
- external_app_spec: "Externalv1LightningappInstance" = (
- self.client.lightningapp_instance_service_get_lightningapp_instance(
- project_id=CloudBackend._get_project_id(),
- id=CloudBackend._get_app_id(),
- )
- )
-
- frontend_specs = external_app_spec.spec.flow_servers
- spec = external_app_spec.spec
- if len(frontend_specs) != len(app.frontends.keys()):
- frontend_specs: List[V1Flowserver] = []
- for flow_name in sorted(app.frontends.keys()):
- frontend_spec = V1Flowserver(name=flow_name)
- frontend_specs.append(frontend_spec)
-
- spec.flow_servers = frontend_specs
- spec.enable_app_server = True
-
- logger.info("Found new frontends. Updating the app spec.")
-
- self.client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=CloudBackend._get_project_id(),
- id=CloudBackend._get_app_id(),
- body=AppinstancesIdBody(spec=spec),
- )
-
- def stop_app(self, app: "lightning.LightningApp"): # noqa: F821
- """Used to mark the App has stopped if everything has fine."""
-
- external_app_spec: "Externalv1LightningappInstance" = (
- self.client.lightningapp_instance_service_get_lightningapp_instance(
- project_id=CloudBackend._get_project_id(),
- id=CloudBackend._get_app_id(),
- )
- )
-
- spec = external_app_spec.spec
- spec.desired_state = V1LightningappInstanceState.STOPPED
-
- self.client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=CloudBackend._get_project_id(),
- id=CloudBackend._get_app_id(),
- body=AppinstancesIdBody(spec=spec),
- )
-
-
-def _create_mount_drive_spec(work_name: str, mount: "Mount") -> V1LightningworkDrives:
- if mount.protocol == "s3://":
- drive_type = V1DriveType.INDEXED_S3
- source_type = V1SourceType.S3
- else:
- raise RuntimeError(
- f"unknown mounts protocol `{mount.protocol}`. Please verify this "
- f"drive type has been configured for use in the cloud dispatcher."
- )
-
- return V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name=work_name,
- ),
- spec=V1DriveSpec(
- drive_type=drive_type,
- source_type=source_type,
- source=mount.source,
- ),
- status=V1DriveStatus(),
- ),
- mount_location=str(mount.mount_path),
- )
diff --git a/src/lightning/app/launcher/lightning_hybrid_backend.py b/src/lightning/app/launcher/lightning_hybrid_backend.py
deleted file mode 100644
index a5b82cd602601..0000000000000
--- a/src/lightning/app/launcher/lightning_hybrid_backend.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import os
-from typing import Optional
-
-from lightning_cloud.openapi import AppinstancesIdBody, Externalv1LightningappInstance
-
-from lightning.app.core import constants
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.launcher.lightning_backend import CloudBackend
-from lightning.app.runners.backends.backend import Backend
-from lightning.app.runners.backends.mp_process import MultiProcessingBackend
-from lightning.app.utilities.network import LightningClient
-
-if hasattr(constants, "get_cloud_queue_type"):
- CLOUD_QUEUE_TYPE = constants.get_cloud_queue_type() or "redis"
-else:
- CLOUD_QUEUE_TYPE = "redis"
-
-
-class CloudHybridBackend(Backend):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, queues=QueuingSystem(CLOUD_QUEUE_TYPE), **kwargs)
- cloud_backend = CloudBackend(*args, **kwargs)
- kwargs.pop("queue_id")
- multiprocess_backend = MultiProcessingBackend(*args, **kwargs)
-
- self.backends = {"cloud": cloud_backend, "multiprocess": multiprocess_backend}
- self.work_to_network_configs = {}
-
- def create_work(self, app, work) -> None:
- backend = self._get_backend(work)
- if isinstance(backend, MultiProcessingBackend):
- self._prepare_work_creation(app, work)
- backend.create_work(app, work)
-
- def _prepare_work_creation(self, app, work) -> None:
- app_id = self._get_app_id()
- project_id = self._get_project_id()
- assert project_id
-
- client = LightningClient()
- list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
- lit_app: Optional[Externalv1LightningappInstance] = None
-
- for lapp in list_apps_resp.lightningapps:
- if lapp.id == app_id:
- lit_app = lapp
-
- assert lit_app
-
- network_configs = lit_app.spec.network_config
-
- index = len(self.work_to_network_configs)
-
- if work.name not in self.work_to_network_configs:
- self.work_to_network_configs[work.name] = network_configs[index]
-
- # Enable Ingress and update the specs.
- lit_app.spec.network_config[index].enable = True
-
- client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=lit_app.id,
- body=AppinstancesIdBody(name=lit_app.name, spec=lit_app.spec),
- )
-
- work_network_config = self.work_to_network_configs[work.name]
-
- work._host = "0.0.0.0" # noqa: S104
- work._port = work_network_config.port
- work._future_url = f"{self._get_proxy_scheme()}://{work_network_config.host}"
-
- def update_work_statuses(self, works) -> None:
- if works:
- backend = self._get_backend(works[0])
- backend.update_work_statuses(works)
-
- def stop_all_works(self, works) -> None:
- if works:
- backend = self._get_backend(works[0])
- backend.stop_all_works(works)
-
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- works = app.works
- if works:
- backend = self._get_backend(works[0])
- backend.resolve_url(app, base_url)
-
- def update_lightning_app_frontend(self, app: "lightning.LightningApp"): # noqa: F821
- self.backends["cloud"].update_lightning_app_frontend(app)
-
- def stop_work(self, app, work) -> None:
- backend = self._get_backend(work)
- if isinstance(backend, MultiProcessingBackend):
- self._prepare_work_stop(app, work)
- backend.stop_work(app, work)
-
- def delete_work(self, app, work) -> None:
- backend = self._get_backend(work)
- if isinstance(backend, MultiProcessingBackend):
- self._prepare_work_stop(app, work)
- backend.delete_work(app, work)
-
- def _prepare_work_stop(self, app, work):
- app_id = self._get_app_id()
- project_id = self._get_project_id()
- assert project_id
-
- client = LightningClient()
- list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
- lit_app: Optional[Externalv1LightningappInstance] = None
-
- for lapp in list_apps_resp.lightningapps:
- if lapp.id == app_id:
- lit_app = lapp
-
- assert lit_app
-
- network_config = self.work_to_network_configs[work.name]
-
- for nc in lit_app.spec.network_config:
- if nc.host == network_config.host:
- nc.enable = False
-
- client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=lit_app.id,
- body=AppinstancesIdBody(name=lit_app.name, spec=lit_app.spec),
- )
-
- del self.work_to_network_configs[work.name]
-
- def _register_queues(self, app, work):
- backend = self._get_backend(work)
- backend._register_queues(app, work)
-
- def _get_backend(self, work):
- if work.cloud_compute.id == "default":
- return self.backends["multiprocess"]
- return self.backends["cloud"]
-
- @staticmethod
- def _get_proxy_scheme() -> str:
- return os.environ.get("LIGHTNING_PROXY_SCHEME", "https")
-
- @staticmethod
- def _get_app_id() -> str:
- return os.environ["LIGHTNING_CLOUD_APP_ID"]
-
- @staticmethod
- def _get_project_id() -> str:
- return os.environ["LIGHTNING_CLOUD_PROJECT_ID"]
-
- def stop_app(self, app: "lightning.LightningApp"): # noqa: F821
- """Used to mark the App has stopped if everything has fine."""
- self.backends["cloud"].stop_app(app)
diff --git a/src/lightning/app/pdb/__init__.py b/src/lightning/app/pdb/__init__.py
deleted file mode 100644
index 260b2505c005d..0000000000000
--- a/src/lightning/app/pdb/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app.pdb.pdb import MPPdb, set_trace
-
-# Enable breakpoint within forked processes.
-__builtins__["breakpoint"] = set_trace
-
-__all__ = ["set_trace", "MPPdb"]
diff --git a/src/lightning/app/pdb/pdb.py b/src/lightning/app/pdb/pdb.py
deleted file mode 100644
index a9249437cc698..0000000000000
--- a/src/lightning/app/pdb/pdb.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import multiprocessing
-import os
-import pdb
-import sys
-
-_stdin = [None]
-_stdin_lock = multiprocessing.Lock()
-try:
- _stdin_fd = sys.stdin.fileno()
-except Exception:
- _stdin_fd = None
-
-
-# Taken from https://github.com/facebookresearch/metaseq/blob/main/metaseq/pdb.py
-class MPPdb(pdb.Pdb):
- """A Pdb wrapper that works in a multiprocessing environment."""
-
- def __init__(self) -> None:
- pdb.Pdb.__init__(self, nosigint=True)
-
- def _cmdloop(self) -> None:
- stdin_back = sys.stdin
- with _stdin_lock:
- try:
- if _stdin_fd is not None:
- if not _stdin[0]:
- _stdin[0] = os.fdopen(_stdin_fd)
- sys.stdin = _stdin[0]
- self.cmdloop()
- finally:
- sys.stdin = stdin_back
-
-
-def set_trace() -> None:
- pdb = MPPdb()
- pdb.set_trace(sys._getframe().f_back)
diff --git a/src/lightning/app/plugin/__init__.py b/src/lightning/app/plugin/__init__.py
deleted file mode 100644
index 8d08c38eb0d5c..0000000000000
--- a/src/lightning/app/plugin/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from lightning.app.plugin.plugin import LightningPlugin
-
-__all__ = ["LightningPlugin"]
diff --git a/src/lightning/app/plugin/plugin.py b/src/lightning/app/plugin/plugin.py
deleted file mode 100644
index 78165ba065d2d..0000000000000
--- a/src/lightning/app/plugin/plugin.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# Copyright The Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-import shutil
-import tarfile
-import tempfile
-from pathlib import Path
-from typing import Any, Dict
-from urllib.parse import urlparse
-
-import requests
-import uvicorn
-from fastapi import FastAPI, HTTPException, status
-from fastapi.middleware.cors import CORSMiddleware
-from lightning_cloud.openapi import Externalv1LightningappInstance
-from pydantic import BaseModel
-
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.component import _set_flow_context
-from lightning.app.utilities.enum import AppStage
-from lightning.app.utilities.load_app import _load_plugin_from_file
-
-logger = Logger(__name__)
-
-_PLUGIN_MAX_CLIENT_TRIES: int = 3
-_PLUGIN_INTERNAL_DIR_PATH: str = f"{os.environ.get('HOME', '')}/internal"
-
-
-class LightningPlugin:
- """A ``LightningPlugin`` is a single-file Python class that can be executed within a cloudspace to perform
- actions."""
-
- def __init__(self) -> None:
- self.project_id = ""
- self.cloudspace_id = ""
- self.cluster_id = ""
- self.source_app = ""
- self.keep_machines_after_stop = False
-
- def run(self, *args: str, **kwargs: str) -> Externalv1LightningappInstance:
- """Override with the logic to execute on the cloudspace."""
- raise NotImplementedError
-
- def run_job(self, name: str, app_entrypoint: str, env_vars: Dict[str, str] = {}) -> Externalv1LightningappInstance:
- """Run a job in the cloudspace associated with this plugin.
-
- Args:
- name: The name of the job.
- app_entrypoint: The path of the file containing the app to run.
- env_vars: Additional env vars to set when running the app.
-
- Returns:
- The spec of the created LightningappInstance.
-
- """
- from lightning.app.runners.backends.cloud import CloudBackend
- from lightning.app.runners.cloud import CloudRuntime
-
- logger.info(f"Processing job run request. name: {name}, app_entrypoint: {app_entrypoint}, env_vars: {env_vars}")
-
- # Dispatch the job
- _set_flow_context()
-
- entrypoint_file = Path(app_entrypoint)
-
- app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute()), env_vars=env_vars)
-
- app.stage = AppStage.BLOCKING
-
- runtime = CloudRuntime(
- app=app,
- entrypoint=entrypoint_file,
- start_server=True,
- env_vars=env_vars,
- secrets={},
- run_app_comment_commands=True,
- backend=CloudBackend(entrypoint_file, client_max_tries=_PLUGIN_MAX_CLIENT_TRIES),
- )
- # Used to indicate Lightning has been dispatched
- os.environ["LIGHTNING_DISPATCHED"] = "1"
-
- return runtime.cloudspace_dispatch(
- project_id=self.project_id,
- cloudspace_id=self.cloudspace_id,
- name=name,
- cluster_id=self.cluster_id,
- source_app=self.source_app,
- keep_machines_after_stop=self.keep_machines_after_stop,
- )
-
- def _setup(
- self,
- project_id: str,
- cloudspace_id: str,
- cluster_id: str,
- source_app: str,
- keep_machines_after_stop: bool,
- ) -> None:
- self.source_app = source_app
- self.project_id = project_id
- self.cloudspace_id = cloudspace_id
- self.cluster_id = cluster_id
- self.keep_machines_after_stop = keep_machines_after_stop
-
-
-class _Run(BaseModel):
- plugin_entrypoint: str
- source_code_url: str
- project_id: str
- cloudspace_id: str
- cluster_id: str
- plugin_arguments: Dict[str, str]
- source_app: str
- keep_machines_after_stop: bool
-
-
-def _run_plugin(run: _Run) -> Dict[str, Any]:
- from lightning.app.runners.cloud import _to_clean_dict
-
- """Create a run with the given name and entrypoint under the cloudspace with the given ID."""
- with tempfile.TemporaryDirectory() as tmpdir:
- download_path = os.path.join(tmpdir, "source.tar.gz")
- source_path = os.path.join(tmpdir, "source")
- os.makedirs(source_path)
-
- # Download the tarball
- try:
- logger.info(f"Downloading plugin source: {run.source_code_url}")
-
- # Sometimes the URL gets encoded, so we parse it here
- source_code_url = urlparse(run.source_code_url).geturl()
-
- response = requests.get(source_code_url)
-
- # TODO: Backoff retry a few times in case the URL is flaky
- response.raise_for_status()
-
- with open(download_path, "wb") as f:
- f.write(response.content)
- except Exception as ex:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Error downloading plugin source: {str(ex)}.",
- )
-
- # Extract
- try:
- logger.info("Extracting plugin source.")
-
- with tarfile.open(download_path, "r:gz") as tf:
- tf.extractall(source_path) # noqa: S202
- except Exception as ex:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"Error extracting plugin source: {str(ex)}.",
- )
-
- # Import the plugin
- try:
- logger.info(f"Importing plugin: {run.plugin_entrypoint}")
-
- plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint))
- except Exception as ex:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(ex)}."
- )
-
- # Allow devs to add files to the app source
- if os.path.isdir(_PLUGIN_INTERNAL_DIR_PATH):
- shutil.copytree(_PLUGIN_INTERNAL_DIR_PATH, source_path, dirs_exist_ok=True)
-
- # Ensure that apps are dispatched from the temp directory
- cwd = os.getcwd()
- os.chdir(source_path)
-
- # Setup and run the plugin
- try:
- logger.info(
- "Running plugin. "
- f"project_id: {run.project_id}, cloudspace_id: {run.cloudspace_id}, cluster_id: {run.cluster_id}."
- )
-
- plugin._setup(
- project_id=run.project_id,
- cloudspace_id=run.cloudspace_id,
- cluster_id=run.cluster_id,
- source_app=run.source_app,
- keep_machines_after_stop=run.keep_machines_after_stop,
- )
- app_instance = plugin.run(**run.plugin_arguments)
- return _to_clean_dict(app_instance, True)
- except Exception as ex:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(ex)}."
- )
- finally:
- os.chdir(cwd)
-
-
-async def _healthz() -> Dict[str, str]:
- """Health check endpoint."""
- return {"status": "ok"}
-
-
-def _start_plugin_server(port: int) -> None:
- """Start the plugin server which can be used to dispatch apps or run plugins."""
-
- fastapi_service = FastAPI()
-
- fastapi_service.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
-
- fastapi_service.post("/v1/runs")(_run_plugin)
- fastapi_service.get("/healthz", status_code=200)(_healthz)
-
- uvicorn.run(
- app=fastapi_service,
- host="127.0.0.1",
- port=port,
- log_level="error",
- )
diff --git a/src/lightning/app/runners/__init__.py b/src/lightning/app/runners/__init__.py
deleted file mode 100644
index 985203033da09..0000000000000
--- a/src/lightning/app/runners/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from lightning.app.runners.cloud import CloudRuntime
-from lightning.app.runners.multiprocess import MultiProcessRuntime
-from lightning.app.runners.runtime import Runtime, dispatch
-from lightning.app.utilities.app_commands import run_app_commands
-from lightning.app.utilities.load_app import load_app_from_file
-
-__all__ = [
- "dispatch",
- "load_app_from_file",
- "run_app_commands",
- "Runtime",
- "MultiProcessRuntime",
- "CloudRuntime",
-]
diff --git a/src/lightning/app/runners/backends/__init__.py b/src/lightning/app/runners/backends/__init__.py
deleted file mode 100644
index be7b3a7457976..0000000000000
--- a/src/lightning/app/runners/backends/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from enum import Enum
-
-from lightning.app.core.constants import APP_SERVER_IN_CLOUD
-from lightning.app.runners.backends.backend import Backend
-from lightning.app.runners.backends.cloud import CloudBackend
-from lightning.app.runners.backends.docker import DockerBackend
-from lightning.app.runners.backends.mp_process import CloudMultiProcessingBackend, MultiProcessingBackend
-
-
-class BackendType(Enum):
- MULTIPROCESSING = "multiprocessing"
- DOCKER = "docker"
- CLOUD = "cloud"
-
- def get_backend(self, entrypoint_file: str) -> "Backend":
- if self == BackendType.MULTIPROCESSING:
- if APP_SERVER_IN_CLOUD:
- return CloudMultiProcessingBackend(entrypoint_file)
- return MultiProcessingBackend(entrypoint_file)
- if self == BackendType.DOCKER:
- return DockerBackend(entrypoint_file)
- if self == BackendType.CLOUD:
- return CloudBackend(entrypoint_file)
- raise ValueError("Unknown client type")
diff --git a/src/lightning/app/runners/backends/backend.py b/src/lightning/app/runners/backends/backend.py
deleted file mode 100644
index 4b50f0d171482..0000000000000
--- a/src/lightning/app/runners/backends/backend.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from abc import ABC, abstractmethod
-from functools import partial
-from typing import TYPE_CHECKING, Any, Callable, List, Optional
-
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.utilities.proxies import ProxyWorkRun, unwrap
-
-if TYPE_CHECKING:
- import lightning.app
-
-
-class Backend(ABC):
- """The Backend provides and interface for the framework to communicate with resources in the cloud."""
-
- def __init__(self, entrypoint_file: str, queues: QueuingSystem, queue_id: str) -> None:
- self.queues: QueuingSystem = queues
- self.queue_id = queue_id
- self.entrypoint_file = entrypoint_file
-
- @abstractmethod
- def create_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
- pass
-
- @abstractmethod
- def update_work_statuses(self, works: List["lightning.app.LightningWork"]) -> None:
- pass
-
- @abstractmethod
- def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
- pass
-
- @abstractmethod
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- pass
-
- @abstractmethod
- def stop_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
- pass
-
- def _dynamic_run_wrapper(
- self,
- *args: Any,
- app: "lightning.app.LightningApp",
- work: "lightning.app.LightningWork",
- work_run: Callable,
- **kwargs: Any,
- ) -> None:
- if not work.name:
- # the name is empty, which means this work was never assigned to a parent flow
- raise AttributeError(
- f"Failed to create process for {work.__class__.__name__}."
- f" Make sure to set this work as an attribute of a `LightningFlow` before calling the run method."
- )
-
- # 1. Create and register the queues associated the work
- self._register_queues(app, work)
-
- work.run = work_run
-
- # 2. Create the work
- self.create_work(app, work)
-
- # 3. Attach backend
- work._backend = self
-
- # 4. Create the work proxy to manipulate the work
- work.run = ProxyWorkRun(
- work_run=work_run,
- work_name=work.name,
- work=work,
- caller_queue=app.caller_queues[work.name],
- )
-
- # 5. Run the work proxy
- return work.run(*args, **kwargs)
-
- def _wrap_run_method(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
- if work.run.__name__ == "_dynamic_run_wrapper":
- return
-
- work.run = partial(self._dynamic_run_wrapper, app=app, work=work, work_run=unwrap(work.run))
-
- def _prepare_queues(self, app: "lightning.app.LightningApp"):
- kw = {"queue_id": self.queue_id}
- app.delta_queue = self.queues.get_delta_queue(**kw)
- app.readiness_queue = self.queues.get_readiness_queue(**kw)
- app.api_response_queue = self.queues.get_api_response_queue(**kw)
- app.error_queue = self.queues.get_error_queue(**kw)
- app.api_publish_state_queue = self.queues.get_api_state_publish_queue(**kw)
- app.api_delta_queue = app.delta_queue
- app.request_queues = {}
- app.response_queues = {}
- app.copy_request_queues = {}
- app.copy_response_queues = {}
- app.caller_queues = {}
- app.work_queues = {}
- app.flow_to_work_delta_queues = {}
-
- def _register_queues(self, app, work):
- kw = {"queue_id": self.queue_id, "work_name": work.name}
- app.request_queues.update({work.name: self.queues.get_orchestrator_request_queue(**kw)})
- app.response_queues.update({work.name: self.queues.get_orchestrator_response_queue(**kw)})
- app.copy_request_queues.update({work.name: self.queues.get_orchestrator_copy_request_queue(**kw)})
- app.copy_response_queues.update({work.name: self.queues.get_orchestrator_copy_response_queue(**kw)})
- app.caller_queues.update({work.name: self.queues.get_caller_queue(**kw)})
- app.flow_to_work_delta_queues.update({work.name: self.queues.get_flow_to_work_delta_queue(**kw)})
-
-
-class WorkManager(ABC):
- """The work manager is an interface for the backend, runtime to control the LightningWork."""
-
- def __init__(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork"):
- pass
-
- @abstractmethod
- def start(self) -> None:
- pass
-
- @abstractmethod
- def kill(self) -> None:
- pass
-
- @abstractmethod
- def restart(self) -> None:
- pass
-
- @abstractmethod
- def is_alive(self) -> bool:
- pass
diff --git a/src/lightning/app/runners/backends/cloud.py b/src/lightning/app/runners/backends/cloud.py
deleted file mode 100644
index efae58233e04f..0000000000000
--- a/src/lightning/app/runners/backends/cloud.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING, List, Optional
-
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.runners.backends import Backend
-from lightning.app.utilities.network import LightningClient
-
-if TYPE_CHECKING:
- import lightning.app
-
-
-class CloudBackend(Backend):
- def __init__(
- self,
- entrypoint_file,
- queue_id: Optional[str] = None,
- status_update_interval: Optional[int] = None,
- client_max_tries: Optional[int] = None,
- ):
- super().__init__(entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id=queue_id)
- self.client = LightningClient(max_tries=client_max_tries)
-
- def create_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
- raise NotImplementedError
-
- def update_work_statuses(self, works: List["lightning.app.LightningWork"]) -> None:
- raise NotImplementedError
-
- def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
- raise NotImplementedError
-
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- raise NotImplementedError
-
- def stop_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
- raise NotImplementedError
diff --git a/src/lightning/app/runners/backends/docker.py b/src/lightning/app/runners/backends/docker.py
deleted file mode 100644
index 3d76d65a74ff1..0000000000000
--- a/src/lightning/app/runners/backends/docker.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from time import time
-from typing import List, Optional
-
-import lightning.app
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.runners.backends.backend import Backend
-
-
-class DockerBackend(Backend):
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- pass
-
- def stop_work(self, app: "lightning.app.LightningApp", work: "lightning.app.LightningWork") -> None:
- pass
-
- def __init__(self, entrypoint_file: str):
- super().__init__(entrypoint_file=entrypoint_file, queues=QueuingSystem.REDIS, queue_id=str(int(time())))
-
- def create_work(self, app, work):
- pass
-
- def update_work_statuses(self, works) -> None:
- pass
-
- def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
- pass
diff --git a/src/lightning/app/runners/backends/mp_process.py b/src/lightning/app/runners/backends/mp_process.py
deleted file mode 100644
index 554a03c5c8e06..0000000000000
--- a/src/lightning/app/runners/backends/mp_process.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import multiprocessing
-from typing import Any, List, Optional
-
-import lightning.app
-from lightning.app.core import constants
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.runners.backends.backend import Backend, WorkManager
-from lightning.app.utilities.enum import WorkStageStatus
-from lightning.app.utilities.network import _check_service_url_is_ready, find_free_network_port
-from lightning.app.utilities.port import disable_port, enable_port
-from lightning.app.utilities.proxies import ProxyWorkRun, WorkRunner
-
-
-class MultiProcessWorkManager(WorkManager):
- def __init__(self, app, work):
- self.app = app
- self.work = work
- self._process = None
-
- def start(self):
- self._work_runner = WorkRunner(
- work=self.work,
- work_name=self.work.name,
- caller_queue=self.app.caller_queues[self.work.name],
- delta_queue=self.app.delta_queue,
- readiness_queue=self.app.readiness_queue,
- error_queue=self.app.error_queue,
- request_queue=self.app.request_queues[self.work.name],
- response_queue=self.app.response_queues[self.work.name],
- copy_request_queue=self.app.copy_request_queues[self.work.name],
- copy_response_queue=self.app.copy_response_queues[self.work.name],
- flow_to_work_delta_queue=self.app.flow_to_work_delta_queues[self.work.name],
- run_executor_cls=self.work._run_executor_cls,
- )
-
- start_method = self.work._start_method
- context = multiprocessing.get_context(start_method)
- self._process = context.Process(target=self._work_runner)
- self._process.start()
-
- def kill(self):
- self._process.terminate()
-
- def restart(self):
- assert not self.is_alive()
- work = self._work_runner.work
- # un-wrap ProxyRun.
- is_proxy = isinstance(work.run, ProxyWorkRun)
- if is_proxy:
- work_run = work.run
- work.run = work_run.work_run
- work._restarting = True
- self.start()
- if is_proxy:
- work.run = work_run
-
- def is_alive(self) -> bool:
- return self._process.is_alive()
-
-
-class MultiProcessingBackend(Backend):
- def __init__(self, entrypoint_file: str):
- super().__init__(entrypoint_file=entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id="0")
-
- def create_work(self, app, work) -> None:
- if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
- # Override the port if set by the user
- work._port = find_free_network_port()
- work._host = "0.0.0.0" # noqa: S104
- work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
-
- app.processes[work.name] = MultiProcessWorkManager(app, work)
- app.processes[work.name].start()
- self.resolve_url(app)
- app._update_layout()
-
- def update_work_statuses(self, works) -> None:
- pass
-
- def stop_all_works(self, works: List["lightning.app.LightningWork"]) -> None:
- pass
-
- def resolve_url(self, app, base_url: Optional[str] = None) -> None:
- for work in app.works:
- if (
- work.status.stage in (WorkStageStatus.RUNNING, WorkStageStatus.SUCCEEDED)
- and work._url == ""
- and work._port
- ):
- url = work._future_url if work._future_url else f"http://{work._host}:{work._port}"
- if _check_service_url_is_ready(url, metadata=f"Checking {work.name}"):
- work._url = url
-
- def stop_work(self, app, work: "lightning.app.LightningWork") -> None:
- work_manager: MultiProcessWorkManager = app.processes[work.name]
- work_manager.kill()
-
- def delete_work(self, app, work: "lightning.app.LightningWork") -> None:
- self.stop_work(app, work)
-
-
-class CloudMultiProcessingBackend(MultiProcessingBackend):
- def __init__(self, *args: Any, **kwargs: Any):
- super().__init__(*args, **kwargs)
-
- # Note: Track the open ports to close them on termination.
- self.ports = []
-
- def create_work(self, app, work) -> None:
- work._host = "0.0.0.0" # noqa: S104
- nc = enable_port()
- self.ports.append(nc.port)
- work._port = nc.port
- work._future_url = f"https://{nc.host}"
- return super().create_work(app, work)
-
- def stop_work(self, app, work: "lightning.app.LightningWork") -> None:
- disable_port(work._port)
- self.ports = [port for port in self.ports if port != work._port]
- return super().stop_work(app, work)
-
- def delete_work(self, app, work: "lightning.app.LightningWork") -> None:
- self.stop_work(app, work)
diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py
deleted file mode 100644
index c488014450b9b..0000000000000
--- a/src/lightning/app/runners/cloud.py
+++ /dev/null
@@ -1,1109 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import fnmatch
-import json
-import os
-import random
-import re
-import string
-import sys
-import time
-from dataclasses import dataclass
-from functools import partial
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
-from urllib.parse import quote
-
-import click
-import rich
-from lightning_cloud.openapi import (
- Body3,
- Body4,
- CloudspaceIdRunsBody,
- Externalv1LightningappInstance,
- Gridv1ImageSpec,
- IdGetBody1,
- ProjectIdCloudspacesBody,
- V1BuildSpec,
- V1CloudSpace,
- V1DataConnectionMount,
- V1DependencyFileInfo,
- V1Drive,
- V1DriveSpec,
- V1DriveStatus,
- V1DriveType,
- V1EnvVar,
- V1Flowserver,
- V1LightningappInstanceSpec,
- V1LightningappInstanceState,
- V1LightningAuth,
- V1LightningBasicAuth,
- V1LightningRun,
- V1LightningworkDrives,
- V1LightningworkSpec,
- V1Membership,
- V1Metadata,
- V1NetworkConfig,
- V1PackageManager,
- V1PythonDependencyInfo,
- V1QueueServerType,
- V1SourceType,
- V1UserRequestedComputeConfig,
- V1UserRequestedFlowComputeConfig,
- V1Work,
-)
-from lightning_cloud.openapi.rest import ApiException
-
-from lightning.app.core.app import LightningApp
-from lightning.app.core.constants import (
- CLOUD_UPLOAD_WARNING,
- DEFAULT_NUMBER_OF_EXPOSED_PORTS,
- DISABLE_DEPENDENCY_CACHE,
- ENABLE_APP_COMMENT_COMMAND_EXECUTION,
- ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER,
- ENABLE_PULLING_STATE_ENDPOINT,
- ENABLE_PUSHING_STATE_ENDPOINT,
- LIGHTNING_CLOUD_PRINT_SPECS,
- SYS_CUSTOMIZATIONS_SYNC_ROOT,
- enable_interruptible_works,
- enable_multiple_works_in_default_container,
- get_cloud_queue_type,
- get_lightning_cloud_url,
-)
-from lightning.app.core.work import LightningWork
-from lightning.app.runners.backends.cloud import CloudBackend
-from lightning.app.runners.runtime import Runtime
-from lightning.app.source_code import LocalSourceCodeDir
-from lightning.app.source_code.copytree import _IGNORE_FUNCTION, _filter_ignored, _parse_lightningignore
-from lightning.app.storage import Drive, Mount
-from lightning.app.utilities.app_helpers import Logger, _is_headless
-from lightning.app.utilities.auth import _credential_string_to_basic_auth_params
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.clusters import _ensure_cluster_project_binding, _get_default_cluster
-from lightning.app.utilities.dependency_caching import get_hash
-from lightning.app.utilities.load_app import load_app_from_file
-from lightning.app.utilities.packaging.app_config import AppConfig, _get_config_file
-from lightning.app.utilities.packaging.lightning_utils import _prepare_lightning_wheels_and_requirements
-from lightning.app.utilities.secrets import _names_to_ids
-
-logger = Logger(__name__)
-
-
-def _to_clean_dict(swagger_object, map_attributes):
- """Returns the swagger object properties as a dict with correct object names."""
- if hasattr(swagger_object, "to_dict"):
- attribute_map = swagger_object.attribute_map
- result = {}
- for key in attribute_map:
- value = getattr(swagger_object, key)
- value = _to_clean_dict(value, map_attributes)
- if value is not None and value != {}:
- key = attribute_map[key] if map_attributes else key
- result[key] = value
- return result
- if isinstance(swagger_object, list):
- return [_to_clean_dict(x, map_attributes) for x in swagger_object]
- if isinstance(swagger_object, dict):
- return {key: _to_clean_dict(value, map_attributes) for key, value in swagger_object.items()}
- return swagger_object
-
-
-@dataclass
-class CloudRuntime(Runtime):
- backend: Union[str, CloudBackend] = "cloud"
-
- def open(self, name: str, cluster_id: Optional[str] = None):
- """Method to open a CloudSpace with the root folder uploaded."""
- try:
- # Check for feature support
- user = self.backend.client.auth_service_get_user()
- if not user.features.code_tab:
- rich.print(
- "[red]The `lightning_app open` command has not been enabled for your account. "
- "To request access, please contact support@lightning.ai[/red]"
- )
- sys.exit(1)
-
- # Dispatch in four phases: resolution, validation, spec creation, API transactions
- # Resolution
- cloudspace_config = self._resolve_config(name, load=False)
- root = self._resolve_root()
- ignore_functions = self._resolve_open_ignore_functions()
- repo = self._resolve_repo(root, ignore_functions)
- project = self._resolve_project()
- existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
- cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
- existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
- cluster_id, project.project_id, existing_cloudspaces
- )
- cloudspace_name = self._resolve_cloudspace_name(
- cloudspace_config.name,
- existing_cloudspace,
- existing_cloudspaces,
- )
- needs_credits = self._resolve_needs_credits(project)
-
- # Validation
- # Note: We do not validate the repo here since open only uploads a directory if asked explicitly
- self._validate_cluster_id(cluster_id, project.project_id)
-
- # Spec creation
- run_body = self._get_run_body(cluster_id, [], None, [], True, root, self.start_server)
-
- if existing_run_instance is not None:
- print(
- f"Re-opening the CloudSpace {cloudspace_config.name}. "
- "This operation will create a new run but will not overwrite the files in your CloudSpace."
- )
- else:
- print(f"The name of the CloudSpace is: {cloudspace_config.name}")
-
- # API transactions
- cloudspace_id = self._api_create_cloudspace_if_not_exists(
- project.project_id,
- cloudspace_name,
- existing_cloudspace,
- )
- self._api_stop_existing_run_instance(project.project_id, existing_run_instance)
- run = self._api_create_run(project.project_id, cloudspace_id, run_body)
- self._api_package_and_upload_repo(repo, run)
-
- if getattr(run, "cluster_id", None):
- print(f"Running on {run.cluster_id}")
-
- if "PYTEST_CURRENT_TEST" not in os.environ:
- click.launch(self._get_cloudspace_url(project, cloudspace_name, "code", needs_credits))
-
- except ApiException as ex:
- logger.error(ex.body)
- sys.exit(1)
-
- def cloudspace_dispatch(
- self,
- project_id: str,
- cloudspace_id: str,
- name: str,
- cluster_id: str,
- source_app: Optional[str] = None,
- keep_machines_after_stop: Optional[bool] = None,
- ) -> Externalv1LightningappInstance:
- """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such
- as the project and cluster IDs that are instead passed directly.
-
- Args:
- project_id: The ID of the project.
- cloudspace_id: The ID of the cloudspace.
- name: The name for the run.
- cluster_id: The ID of the cluster to run on.
- source_app: Name of the source app that triggered the run.
- keep_machines_after_stop: If true, machines will be left running after the run is finished and reused after
-
- Raises:
- ApiException: If there was an issue in the backend.
- RuntimeError: If there are validation errors.
- ValueError: If there are validation errors.
-
- Returns:
- The spec the created app instance.
-
- """
- # Dispatch in four phases: resolution, validation, spec creation, API transactions
- # Resolution
- root = self._resolve_root()
- # If the root will already be there, we don't need to upload and preserve the absolute entrypoint
- top_folder = os.getenv("FILESYSTEM_TOP_FOLDER_NAME", "project")
- absolute_entrypoint = str(root).startswith(f"/{top_folder}")
- # If system customization files found, it will set their location path
- sys_customizations_root = self._resolve_env_root()
- repo = self._resolve_repo(
- root,
- default_ignore=False,
- package_source=not absolute_entrypoint,
- sys_customizations_root=sys_customizations_root,
- )
- existing_instances = self._resolve_run_instances_by_name(project_id, name)
- name = self._resolve_run_name(name, existing_instances)
- cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
- queue_server_type = self._resolve_queue_server_type()
-
- self.app._update_index_file()
-
- # Validation
- # TODO: Validate repo and surface to the user
- # self._validate_repo(root, repo)
- self._validate_work_build_specs_and_compute()
- self._validate_drives()
- self._validate_mounts()
-
- # Spec creation
- flow_servers = self._get_flow_servers()
- network_configs = self._get_network_configs(flow_servers)
- works = self._get_works(cloudspace=cloudspace)
- run_body = self._get_run_body(
- cluster_id,
- flow_servers,
- network_configs,
- works,
- False,
- root,
- True,
- True,
- absolute_entrypoint,
- )
- env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands)
-
- # API transactions
- logger.info(f"Creating cloudspace run. run_body: {run_body}")
- run = self._api_create_run(project_id, cloudspace_id, run_body)
-
- self._api_package_and_upload_repo(repo, run)
-
- logger.info(f"Creating cloudspace run instance. name: {name}")
- return self._api_create_run_instance(
- cluster_id,
- project_id,
- name,
- cloudspace_id,
- run.id,
- V1LightningappInstanceState.RUNNING,
- queue_server_type,
- env_vars,
- source_app=source_app,
- keep_machines_after_stop=keep_machines_after_stop,
- )
-
- def dispatch(
- self,
- name: str = "",
- cluster_id: Optional[str] = None,
- open_ui: bool = True,
- no_cache: bool = False,
- **kwargs: Any,
- ) -> None:
- """Method to dispatch and run the :class:`~lightning.app.core.app.LightningApp` in the cloud."""
- # not user facing error ideally - this should never happen in normal user workflow
- if not self.entrypoint:
- raise ValueError(
- "Entrypoint file not provided. Did you forget to "
- "initialize the Runtime object with `entrypoint` argument?"
- )
-
- cleanup_handle = None
-
- try:
- # Dispatch in four phases: resolution, validation, spec creation, API transactions
- # Resolution
- cloudspace_config = self._resolve_config(name)
- root = self._resolve_root()
- repo = self._resolve_repo(root)
- project = self._resolve_project()
- existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
- cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
- existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
- cluster_id, project.project_id, existing_cloudspaces
- )
- cloudspace_name = self._resolve_cloudspace_name(
- cloudspace_config.name,
- existing_cloudspace,
- existing_cloudspaces,
- )
- queue_server_type = self._resolve_queue_server_type()
- needs_credits = self._resolve_needs_credits(project)
-
- # TODO: Move these
- cleanup_handle = _prepare_lightning_wheels_and_requirements(root)
- self.app._update_index_file()
-
- # Validation
- self._validate_repo(root, repo)
- self._validate_cluster_id(cluster_id, project.project_id)
- self._validate_work_build_specs_and_compute()
- self._validate_drives()
- self._validate_mounts()
-
- # Spec creation
- flow_servers = self._get_flow_servers()
- network_configs = self._get_network_configs(flow_servers)
- works = self._get_works()
- run_body = self._get_run_body(
- cluster_id, flow_servers, network_configs, works, no_cache, root, self.start_server
- )
- auth = self._get_auth(self.enable_basic_auth)
- env_vars = self._get_env_vars(self.env_vars, self.secrets, self.run_app_comment_commands)
-
- if LIGHTNING_CLOUD_PRINT_SPECS is not None:
- self._print_specs(run_body, LIGHTNING_CLOUD_PRINT_SPECS)
- sys.exit(0)
-
- print(f"The name of the app is: {cloudspace_name}")
-
- # API transactions
- cloudspace_id = self._api_create_cloudspace_if_not_exists(
- project.project_id,
- cloudspace_name,
- existing_cloudspace,
- )
- self._api_stop_existing_run_instance(project.project_id, existing_run_instance)
- run = self._api_create_run(project.project_id, cloudspace_id, run_body)
- self._api_package_and_upload_repo(repo, run)
-
- if getattr(run, "cluster_id", None):
- print(f"Running app on {run.cluster_id}")
-
- # Save the config for re-runs
- cloudspace_config.save_to_dir(root)
-
- desired_state = (
- V1LightningappInstanceState.STOPPED if needs_credits else V1LightningappInstanceState.RUNNING
- )
-
- if existing_run_instance is not None:
- run_instance = self._api_transfer_run_instance(
- project.project_id,
- run.id,
- existing_run_instance.id,
- desired_state,
- queue_server_type,
- env_vars,
- auth,
- )
- else:
- run_instance = self._api_create_run_instance(
- cluster_id,
- project.project_id,
- cloudspace_name,
- cloudspace_id,
- run.id,
- desired_state,
- queue_server_type,
- env_vars,
- auth,
- )
-
- if run_instance.status.phase == V1LightningappInstanceState.FAILED:
- raise RuntimeError("Failed to create the application. Cannot upload the source code.")
-
- # TODO: Remove testing dependency, but this would open a tab for each test...
- if open_ui and "PYTEST_CURRENT_TEST" not in os.environ:
- click.launch(
- self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui", needs_credits)
- )
-
- if bool(int(os.getenv("LIGHTING_TESTING", "0"))):
- print(f"APP_LOGS_URL: {self._get_app_url(project, run_instance, 'logs')}")
-
- except ApiException as ex:
- logger.error(ex.body)
- sys.exit(1)
- finally:
- if cleanup_handle:
- cleanup_handle()
-
- @classmethod
- def load_app_from_file(cls, filepath: str, env_vars: Dict[str, str] = {}) -> "LightningApp":
- """Load a LightningApp from a file, mocking the imports."""
- # Pretend we are running in the cloud when loading the app locally
- os.environ["LAI_RUNNING_IN_CLOUD"] = "1"
-
- try:
- app = load_app_from_file(filepath, raise_exception=True, mock_imports=True, env_vars=env_vars)
- except FileNotFoundError as ex:
- raise ex
- except Exception:
- from lightning.app.testing.helpers import EmptyFlow
-
- # Create a generic app.
- logger.info("Could not load the app locally. Starting the app directly on the cloud.")
- app = LightningApp(EmptyFlow())
- finally:
- del os.environ["LAI_RUNNING_IN_CLOUD"]
- return app
-
- def _resolve_config(self, name: Optional[str], load: bool = True) -> AppConfig:
- """Find and load the config file if it exists (otherwise create an empty config).
-
- Override the name if provided.
-
- """
- config_file = _get_config_file(self.entrypoint)
- cloudspace_config = AppConfig.load_from_file(config_file) if config_file.exists() and load else AppConfig()
- if name:
- # Override the name if provided
- cloudspace_config.name = name
- return cloudspace_config
-
- def _resolve_root(self) -> Path:
- """Determine the root of the project."""
- root = Path(self.entrypoint).absolute()
- if root.is_file():
- root = root.parent
- return root
-
- def _resolve_env_root(self) -> Optional[Path]:
- """Determine whether the root of environment sync files exists."""
- root = Path(SYS_CUSTOMIZATIONS_SYNC_ROOT)
- if root.exists():
- return root
- return None
-
- def _resolve_open_ignore_functions(self) -> List[_IGNORE_FUNCTION]:
- """Used by the ``open`` method.
-
- If the entrypoint is a file, return an ignore function that will ignore everything except that file so only the
- file gets uploaded.
-
- """
- entrypoint = self.entrypoint.absolute()
- if entrypoint.is_file():
- return [lambda src, paths: [path for path in paths if path.absolute() == entrypoint]]
- return []
-
- def _resolve_repo(
- self,
- root: Path,
- ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
- default_ignore: bool = True,
- package_source: bool = True,
- sys_customizations_root: Optional[Path] = None,
- ) -> LocalSourceCodeDir:
- """Gather and merge all lightningignores from the app children and create the ``LocalSourceCodeDir`` object."""
- if ignore_functions is None:
- ignore_functions = []
-
- if self.app is not None:
- flow_lightningignores = [flow.lightningignore for flow in self.app.flows]
- work_lightningignores = [work.lightningignore for work in self.app.works]
- lightningignores = flow_lightningignores + work_lightningignores
- if lightningignores:
- merged = sum(lightningignores, ())
- logger.debug(f"Found the following lightningignores: {merged}")
- patterns = _parse_lightningignore(merged)
- ignore_functions = [*ignore_functions, partial(_filter_ignored, root, patterns)]
-
- return LocalSourceCodeDir(
- path=root,
- ignore_functions=ignore_functions,
- default_ignore=default_ignore,
- package_source=package_source,
- sys_customizations_root=sys_customizations_root,
- )
-
- def _resolve_project(self, project_id: Optional[str] = None) -> V1Membership:
- """Determine the project to run on, choosing a default if multiple projects are found."""
- return _get_project(self.backend.client, project_id=project_id)
-
- def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
- """Lists all the cloudspaces with a name matching the provided cloudspace name."""
- # TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces.
- existing_cloudspaces = self.backend.client.cloud_space_service_list_cloud_spaces(
- project_id=project_id
- ).cloudspaces
-
- # Search for cloudspaces with the given name (possibly with some random characters appended)
- pattern = re.escape(f"{cloudspace_name}-") + ".{4}"
- return [
- cloudspace
- for cloudspace in existing_cloudspaces
- if cloudspace.name == cloudspace_name or (re.fullmatch(pattern, cloudspace.name) is not None)
- ]
-
- def _resolve_cluster_id(
- self, cluster_id: Optional[str], project_id: str, existing_cloudspaces: List[V1CloudSpace]
- ) -> Optional[str]:
- """If cloudspaces exist and cluster is None, mimic cluster selection logic to choose a default."""
- # 1. Use the environement variables
- if cluster_id is None:
- cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
-
- # 2. Use the project bindings
- # TODO: Use the user prefered cluster.
- if cluster_id is None and len(existing_cloudspaces) > 0:
- # Determine the cluster ID
- cluster_id = _get_default_cluster(self.backend.client, project_id)
- return cluster_id
-
- def _resolve_existing_run_instance(
- self, cluster_id: Optional[str], project_id: str, existing_cloudspaces: List[V1CloudSpace]
- ) -> Tuple[Optional[V1CloudSpace], Optional[Externalv1LightningappInstance]]:
- """Look for an existing run and instance from one of the provided cloudspaces on the provided cluster."""
- existing_cloudspace = None
- existing_run_instance = None
-
- if cluster_id is not None:
- for cloudspace in existing_cloudspaces:
- run_instances = self.backend.client.lightningapp_instance_service_list_lightningapp_instances(
- project_id=project_id,
- app_id=cloudspace.id,
- ).lightningapps
- if run_instances and run_instances[0].spec.cluster_id == cluster_id:
- existing_cloudspace = cloudspace
- existing_run_instance = run_instances[0]
- break
- return existing_cloudspace, existing_run_instance
-
- def _resolve_run_instances_by_name(self, project_id: str, name: str) -> List[Externalv1LightningappInstance]:
- """Get all existing instances in the given project with the given name."""
- run_instances = self.backend.client.lightningapp_instance_service_list_lightningapp_instances(
- project_id=project_id,
- ).lightningapps
-
- return [run_instance for run_instance in run_instances if run_instance.display_name == name]
-
- def _resolve_cloudspace_name(
- self,
- cloudspace_name: str,
- existing_cloudspace: Optional[V1CloudSpace],
- existing_cloudspaces: List[V1CloudSpace],
- ) -> str:
- """If there are existing cloudspaces but not on the cluster - choose a randomised name."""
- if len(existing_cloudspaces) > 0 and existing_cloudspace is None:
- name_exists = True
- while name_exists:
- random_name = cloudspace_name + "-" + "".join(random.sample(string.ascii_letters, 4))
- name_exists = any(app.name == random_name for app in existing_cloudspaces)
-
- cloudspace_name = random_name
- return cloudspace_name
-
- def _resolve_run_name(
- self,
- name: str,
- existing_instances: List[Externalv1LightningappInstance],
- ) -> str:
- """If there are existing instances with the same name - choose a randomised name."""
- if len(existing_instances) > 0:
- name_exists = True
- while name_exists:
- random_name = name + "-" + "".join(random.sample(string.ascii_letters, 4))
- name_exists = any(app.name == random_name for app in existing_instances)
-
- name = random_name
- return name
-
- def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> Optional[V1CloudSpace]:
- """Returns a cloudspace by project_id and cloudspace_id, if exists."""
- return self.backend.client.cloud_space_service_get_cloud_space(
- project_id=project_id,
- id=cloudspace_id,
- )
-
- def _resolve_queue_server_type(self) -> V1QueueServerType:
- """Resolve the cloud queue type from the environment."""
- queue_server_type = V1QueueServerType.UNSPECIFIED
- # Note: Enable app to select their own queue type.
- queue_type = get_cloud_queue_type()
- if queue_type == "http":
- queue_server_type = V1QueueServerType.HTTP
- elif queue_type == "redis":
- queue_server_type = V1QueueServerType.REDIS
- return queue_server_type
-
- @staticmethod
- def _resolve_needs_credits(project: V1Membership):
- """Check if the user likely needs credits to run the app with its hardware.
-
- Returns False if user has 1 or more credits.
-
- """
- balance = project.balance
- if balance is None:
- balance = 0 # value is missing in some tests
-
- needs_credits = balance < 1
- if needs_credits:
- logger.warn("You may need Lightning credits to run your apps on the cloud.")
- return needs_credits
-
- @staticmethod
- def _validate_repo(root: Path, repo: LocalSourceCodeDir) -> None:
- """This method is used to inform the users if their folder files are large and how to filter them."""
- excludes = set(fnmatch.filter(repo.files, "*lightning-*.tar.gz"))
- excludes.update(fnmatch.filter(repo.files, ".lightningignore"))
- files = [Path(f) for f in repo.files if f not in excludes]
- file_sizes = {f: f.stat().st_size for f in files}
- mb = 1000_000
- app_folder_size_in_mb = sum(file_sizes.values()) / mb
- if app_folder_size_in_mb > CLOUD_UPLOAD_WARNING:
- # filter out files under 0.01mb
- relevant_files = {f: sz for f, sz in file_sizes.items() if sz > 0.01 * mb}
- if relevant_files:
- by_largest = dict(sorted(relevant_files.items(), key=lambda x: x[1], reverse=True))
- by_largest = dict(list(by_largest.items())[:25]) # trim
- largest_paths_msg = "\n".join(
- f"{round(sz / mb, 5)} MB: {p.relative_to(root)}" for p, sz in by_largest.items()
- )
- largest_paths_msg = f"Here are the largest files:\n{largest_paths_msg}\n"
- else:
- largest_paths_msg = ""
- warning_msg = (
- f"Your application folder '{root.absolute()}' is more than {CLOUD_UPLOAD_WARNING} MB. "
- f"The total size is {round(app_folder_size_in_mb, 2)} MB. {len(files)} files were uploaded.\n"
- + largest_paths_msg
- + "Perhaps you should try running the app in an empty directory.\n"
- + "You can ignore some files or folders by adding them to `.lightningignore`.\n"
- + " You can also set the `self.lightningingore` attribute in a Flow or Work."
- )
-
- logger.warn(warning_msg)
-
- def _validate_cluster_id(self, cluster_id: Optional[str], project_id: str):
- """Check that the provided cluster exists and ensure that it is bound to the given project."""
- if cluster_id is not None:
- # Verify that the cluster exists
- list_clusters_resp = self.backend.client.cluster_service_list_clusters()
- cluster_ids = [cluster.id for cluster in list_clusters_resp.clusters]
- if cluster_id not in cluster_ids:
- raise ValueError(
- f"You requested to run on cluster {cluster_id}, but that cluster doesn't exist."
- f" Found {list_clusters_resp} with project_id: {project_id}"
- )
-
- _ensure_cluster_project_binding(self.backend.client, project_id, cluster_id)
-
- def _validate_work_build_specs_and_compute(self) -> None:
- """Check that the cloud compute and build configs are valid for all works in the app."""
- for work in self.app.works:
- if work.cloud_build_config.image is not None and work.cloud_compute.name == "default":
- raise ValueError(
- f"You requested a custom base image for the Work with name '{work.name}', but custom images are "
- "currently not supported on the default cloud compute instance. Please choose a different "
- "configuration, for example `CloudCompute('cpu-medium')`."
- )
-
- def _validate_drives(self) -> None:
- """Check that all drives in the app have a valid protocol."""
- for work in self.app.works:
- for drive_attr_name, drive in [
- (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
- ]:
- if drive.protocol != "lit://":
- raise RuntimeError(
- f"Unknown drive protocol `{drive.protocol}` for drive `{work.name}.{drive_attr_name}`."
- )
-
- def _validate_mounts(self) -> None:
- """Check that all mounts in the app have a valid protocol."""
- for work in self.app.works:
- if work.cloud_compute.mounts is not None:
- mounts = work.cloud_compute.mounts
- for mount in [mounts] if isinstance(mounts, Mount) else mounts:
- if mount.protocol != "s3://":
- raise RuntimeError(f"Unknown mount protocol `{mount.protocol}` for work `{work.name}`.")
-
- def _get_flow_servers(self) -> List[V1Flowserver]:
- """Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs to
- start servers."""
- flow_servers: List[V1Flowserver] = []
- for flow_name in self.app.frontends:
- flow_server = V1Flowserver(name=flow_name)
- flow_servers.append(flow_server)
- return flow_servers
-
- @staticmethod
- def _get_network_configs(flow_servers: List[V1Flowserver]) -> Optional[List[V1NetworkConfig]]:
- """Get the list of network configs for the run if multiple works in default container is enabled."""
- network_configs = None
- if enable_multiple_works_in_default_container():
- network_configs = []
- initial_port = 8080 + 1 + len(flow_servers)
- for _ in range(DEFAULT_NUMBER_OF_EXPOSED_PORTS):
- network_configs.append(
- V1NetworkConfig(
- name="w" + str(initial_port),
- port=initial_port,
- )
- )
- initial_port += 1
- return network_configs
-
- @staticmethod
- def _get_drives(work: LightningWork) -> List[V1LightningworkDrives]:
- """Get the list of drive specifications for the provided work."""
- drives: List[V1LightningworkDrives] = []
- for drive_attr_name, drive in [
- (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
- ]:
- drives.append(
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name=f"{work.name}.{drive_attr_name}",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.NO_MOUNT_S3,
- source_type=V1SourceType.S3,
- source=f"{drive.protocol}{drive.id}",
- ),
- status=V1DriveStatus(),
- ),
- ),
- )
-
- return drives
-
- @staticmethod
- def _get_mounts(work: LightningWork) -> List[V1LightningworkDrives]:
- """Get the list of mount specifications for the provided work."""
- mounts = []
- if work.cloud_compute.mounts is not None:
- mount_objects = work.cloud_compute.mounts
- for mount in [mount_objects] if isinstance(mount_objects, Mount) else mount_objects:
- mounts.append(
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name=work.name,
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.INDEXED_S3,
- source_type=V1SourceType.S3,
- source=mount.source,
- ),
- status=V1DriveStatus(),
- ),
- mount_location=str(mount.mount_path),
- )
- )
- return mounts
-
- def _get_works(self, cloudspace: Optional[V1CloudSpace] = None) -> List[V1Work]:
- """Get the list of work specs from the app."""
- works: List[V1Work] = []
- for work in self.app.works:
- if not work._start_with_flow:
- continue
-
- work_requirements = "\n".join(work.cloud_build_config.requirements)
- build_spec = V1BuildSpec(
- commands=work.cloud_build_config.build_commands(),
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages=work_requirements
- ),
- image=work.cloud_build_config.image,
- )
- user_compute_config = V1UserRequestedComputeConfig(
- name=work.cloud_compute.name,
- count=1,
- disk_size=work.cloud_compute.disk_size,
- preemptible=work.cloud_compute.interruptible,
- shm_size=work.cloud_compute.shm_size,
- affinity_identifier=work.cloud_compute.colocation_group_id,
- )
-
- drives = self._get_drives(work)
- mounts = self._get_mounts(work)
-
- data_connection_mounts: list[V1DataConnectionMount] = []
- if cloudspace is not None and cloudspace.code_config is not None:
- data_connection_mounts = cloudspace.code_config.data_connection_mounts
-
- random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311
- work_spec = V1LightningworkSpec(
- build_spec=build_spec,
- drives=drives + mounts,
- user_requested_compute_config=user_compute_config,
- network_config=[V1NetworkConfig(name=random_name, port=work.port)],
- data_connection_mounts=data_connection_mounts,
- )
- works.append(V1Work(name=work.name, display_name=work.display_name, spec=work_spec))
-
- return works
-
- def _get_run_body(
- self,
- cluster_id: str,
- flow_servers: List[V1Flowserver],
- network_configs: Optional[List[V1NetworkConfig]],
- works: List[V1Work],
- no_cache: bool,
- root: Path,
- start_server: bool,
- should_mount_cloudspace_content: bool = False,
- absolute_entrypoint: bool = False,
- ) -> CloudspaceIdRunsBody:
- """Get the specification of the run creation request."""
- if absolute_entrypoint:
- # If the entrypoint will already exist in the cloud then we can choose to keep it as an absolute path.
- app_entrypoint_file = Path(self.entrypoint).absolute()
- else:
- # The entry point file needs to be relative to the root of the uploaded source file directory,
- # because the backend will invoke the lightning commands relative said source directory
- # TODO: we shouldn't set this if the entrypoint isn't a file but the backend gives an error if we don't
- app_entrypoint_file = Path(self.entrypoint).absolute().relative_to(root)
-
- run_body = CloudspaceIdRunsBody(
- cluster_id=cluster_id,
- app_entrypoint_file=str(app_entrypoint_file),
- enable_app_server=start_server,
- flow_servers=flow_servers,
- network_config=network_configs,
- works=works,
- local_source=True,
- should_mount_cloudspace_content=should_mount_cloudspace_content,
- )
-
- if self.app is not None:
- run_body.user_requested_flow_compute_config = V1UserRequestedFlowComputeConfig(
- name=self.app.flow_cloud_compute.name,
- shm_size=self.app.flow_cloud_compute.shm_size,
- preemptible=False,
- )
-
- run_body.is_headless = _is_headless(self.app)
-
- # if requirements file at the root of the repository is present,
- # we pass just the file name to the backend, so backend can find it in the relative path
- requirements_file = root / "requirements.txt"
- if requirements_file.is_file() and requirements_file.exists():
- requirements_path = requirements_file if absolute_entrypoint else "requirements.txt"
- run_body.image_spec = Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(package_manager=V1PackageManager.PIP, path=requirements_path)
- )
- if not DISABLE_DEPENDENCY_CACHE and not no_cache:
- # hash used for caching the dependencies
- run_body.dependency_cache_key = get_hash(requirements_file)
-
- return run_body
-
- @staticmethod
- def _get_auth(credentials: str) -> Optional[V1LightningAuth]:
- """If credentials are provided, parse them and return the auth spec."""
- auth = None
- if credentials != "":
- parsed_credentials = _credential_string_to_basic_auth_params(credentials)
- auth = V1LightningAuth(
- basic=V1LightningBasicAuth(
- username=parsed_credentials["username"], password=parsed_credentials["password"]
- )
- )
- return auth
-
- @staticmethod
- def _get_env_vars(
- env_vars: Dict[str, str], secrets: Dict[str, str], run_app_comment_commands: bool
- ) -> List[V1EnvVar]:
- """Generate the list of environment variable specs for the app, including variables set by the framework."""
- v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in env_vars.items()]
-
- if len(secrets.values()) > 0:
- secret_names_to_ids = _names_to_ids(secrets.values())
- env_vars_from_secrets = [V1EnvVar(name=k, from_secret=secret_names_to_ids[v]) for k, v in secrets.items()]
- v1_env_vars.extend(env_vars_from_secrets)
-
- if run_app_comment_commands or ENABLE_APP_COMMENT_COMMAND_EXECUTION:
- v1_env_vars.append(V1EnvVar(name="ENABLE_APP_COMMENT_COMMAND_EXECUTION", value="1"))
-
- if enable_multiple_works_in_default_container():
- v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", value="1"))
-
- if ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER:
- v1_env_vars.append(V1EnvVar(name="ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER", value="1"))
-
- if not ENABLE_PULLING_STATE_ENDPOINT:
- v1_env_vars.append(V1EnvVar(name="ENABLE_PULLING_STATE_ENDPOINT", value="0"))
-
- if not ENABLE_PUSHING_STATE_ENDPOINT:
- v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))
-
- if enable_interruptible_works():
- v1_env_vars.append(
- V1EnvVar(
- name="LIGHTNING_INTERRUPTIBLE_WORKS",
- value=os.getenv("LIGHTNING_INTERRUPTIBLE_WORKS", "0"),
- )
- )
-
- return v1_env_vars
-
- def _api_create_cloudspace_if_not_exists(
- self, project_id: str, name: str, existing_cloudspace: Optional[V1CloudSpace]
- ) -> str:
- """Create the cloudspace if it doesn't exist.
-
- Return the cloudspace ID.
-
- """
- if existing_cloudspace is None:
- cloudspace_body = ProjectIdCloudspacesBody(name=name, can_download_source_code=True)
- cloudspace = self.backend.client.cloud_space_service_create_cloud_space(
- project_id=project_id, body=cloudspace_body
- )
- return cloudspace.id
- return existing_cloudspace.id
-
- def _api_stop_existing_run_instance(
- self, project_id: str, existing_run_instance: Optional[Externalv1LightningappInstance]
- ) -> None:
- """If an existing instance is provided and it isn't stopped, stop it."""
- if existing_run_instance and existing_run_instance.status.phase != V1LightningappInstanceState.STOPPED:
- # TODO(yurij): Implement release switching in the UI and remove this
- # We can only switch release of the stopped instance
- existing_run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=existing_run_instance.id,
- body=Body3(spec=V1LightningappInstanceSpec(desired_state=V1LightningappInstanceState.STOPPED)),
- )
- # wait for the instance to stop for up to 150 seconds
- for _ in range(150):
- existing_run_instance = self.backend.client.lightningapp_instance_service_get_lightningapp_instance(
- project_id=project_id, id=existing_run_instance.id
- )
- if existing_run_instance.status.phase == V1LightningappInstanceState.STOPPED:
- break
- time.sleep(1)
- if existing_run_instance.status.phase != V1LightningappInstanceState.STOPPED:
- raise RuntimeError("Failed to stop the existing instance.")
-
- def _api_create_run(self, project_id: str, cloudspace_id: str, run_body: CloudspaceIdRunsBody) -> V1LightningRun:
- """Create and return the run."""
- return self.backend.client.cloud_space_service_create_lightning_run(
- project_id=project_id, cloudspace_id=cloudspace_id, body=run_body
- )
-
- def _api_transfer_run_instance(
- self,
- project_id: str,
- run_id: str,
- instance_id: str,
- desired_state: V1LightningappInstanceState,
- queue_server_type: Optional[V1QueueServerType] = None,
- env_vars: Optional[List[V1EnvVar]] = None,
- auth: Optional[V1LightningAuth] = None,
- ) -> Externalv1LightningappInstance:
- """Transfer an existing instance to the given run ID and update its specification.
-
- Return the instance.
-
- """
- run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance_release(
- project_id=project_id,
- id=instance_id,
- body=Body4(release_id=run_id),
- )
-
- self.backend.client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=instance_id,
- body=Body3(
- spec=V1LightningappInstanceSpec(
- desired_state=desired_state,
- queue_server_type=queue_server_type,
- env=env_vars,
- auth=auth,
- )
- ),
- )
-
- return run_instance
-
- def _api_create_run_instance(
- self,
- cluster_id: str,
- project_id: str,
- run_name: str,
- cloudspace_id: str,
- run_id: str,
- desired_state: V1LightningappInstanceState,
- queue_server_type: Optional[V1QueueServerType] = None,
- env_vars: Optional[List[V1EnvVar]] = None,
- auth: Optional[V1LightningAuth] = None,
- source_app: Optional[str] = None,
- keep_machines_after_stop: Optional[bool] = None,
- ) -> Externalv1LightningappInstance:
- """Create a new instance of the given run with the given specification."""
- return self.backend.client.cloud_space_service_create_lightning_run_instance(
- project_id=project_id,
- cloudspace_id=cloudspace_id,
- id=run_id,
- body=IdGetBody1(
- cluster_id=cluster_id,
- name=run_name,
- desired_state=desired_state,
- queue_server_type=queue_server_type,
- env=env_vars,
- auth=auth,
- source_app=source_app,
- keep_machines_after_stop=keep_machines_after_stop,
- ),
- )
-
- @staticmethod
- def _api_package_and_upload_repo(
- repo: LocalSourceCodeDir,
- run: V1LightningRun,
- ) -> None:
- """Package and upload the provided local source code directory to the provided run."""
- if run.source_upload_url == "":
- raise RuntimeError("The source upload url is empty.")
- repo.package()
- repo.upload(url=run.source_upload_url)
-
- @staticmethod
- def _print_specs(run_body: CloudspaceIdRunsBody, print_format: str) -> None:
- """Print the given run body in either `web` or `gallery` format."""
- if print_format not in ("web", "gallery"):
- raise ValueError(
- f"`LIGHTNING_CLOUD_PRINT_SPECS` should be either `web` or `gallery`. You provided: {print_format}"
- )
-
- flow_servers_json = [{"Name": flow_server.name} for flow_server in run_body.flow_servers]
- logger.info(f"flow_servers: {flow_servers_json}")
- works_json = json.dumps(_to_clean_dict(run_body.works, print_format == "web"), separators=(",", ":"))
- logger.info(f"works: {works_json}")
- logger.info(f"entrypoint_file: {run_body.app_entrypoint_file}")
- requirements_path = getattr(getattr(run_body.image_spec, "dependency_file_info", ""), "path", "")
- logger.info(f"requirements_path: {requirements_path}")
-
- def _get_cloudspace_url(
- self, project: V1Membership, cloudspace_name: str, tab: str, need_credits: bool = False
- ) -> str:
- user = self.backend.client.auth_service_get_user()
- action = "?action=add_credits" if need_credits else ""
- paths = [
- user.username,
- project.name,
- "apps",
- cloudspace_name,
- tab,
- ]
- path = "/".join([quote(path, safe="") for path in paths])
- return f"{get_lightning_cloud_url()}/{path}{action}"
-
- def _get_app_url(
- self,
- project: V1Membership,
- run_instance: Externalv1LightningappInstance,
- tab: str,
- need_credits: bool = False,
- ) -> str:
- user = self.backend.client.auth_service_get_user()
- action = "?action=add_credits" if need_credits else ""
- if user.features.project_selector:
- paths = [
- user.username,
- project.name,
- "jobs",
- run_instance.name,
- tab,
- ]
- else:
- paths = [
- user.username,
- "apps",
- run_instance.id,
- tab,
- ]
- path = "/".join([quote(path, safe="") for path in paths])
- return f"{get_lightning_cloud_url()}/{path}{action}"
diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py
deleted file mode 100644
index c3217197a6a33..0000000000000
--- a/src/lightning/app/runners/multiprocess.py
+++ /dev/null
@@ -1,160 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import multiprocessing
-import os
-from dataclasses import dataclass
-from typing import Any, Union
-
-import click
-
-from lightning.app.api.http_methods import _add_tags_to_api, _validate_api
-from lightning.app.core import constants
-from lightning.app.core.api import start_server
-from lightning.app.runners.backends import Backend
-from lightning.app.runners.runtime import Runtime
-from lightning.app.storage.orchestrator import StorageOrchestrator
-from lightning.app.utilities.app_helpers import _is_headless, is_overridden
-from lightning.app.utilities.commands.base import _commands_to_api, _prepare_commands
-from lightning.app.utilities.component import _set_flow_context, _set_frontend_context
-from lightning.app.utilities.load_app import extract_metadata_from_app
-from lightning.app.utilities.network import find_free_network_port
-from lightning.app.utilities.port import disable_port
-
-
-@dataclass
-class MultiProcessRuntime(Runtime):
- """Runtime to launch the LightningApp into multiple processes.
-
- The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
- queues to enable communication between the different processes.
-
- """
-
- backend: Union[str, Backend] = "multiprocessing"
- _has_triggered_termination: bool = False
-
- def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
- """Method to dispatch and run the LightningApp."""
- try:
- _set_flow_context()
-
- # Note: In case the runtime is used in the cloud.
- in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None
- self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host # noqa: S104
-
- self.app.backend = self.backend
- self.backend._prepare_queues(self.app)
- self.backend.resolve_url(self.app, "http://127.0.0.1")
- self.app._update_index_file()
-
- # set env variables
- os.environ.update(self.env_vars)
-
- # refresh the layout with the populated urls.
- self.app._update_layout()
-
- _set_frontend_context()
- for frontend in self.app.frontends.values():
- port = find_free_network_port()
-
- server_host = "0.0.0.0" if in_cloudspace else "localhost" # noqa: S104
- server_target = (
- f"https://{port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
- if in_cloudspace
- else f"http://localhost:{port}"
- )
-
- frontend.start_server(host=server_host, port=port)
- frontend.flow._layout["target"] = f"{server_target}/{frontend.flow.name}"
-
- _set_flow_context()
-
- storage_orchestrator = StorageOrchestrator(
- self.app,
- self.app.request_queues,
- self.app.response_queues,
- self.app.copy_request_queues,
- self.app.copy_response_queues,
- )
- self.threads.append(storage_orchestrator)
- storage_orchestrator.setDaemon(True)
- storage_orchestrator.start()
-
- if self.start_server:
- self.app.should_publish_changes_to_api = True
- has_started_queue = self.backend.queues.get_has_server_started_queue()
-
- apis = []
- if is_overridden("configure_api", self.app.root):
- apis = self.app.root.configure_api()
- _validate_api(apis)
- _add_tags_to_api(apis, ["app_api"])
-
- if is_overridden("configure_commands", self.app.root):
- commands = _prepare_commands(self.app)
- apis += _commands_to_api(commands, info=self.app.info)
-
- kwargs = {
- "apis": apis,
- "host": self.host,
- "port": self.port,
- "api_response_queue": self.app.api_response_queue,
- "api_publish_state_queue": self.app.api_publish_state_queue,
- "api_delta_queue": self.app.api_delta_queue,
- "has_started_queue": has_started_queue,
- "spec": extract_metadata_from_app(self.app),
- "root_path": self.app.root_path,
- }
- server_proc = multiprocessing.Process(target=start_server, kwargs=kwargs)
- self.processes["server"] = server_proc
- server_proc.start()
- # requires to wait for the UI to be clicked on.
-
- # wait for server to be ready
- has_started_queue.get()
-
- if all([
- open_ui,
- "PYTEST_CURRENT_TEST" not in os.environ,
- not _is_headless(self.app),
- constants.LIGHTNING_CLOUDSPACE_HOST is None,
- ]):
- click.launch(self._get_app_url())
-
- # Connect the runtime to the application.
- self.app.connect(self)
-
- # Once the bootstrapping is done, running the rank 0
- # app with all the components inactive
- self.app._run()
- except KeyboardInterrupt:
- self.terminate()
- self._has_triggered_termination = True
- raise
- finally:
- if not self._has_triggered_termination:
- self.terminate()
-
- def terminate(self):
- if constants.APP_SERVER_IN_CLOUD:
- # Close all the ports open for the App within the App.
- ports = [self.port] + getattr(self.backend, "ports", [])
- for port in ports:
- disable_port(port)
- super().terminate()
-
- @staticmethod
- def _get_app_url() -> str:
- return os.getenv("APP_SERVER_HOST", "http://127.0.0.1:7501/view")
diff --git a/src/lightning/app/runners/runtime.py b/src/lightning/app/runners/runtime.py
deleted file mode 100644
index be938436ff76a..0000000000000
--- a/src/lightning/app/runners/runtime.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import multiprocessing
-import os
-import sys
-from dataclasses import dataclass, field
-from pathlib import Path
-from threading import Thread
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
-
-from lightning.app.core import LightningApp, LightningFlow
-from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT
-from lightning.app.runners.backends import Backend, BackendType
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.enum import AppStage, CacheCallsKeys, WorkStageStatus, make_status
-from lightning.app.utilities.load_app import load_app_from_file
-from lightning.app.utilities.proxies import WorkRunner
-
-logger = Logger(__name__)
-
-if TYPE_CHECKING:
- import lightning.app
-
-
-def dispatch(
- entrypoint_file: Path,
- runtime_type: "lightning.app.runners.runtime_type.RuntimeType",
- start_server: bool = True,
- no_cache: bool = False,
- host: str = APP_SERVER_HOST,
- port: int = APP_SERVER_PORT,
- blocking: bool = True,
- open_ui: bool = True,
- name: str = "",
- env_vars: Optional[Dict[str, str]] = None,
- secrets: Optional[Dict[str, str]] = None,
- run_app_comment_commands: bool = False,
- enable_basic_auth: str = "",
-) -> Optional[Any]:
- """Bootstrap and dispatch the application to the target.
-
- Arguments:
- entrypoint_file: Filepath to the current script
- runtime_type: The runtime to be used for launching the app.
- start_server: Whether to run the app REST API.
- no_cache: Whether to use the dependency cache for the app.
- host: Server host address
- port: Server port
- blocking: Whether for the wait for the UI to start running.
- open_ui: Whether to open the UI in the browser.
- name: Name of app execution
- env_vars: Dict of env variables to be set on the app
- secrets: Dict of secrets to be passed as environment variables to the app
- run_app_comment_commands: whether to parse commands from the entrypoint file and execute them before app startup
- enable_basic_auth: whether to enable basic authentication for the app
- (use credentials in the format username:password as an argument)
-
- """
- from lightning.app.runners.runtime_type import RuntimeType
- from lightning.app.utilities.component import _set_flow_context
-
- _set_flow_context()
-
- runtime_type = RuntimeType(runtime_type)
- runtime_cls: Type[Runtime] = runtime_type.get_runtime()
- app = runtime_cls.load_app_from_file(str(entrypoint_file))
-
- env_vars = {} if env_vars is None else env_vars
- secrets = {} if secrets is None else secrets
-
- if blocking:
- app.stage = AppStage.BLOCKING
-
- runtime = runtime_cls(
- app=app,
- entrypoint=entrypoint_file,
- start_server=start_server,
- host=host,
- port=port,
- env_vars=env_vars,
- secrets=secrets,
- run_app_comment_commands=run_app_comment_commands,
- enable_basic_auth=enable_basic_auth,
- )
- # Used to indicate Lightning has been dispatched
- os.environ["LIGHTNING_DISPATCHED"] = "1"
- # a cloud dispatcher will return the result while local
- # dispatchers will be running the app in the main process
- return runtime.dispatch(open_ui=open_ui, name=name, no_cache=no_cache)
-
-
-@dataclass
-class Runtime:
- app: Optional[LightningApp] = None
- entrypoint: Optional[Path] = None
- start_server: bool = True
- host: str = APP_SERVER_HOST
- port: int = APP_SERVER_PORT
- processes: Dict[str, multiprocessing.Process] = field(default_factory=dict)
- threads: List[Thread] = field(default_factory=list)
- work_runners: Dict[str, WorkRunner] = field(default_factory=dict)
- done: bool = False
- backend: Optional[Union[str, Backend]] = "multiprocessing"
- env_vars: Dict[str, str] = field(default_factory=dict)
- secrets: Dict[str, str] = field(default_factory=dict)
- run_app_comment_commands: bool = False
- enable_basic_auth: str = ""
-
- def __post_init__(self):
- if isinstance(self.backend, str):
- self.backend = BackendType(self.backend).get_backend(self.entrypoint)
-
- if self.app is not None:
- LightningFlow._attach_backend(self.app.root, self.backend)
-
- def terminate(self) -> None:
- """This method is used to terminate all the objects (threads, processes, etc..) created by the app."""
- logger.info("Your Lightning App is being stopped. This won't take long.")
- self.done = False
- has_messaged = False
- while not self.done:
- try:
- if self.app.backend is not None:
- self.app.backend.stop_all_works(self.app.works)
-
- if self.app.api_publish_state_queue:
- for work in self.app.works:
- self._add_stopped_status_to_work(work)
-
- # Publish the updated state and wait for the frontend to update.
- self.app.api_publish_state_queue.put((self.app.state, self.app.status))
-
- for thread in self.threads + self.app.threads:
- thread.join(timeout=0)
-
- for frontend in self.app.frontends.values():
- frontend.stop_server()
-
- for proc in list(self.processes.values()) + list(self.app.processes.values()):
- if proc.is_alive():
- proc.kill()
-
- self.done = True
-
- except KeyboardInterrupt:
- if not has_messaged:
- logger.info("Your Lightning App is being stopped. This won't take long.")
- has_messaged = True
-
- if self.done:
- logger.info("Your Lightning App has been stopped successfully!")
-
- # Inform the application failed.
- if self.app.stage == AppStage.FAILED:
- sys.exit(1)
-
- def dispatch(self, *args: Any, **kwargs: Any):
- raise NotImplementedError
-
- def _add_stopped_status_to_work(self, work: "lightning.app.LightningWork") -> None:
- if work.status.stage == WorkStageStatus.STOPPED:
- return
-
- latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
- if latest_call_hash in work._calls:
- work._calls[latest_call_hash]["statuses"].append(make_status(WorkStageStatus.STOPPED))
-
- @classmethod
- def load_app_from_file(cls, filepath: str) -> "LightningApp":
- return load_app_from_file(filepath)
diff --git a/src/lightning/app/runners/runtime_type.py b/src/lightning/app/runners/runtime_type.py
deleted file mode 100644
index 26212385d82bd..0000000000000
--- a/src/lightning/app/runners/runtime_type.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from enum import Enum
-from typing import TYPE_CHECKING, Type
-
-from lightning.app.runners import CloudRuntime, MultiProcessRuntime
-
-if TYPE_CHECKING:
- from lightning.app.runners.runtime import Runtime
-
-
-class RuntimeType(Enum):
- MULTIPROCESS = "multiprocess"
- CLOUD = "cloud"
-
- def get_runtime(self) -> Type["Runtime"]:
- if self == RuntimeType.MULTIPROCESS:
- return MultiProcessRuntime
- if self == RuntimeType.CLOUD:
- return CloudRuntime
- raise ValueError("Unknown runtime type")
diff --git a/src/lightning/app/source_code/__init__.py b/src/lightning/app/source_code/__init__.py
deleted file mode 100644
index 8b4fffd27762e..0000000000000
--- a/src/lightning/app/source_code/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from lightning.app.source_code.local import LocalSourceCodeDir
-from lightning.app.source_code.uploader import FileUploader
-
-__all__ = [
- "LocalSourceCodeDir",
- "FileUploader",
-]
diff --git a/src/lightning/app/source_code/copytree.py b/src/lightning/app/source_code/copytree.py
deleted file mode 100644
index fa2716e94c5c7..0000000000000
--- a/src/lightning/app/source_code/copytree.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import fnmatch
-import os
-from functools import partial
-from pathlib import Path
-from shutil import Error, copy2, copystat
-from typing import Callable, List, Optional, Set, Tuple, Union
-
-from lightning.app.core.constants import DOT_IGNORE_FILENAME
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-_IGNORE_FUNCTION = Callable[[Path, List[Path]], List[Path]]
-
-
-def _copytree(
- src: Union[Path, str],
- dst: Union[Path, str],
- ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
- dirs_exist_ok=False,
- dry_run=False,
-) -> List[str]:
- """Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like `git`
- does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. Differences
- between original and this function are.
-
- 1. It supports a list of ignore function instead of a single one in the
- original. We can use this for filtering out files based on nested
- .lightningignore files
- 2. It supports a dry run. When enabled, this function will not copy anything but just recursively
- find the source files which are not-ignored and return them. It is useful while calculating
- the hash or checking the size of files
- 3. This function returns a list of copied files unlike the original which was returning the
- destination directory
-
- Recursively copy a directory tree and return the destination directory.
-
- Parameters
- ----------
- src:
- Source directory path to copy from
- dst:
- Destination directory path to copy to
- ignore_functions:
- List of functions that will be used to filter out files
- and directories. This isn't required to be passed when calling from outside but will be
- autopopulated by the recursive calls in this function itself (Original copytree doesn't have this argument)
- dirs_exist_ok:
- If true, the destination directory will be created if it doesn't exist.
- dry_run:
- If true, this function will not copy anything (this is not present in the original copytree)
-
-
- If exception(s) occur, an Error is raised with a list of reasons.
-
- """
- files_copied = []
-
- if ignore_functions is None:
- ignore_functions = []
-
- _ignore_filename_spell_check(src)
- src = Path(src)
- dst = Path(dst)
- ignore_filepath = src / DOT_IGNORE_FILENAME
- if ignore_filepath.is_file():
- patterns = _read_lightningignore(ignore_filepath)
- ignore_fn = partial(_filter_ignored, src, patterns)
- # creating new list so we won't modify the original
- ignore_functions = [*ignore_functions, ignore_fn]
-
- if not dry_run:
- os.makedirs(dst, exist_ok=dirs_exist_ok)
-
- errors = []
-
- entries = list(src.iterdir())
- for fn in ignore_functions:
- # ignore function return only the entries that are not ignored
- entries = fn(src, entries)
-
- for srcentry in entries:
- dstpath = dst / srcentry.name
- try:
- if srcentry.is_dir():
- _files = _copytree(
- src=srcentry,
- dst=dstpath,
- ignore_functions=ignore_functions,
- dirs_exist_ok=dirs_exist_ok,
- dry_run=dry_run,
- )
- files_copied.extend(_files)
- else:
- files_copied.append(str(srcentry))
- if not dry_run:
- # Will raise a SpecialFileError for unsupported file types
- copy2(srcentry, dstpath)
- # catch the Error from the recursive copytree so that we can
- # continue with other files
- except Error as err:
- errors.extend(err.args[0])
- except OSError as why:
- errors.append((srcentry, dstpath, str(why)))
- try:
- if not dry_run:
- copystat(src, dst)
- except OSError as why:
- # Copying file access times may fail on Windows
- if getattr(why, "winerror", None) is None:
- errors.append((src, dst, str(why)))
- if errors:
- raise Error(errors)
- return files_copied
-
-
-def _filter_ignored(src: Path, patterns: Set[str], current_dir: Path, entries: List[Path]) -> List[Path]:
- relative_dir = current_dir.relative_to(src)
- names = [str(relative_dir / entry.name) for entry in entries]
- ignored_names = set()
- for pattern in patterns:
- ignored_names.update(fnmatch.filter(names, pattern))
- return [entry for entry in entries if str(relative_dir / entry.name) not in ignored_names]
-
-
-def _parse_lightningignore(lines: Tuple[str]) -> Set[str]:
- """Creates a set that removes empty lines and comments."""
- lines = [ln.strip() for ln in lines]
- # removes first `/` character for posix and `\\` for windows
- lines = [ln.lstrip("/").lstrip("\\") for ln in lines if ln != "" and not ln.startswith("#")]
- # convert to path and converting back to string to sanitize the pattern
- return {str(Path(ln)) for ln in lines}
-
-
-def _read_lightningignore(path: Path) -> Set[str]:
- """Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's done
- to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
-
- Parameters
- ----------
- path: Path
- Path to .lightningignore file or equivalent.
-
- Returns
- -------
- Set[str]
- Set of unique lines.
-
- """
- raw_lines = path.open().readlines()
- return _parse_lightningignore(raw_lines)
-
-
-def _ignore_filename_spell_check(src: Path):
- possible_spelling_mistakes = [
- ".gridignore",
- ".lightingignore",
- ".lightinginore",
- ".lightninginore",
- ".lightninignore",
- ".lightinignore",
- ]
- possible_spelling_mistakes.extend([p.lstrip(".") for p in possible_spelling_mistakes])
- for path in src.iterdir():
- if path.is_file() and path.name in possible_spelling_mistakes:
- logger.warn(
- f"Lightning uses `{DOT_IGNORE_FILENAME}` as the ignore file but found {path.name} at "
- f"{path.parent} instead. If this was a mistake, please rename the file."
- )
diff --git a/src/lightning/app/source_code/hashing.py b/src/lightning/app/source_code/hashing.py
deleted file mode 100644
index 8ca8f0cf72f51..0000000000000
--- a/src/lightning/app/source_code/hashing.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import hashlib
-from typing import List
-
-
-def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int = 128) -> str:
- """Hashes the contents of a list of files.
-
- Parameters
- ----------
- files: List[Path]
- List of files.
- algorithm: str, default "blake2"
- Algorithm to hash contents. "blake2" is set by default because it
- is faster than "md5". [1]
- chunk_num_blocks: int, default 128
- Block size to user when iterating over file chunks.
-
- References
- ----------
- [1] https://crypto.stackexchange.com/questions/70101/blake2-vs-md5-for-checksum-file-integrity
- [2] https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python
-
- """
- # validate input
- if algorithm == "blake2":
- h = hashlib.blake2b(digest_size=20)
- elif algorithm == "md5":
- h = hashlib.md5()
- else:
- raise ValueError(f"Algorithm {algorithm} not supported")
-
- # calculate hash for all files
- for file in files:
- with open(file, "rb") as f:
- for chunk in iter(lambda: f.read(chunk_num_blocks * h.block_size), b""):
- h.update(chunk)
- return h.hexdigest()
diff --git a/src/lightning/app/source_code/local.py b/src/lightning/app/source_code/local.py
deleted file mode 100644
index 4062b5cab54d7..0000000000000
--- a/src/lightning/app/source_code/local.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import uuid
-from contextlib import contextmanager
-from pathlib import Path
-from shutil import copytree, rmtree
-from typing import List, Optional
-
-from lightning.app.core.constants import DOT_IGNORE_FILENAME, SYS_CUSTOMIZATIONS_SYNC_PATH
-from lightning.app.source_code.copytree import _IGNORE_FUNCTION, _copytree
-from lightning.app.source_code.tar import _tar_path
-from lightning.app.source_code.uploader import FileUploader
-
-
-class LocalSourceCodeDir:
- """Represents the source code directory and provide the utilities to manage it."""
-
- def __init__(
- self,
- path: Path,
- ignore_functions: Optional[List[_IGNORE_FUNCTION]] = None,
- default_ignore: bool = True,
- package_source: bool = True,
- sys_customizations_root: Optional[Path] = None,
- ) -> None:
- if "LIGHTNING_VSCODE_WORKSPACE" in os.environ:
- # Don't use home to store the tar ball. This won't play nice with symlinks
- self.cache_location: Path = Path("/tmp", ".lightning", "cache", "repositories")
- else:
- self.cache_location: Path = Path.home() / ".lightning" / "cache" / "repositories"
-
- self.path = path
- self.ignore_functions = ignore_functions
- self.package_source = package_source
- self.sys_customizations_root = sys_customizations_root
-
- # cache version
- self._version: Optional[str] = None
- self._non_ignored_files: Optional[List[str]] = None
-
- # create global cache location if it doesn't exist
- if not self.cache_location.exists():
- self.cache_location.mkdir(parents=True, exist_ok=True)
-
- # Create a default dotignore if requested and it doesn't exist
- if default_ignore and not (path / DOT_IGNORE_FILENAME).is_file():
- with open(path / DOT_IGNORE_FILENAME, "w") as f:
- f.write("venv/\n")
- if (path / "bin" / "activate").is_file() or (path / "pyvenv.cfg").is_file():
- # the user is developing inside venv
- f.write("bin/\ninclude/\nlib/\npyvenv.cfg\n")
-
- # clean old cache entries
- self._prune_cache()
-
- @property
- def files(self) -> List[str]:
- """Returns a set of files that are not ignored by .lightningignore."""
- if self._non_ignored_files is None:
- if self.package_source:
- self._non_ignored_files = _copytree(self.path, "", ignore_functions=self.ignore_functions, dry_run=True)
- else:
- self._non_ignored_files = []
- return self._non_ignored_files
-
- @property
- def version(self):
- """Calculates the checksum of a local path."""
- # cache value to prevent doing this over again
- if self._version is not None:
- return self._version
-
- # create a random version ID and store it
- self._version = uuid.uuid4().hex
- return self._version
-
- @property
- def package_path(self):
- """Location to tarball in local cache."""
- filename = f"{self.version}.tar.gz"
- return self.cache_location / filename
-
- @contextmanager
- def packaging_session(self) -> Path:
- """Creates a local directory with source code that is used for creating a local source-code package."""
- session_path = self.cache_location / "packaging_sessions" / self.version
- try:
- rmtree(session_path, ignore_errors=True)
- if self.package_source:
- _copytree(self.path, session_path, ignore_functions=self.ignore_functions)
- if self.sys_customizations_root is not None:
- path_to_sync = Path(session_path, SYS_CUSTOMIZATIONS_SYNC_PATH)
- copytree(self.sys_customizations_root, path_to_sync, dirs_exist_ok=True)
- yield session_path
- finally:
- rmtree(session_path, ignore_errors=True)
-
- def _prune_cache(self) -> None:
- """Prunes cache; only keeps the 10 most recent items."""
- packages = sorted(self.cache_location.iterdir(), key=os.path.getmtime)
- for package in packages[10:]:
- if package.is_file():
- package.unlink()
-
- def package(self) -> Path:
- """Packages local path using tar."""
- if self.package_path.exists():
- return self.package_path
- # create a packaging session if not available
- with self.packaging_session() as session_path:
- _tar_path(source_path=session_path, target_file=str(self.package_path), compression=True)
- return self.package_path
-
- def upload(self, url: str) -> None:
- """Uploads package to URL, usually pre-signed UR.
-
- Notes
- -----
- Since we do not use multipart uploads here, we cannot upload any
- packaged repository files which have a size > 2GB.
-
- This limitation should be removed during the datastore upload redesign
-
- """
- if self.package_path.stat().st_size > 2e9:
- raise OSError(
- "cannot upload directory code whose total fize size is greater than 2GB (2e9 bytes)"
- ) from None
-
- uploader = FileUploader(
- presigned_url=url,
- source_file=str(self.package_path),
- name=self.package_path.name,
- total_size=self.package_path.stat().st_size,
- )
- uploader.upload()
diff --git a/src/lightning/app/source_code/tar.py b/src/lightning/app/source_code/tar.py
deleted file mode 100644
index 5b92201df2485..0000000000000
--- a/src/lightning/app/source_code/tar.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-import os
-import subprocess
-import tarfile
-from dataclasses import dataclass
-from typing import Optional, Tuple
-
-import click
-
-MAX_SPLIT_COUNT = 999
-
-
-def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tuple[int, int]:
- """Get size and file count of a directory.
-
- Parameters
- ----------
- source_dir: str
- Directory path
-
- Returns
- -------
- Tuple[int, int]
- Size in megabytes and file count
-
- """
- size = 0
- count = 0
- for root, _, files in os.walk(source_dir, topdown=True):
- for f in files:
- if prefix and not f.startswith(prefix):
- continue
-
- full_path = os.path.join(root, f)
- size += os.path.getsize(full_path)
- count += 1
-
- return (size, count)
-
-
-@dataclass
-class _TarResults:
- """This class holds the results of running tar_path.
-
- Attributes
- ----------
- before_size: int
- The total size of the original directory files in bytes
- after_size: int
- The total size of the compressed and tarred split files in bytes
-
- """
-
- before_size: int
- after_size: int
-
-
-def _get_split_size(
- total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
-) -> int:
- """Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited
- to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
- https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
- https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
- https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
- for a single file upload.
-
- Parameters
- ----------
- minimum_split_size: int
- The minimum split size to use
- max_split_count: int
- The maximum split count
- total_size: int
- Total size of the file to split
-
- Returns
- -------
- int
- Split size
-
- """
- max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
- if total_size > max_size:
- raise click.ClickException(
- f"The size of the datastore to be uploaded is bigger than our {max_size / (1 << 40):.2f} TBytes limit"
- )
-
- split_size = minimum_split_size
- split_count = math.ceil(total_size / split_size)
- if split_count > max_split_count:
- # Adjust the split size based on max split count
- split_size = math.ceil(total_size / max_split_count)
-
- return split_size
-
-
-def _tar_path(source_path: str, target_file: str, compression: bool = False) -> _TarResults:
- """Create tar from directory using `tar`
-
- Parameters
- ----------
- source_path: str
- Source directory or file
- target_file
- Target tar file
- compression: bool, default False
- Enable compression, which is disabled by default.
-
- Returns
- -------
- TarResults
- Results that holds file counts and sizes
-
- """
- if os.path.isdir(source_path):
- before_size, _ = _get_dir_size_and_count(source_path)
- else:
- before_size = os.path.getsize(source_path)
-
- try:
- _tar_path_subprocess(source_path, target_file, compression)
- except subprocess.CalledProcessError:
- _tar_path_python(source_path, target_file, compression)
-
- after_size = os.stat(target_file).st_size
- return _TarResults(before_size=before_size, after_size=after_size)
-
-
-def _tar_path_python(source_path: str, target_file: str, compression: bool = False) -> None:
- """Create tar from directory using `python`
-
- Parameters
- ----------
- source_path: str
- Source directory or file
- target_file
- Target tar file
- compression: bool, default False
- Enable compression, which is disabled by default.
-
- """
- file_mode = "w:gz" if compression else "w:"
-
- with tarfile.open(target_file, file_mode) as tar:
- if os.path.isdir(source_path):
- tar.add(str(source_path), arcname=".")
- elif os.path.isfile(source_path):
- file_info = tarfile.TarInfo(os.path.basename(str(source_path)))
- with open(source_path) as fo:
- tar.addfile(file_info, fo)
-
-
-def _tar_path_subprocess(source_path: str, target_file: str, compression: bool = False) -> None:
- """Create tar from directory using `tar`
-
- Parameters
- ----------
- source_path: str
- Source directory or file
- target_file
- Target tar file
- compression: bool, default False
- Enable compression, which is disabled by default.
-
- """
- # Only add compression when users explicitly request it.
- # We do this because it takes too long to compress
- # large datastores.
- tar_flags = "-cvf"
- if compression:
- tar_flags = "-zcvf"
- if os.path.isdir(source_path):
- command = f"tar -C {source_path} {tar_flags} {target_file} ./"
- else:
- abs_path = os.path.abspath(source_path)
- parent_dir = os.path.dirname(abs_path)
- base_name = os.path.basename(abs_path)
- command = f"tar -C {parent_dir} {tar_flags} {target_file} {base_name}"
-
- subprocess.check_call(
- command,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- shell=True,
- env={"GZIP": "-9", "COPYFILE_DISABLE": "1"},
- )
diff --git a/src/lightning/app/source_code/uploader.py b/src/lightning/app/source_code/uploader.py
deleted file mode 100644
index 82336c7b0b96f..0000000000000
--- a/src/lightning/app/source_code/uploader.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-
-import requests
-from requests.adapters import HTTPAdapter
-from rich.progress import BarColumn, Progress, TextColumn
-from urllib3.util.retry import Retry
-
-
-class FileUploader:
- """This class uploads a source file with presigned url to S3.
-
- Attributes
- ----------
- source_file: str
- Source file to upload
- presigned_url: str
- Presigned urls dictionary, with key as part number and values as urls
- retries: int
- Amount of retries when requests encounter an error
- total_size: int
- Size of all files to upload
- name: str
- Name of this upload to display progress
-
- """
-
- workers: int = 8
- retries: int = 10000
- disconnect_retry_wait_seconds: int = 5
- progress = Progress(
- TextColumn("[bold blue]{task.description}", justify="left"),
- BarColumn(bar_width=None),
- "[self.progress.percentage]{task.percentage:>3.1f}%",
- )
-
- def __init__(self, presigned_url: str, source_file: str, total_size: int, name: str, use_progress: bool = True):
- self.presigned_url = presigned_url
- self.source_file = source_file
- self.total_size = total_size
- self.name = name
- self.use_progress = use_progress
- self.task_id = None
-
- def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait_seconds: int) -> str:
- """Send data to url.
-
- Parameters
- ----------
- url: str
- url string to send data to
- data: bytes
- Bytes of data to send to url
- retries: int
- Amount of retries
- disconnect_retry_wait_seconds: int
- Amount of seconds between disconnect retry
-
- Returns
- -------
- str
- ETag from response
-
- """
- disconnect_retries = retries
- while disconnect_retries > 0:
- try:
- retries = Retry(total=10)
- with requests.Session() as s:
- s.mount("https://", HTTPAdapter(max_retries=retries))
- return self._upload_data(s, url, data)
- except BrokenPipeError:
- time.sleep(disconnect_retry_wait_seconds)
- disconnect_retries -= 1
-
- raise ValueError("Unable to upload file after multiple attempts")
-
- def _upload_data(self, s: requests.Session, url: str, data: bytes):
- resp = s.put(url, data=data)
- if "ETag" not in resp.headers:
- raise ValueError(f"Unexpected response from {url}, response: {resp.content}")
- return resp.headers["ETag"]
-
- def upload(self) -> None:
- """Upload files from source dir into target path in S3."""
- no_task = self.task_id is None
- if self.use_progress and no_task:
- self.task_id = self.progress.add_task("upload", filename=self.name, total=self.total_size)
- self.progress.start()
- try:
- with open(self.source_file, "rb") as f:
- data = f.read()
- self.upload_data(self.presigned_url, data, self.retries, self.disconnect_retry_wait_seconds)
- if self.use_progress:
- self.progress.update(self.task_id, advance=len(data))
- finally:
- if self.use_progress and no_task:
- self.progress.stop()
diff --git a/src/lightning/app/storage/__init__.py b/src/lightning/app/storage/__init__.py
deleted file mode 100644
index 8ed541dd2de7e..0000000000000
--- a/src/lightning/app/storage/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app.storage.drive import Drive # noqa: F401
-from lightning.app.storage.filesystem import FileSystem # noqa: F401
-from lightning.app.storage.mount import Mount # noqa: F401
-from lightning.app.storage.orchestrator import StorageOrchestrator # noqa: F401
-from lightning.app.storage.path import Path # noqa: F401
-from lightning.app.storage.payload import Payload # noqa: F401
diff --git a/src/lightning/app/storage/copier.py b/src/lightning/app/storage/copier.py
deleted file mode 100644
index f4bf799e2f395..0000000000000
--- a/src/lightning/app/storage/copier.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import concurrent.futures
-import pathlib
-import threading
-from threading import Thread
-from time import time
-from typing import TYPE_CHECKING, Optional, Union
-
-from fsspec import AbstractFileSystem
-from fsspec.implementations.local import LocalFileSystem
-
-from lightning.app.core.queues import BaseQueue
-from lightning.app.storage.path import _filesystem
-from lightning.app.storage.requests import _ExistsRequest, _GetRequest
-from lightning.app.utilities.app_helpers import Logger
-
-_PathRequest = Union[_GetRequest, _ExistsRequest]
-
-_logger = Logger(__name__)
-
-num_workers = 8
-if TYPE_CHECKING:
- import lightning.app
-
-
-class _Copier(Thread):
- """The Copier is a thread running alongside a LightningWork.
-
- It maintains two queues that connect to the central
- :class:`~lightning.app.storage.orchestrator.StorageOrchestrator`,
- the request queue and the response queue. The Copier waits for a request to be pushed to the request queue,
- processes it and sends back the request through the response queue. In the current implementation, the Copier
- simply copies the requested file from the local filesystem to a shared directory (determined by
- :func:`~lightning.app.storage.path.shared_storage_path`). Any errors raised during the copy will be added to the
- response and get re-raised within the corresponding LightningWork.
-
- Args:
- copy_request_queue: A queue connecting the central StorageOrchestrator with the Copier. The orchestrator
- will send requests to this queue.
- copy_response_queue: A queue connecting the central StorageOrchestrator with the Copier. The Copier
- will send a response to this queue whenever a requested copy has finished.
-
- """
-
- def __init__(
- self, work: "lightning.app.LightningWork", copy_request_queue: "BaseQueue", copy_response_queue: "BaseQueue"
- ) -> None:
- super().__init__(daemon=True)
- self._work = work
- self.copy_request_queue = copy_request_queue
- self.copy_response_queue = copy_response_queue
- self._exit_event = threading.Event()
- self._sleep_time = 0.1
-
- def run(self) -> None:
- while not self._exit_event.is_set():
- self._exit_event.wait(self._sleep_time)
- self.run_once()
-
- def join(self, timeout: Optional[float] = None) -> None:
- self._exit_event.set()
- super().join(timeout)
-
- def run_once(self):
- request: _PathRequest = self.copy_request_queue.get() # blocks until we get a request
-
- t0 = time()
-
- obj: Optional[lightning.app.storage.path.Path] = _find_matching_path(self._work, request)
- if obj is None:
- # If it's not a path, it must be a payload
- obj: lightning.app.storage.Payload = getattr(self._work, request.name)
-
- if isinstance(request, _ExistsRequest):
- response = obj._handle_exists_request(self._work, request)
- elif isinstance(request, _GetRequest):
- response = obj._handle_get_request(self._work, request)
- else:
- raise TypeError(
- f"The file copy request had an invalid type. Expected PathGetRequest or PathExistsRequest, got:"
- f" {type(request)}"
- )
-
- response.timedelta = time() - t0
- self.copy_response_queue.put(response)
-
-
-def _find_matching_path(work, request: _GetRequest) -> Optional["lightning.app.storage.path.Path"]:
- for name in work._paths:
- candidate: lightning.app.storage.path.Path = getattr(work, name)
- if candidate.hash == request.hash:
- return candidate
- return None
-
-
-def _copy_files(
- source_path: pathlib.Path,
- destination_path: pathlib.Path,
- fs: Optional[AbstractFileSystem] = None,
-) -> None:
- """Copy files from one path to another.
-
- The source path must either be an existing file or folder. If the source is a folder, the destination path is
- interpreted as a folder as well. If the source is a file, the destination path is interpreted as a file too.
-
- Files in a folder are copied recursively and efficiently using multiple threads.
-
- """
- if fs is None:
- fs = _filesystem()
-
- def _copy(from_path: pathlib.Path, to_path: pathlib.Path) -> Optional[Exception]:
- _logger.debug(f"Copying {str(from_path)} -> {str(to_path)}")
-
- try:
- # NOTE: S3 does not have a concept of directories, so we do not need to create one.
- if isinstance(fs, LocalFileSystem):
- fs.makedirs(str(to_path.parent), exist_ok=True)
-
- fs.put(str(from_path), str(to_path), recursive=False)
- except Exception as ex:
- # Return the exception so that it can be handled in the main thread
- return ex
-
- # NOTE: Cannot use `S3FileSystem.put(recursive=True)` because it tries to access parent directories
- # which it does not have access to.
- if source_path.is_dir():
- src = [file for file in source_path.rglob("*") if file.is_file()]
- dst = [destination_path / file.relative_to(source_path) for file in src]
-
- with concurrent.futures.ThreadPoolExecutor(num_workers) as executor:
- results = executor.map(_copy, src, dst)
-
- # Raise the first exception found
- exception = next((e for e in results if isinstance(e, Exception)), None)
- if exception:
- raise exception
- else:
- if isinstance(fs, LocalFileSystem):
- fs.makedirs(str(destination_path.parent), exist_ok=True)
-
- fs.put(str(source_path), str(destination_path))
diff --git a/src/lightning/app/storage/drive.py b/src/lightning/app/storage/drive.py
deleted file mode 100644
index b3f5f8f56726c..0000000000000
--- a/src/lightning/app/storage/drive.py
+++ /dev/null
@@ -1,341 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import pathlib
-import shutil
-import sys
-from copy import deepcopy
-from time import sleep, time
-from typing import Dict, List, Optional, Union
-
-from lightning.app.storage.path import LocalFileSystem, _filesystem, _shared_storage_path
-from lightning.app.utilities.component import _is_flow_context
-
-
-class Drive:
- __IDENTIFIER__ = "__drive__"
- __PROTOCOLS__ = ["lit://"]
-
- def __init__(
- self,
- id: str,
- allow_duplicates: bool = False,
- component_name: Optional[str] = None,
- root_folder: Optional[str] = None,
- ):
- """The Drive object provides a shared space to write and read files from.
-
- When the drive object is passed from one component to another, a copy is made and ownership
- is transferred to the new component.
-
- Arguments:
- id: Unique identifier for this Drive.
- allow_duplicates: Whether to enable files duplication between components.
- component_name: The component name which owns this drive.
- When not provided, it is automatically inferred by Lightning.
- root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
-
- """
- if id.startswith("s3://"):
- raise ValueError(
- "Using S3 buckets in a Drive is no longer supported. Please pass an S3 `Mount` to "
- "a Work's CloudCompute config in order to mount an s3 bucket as a filesystem in a work.\n"
- f"`CloudCompute(mount=Mount({id}), ...)`"
- )
-
- self.id = None
- self.protocol = None
- for protocol in self.__PROTOCOLS__:
- if id.startswith(protocol):
- self.protocol = protocol
- self.id = id.replace(protocol, "")
- break
- else: # N.B. for-else loop
- raise ValueError(
- f"Unknown protocol for the drive 'id' argument '{id}`. The 'id' string "
- f"must start with one of the following prefixes {self.__PROTOCOLS__}"
- )
-
- if not self.id:
- raise Exception(f"The Drive id needs to start with one of the following protocols: {self.__PROTOCOLS__}")
-
- if "/" in self.id:
- raise Exception(f"The id should be unique to identify your drive. Found `{self.id}`.")
-
- self.root_folder = pathlib.Path(root_folder).resolve() if root_folder else pathlib.Path(os.getcwd())
- if self.protocol != "s3://" and not os.path.isdir(self.root_folder):
- raise Exception(f"The provided root_folder isn't a directory: {root_folder}")
- self.component_name = component_name
- self.allow_duplicates = allow_duplicates
- self.fs = _filesystem()
-
- @property
- def root(self) -> pathlib.Path:
- root_path = self.drive_root / self.component_name
- if isinstance(self.fs, LocalFileSystem):
- self.fs.makedirs(root_path, exist_ok=True)
- return root_path
-
- @property
- def drive_root(self) -> pathlib.Path:
- return _shared_storage_path() / "artifacts" / "drive" / self.id
-
- def put(self, path: str) -> None:
- """This method enables to put a file to the Drive in a blocking fashion.
-
- Arguments:
- path: The relative path to your files to be added to the Drive.
-
- """
- if not self.component_name:
- raise Exception("The component name needs to be known to put a path to the Drive.")
- if _is_flow_context():
- raise Exception("The flow isn't allowed to put files into a Drive.")
-
- self._validate_path(path)
-
- if not self.allow_duplicates:
- self._check_for_allow_duplicates(path)
-
- from lightning.app.storage.copier import _copy_files
-
- src = pathlib.Path(os.path.join(self.root_folder, path)).resolve()
- dst = self._to_shared_path(path, component_name=self.component_name)
-
- _copy_files(src, dst)
-
- def list(self, path: Optional[str] = ".", component_name: Optional[str] = None) -> List[str]:
- """This method enables to list files under the provided path from the Drive in a blocking fashion.
-
- Arguments:
- path: The relative path you want to list files from the Drive.
- component_name: By default, the Drive lists files across all components.
- If you provide a component name, the listing is specific to this component.
-
- """
- if _is_flow_context():
- raise Exception("The flow isn't allowed to list files from a Drive.")
-
- if component_name:
- paths = [
- self._to_shared_path(
- path,
- component_name=component_name,
- )
- ]
- else:
- paths = [
- self._to_shared_path(
- path,
- component_name=component_name,
- )
- for component_name in self._collect_component_names()
- ]
-
- files = []
- sep = "\\" if sys.platform == "win32" else "/"
- prefix_len = len(str(self.root).split(sep))
- for p in paths:
- if self.fs.exists(p):
- for f in self.fs.ls(p):
- files.append(str(pathlib.Path(*pathlib.Path(f).parts[prefix_len:])))
- return files
-
- def get(
- self,
- path: str,
- component_name: Optional[str] = None,
- timeout: Optional[float] = None,
- overwrite: bool = False,
- ) -> None:
- """This method enables to get files under the provided path from the Drive in a blocking fashion.
-
- Arguments:
- path: The relative path you want to list files from the Drive.
- component_name: By default, the Drive get the matching files across all components.
- If you provide a component name, the matching is specific to this component.
- timeout: Whether to wait for the files to be available if not created yet.
- overwrite: Whether to override the provided path if it exists.
-
- """
- if _is_flow_context():
- raise Exception("The flow isn't allowed to get files from a Drive.")
-
- if component_name:
- shared_path = self._to_shared_path(
- path,
- component_name=component_name,
- )
- if timeout:
- start_time = time()
- while not self.fs.exists(shared_path):
- sleep(1)
- if (time() - start_time) > timeout:
- raise Exception(f"The following {path} wasn't found in {timeout} seconds")
- break
-
- self._get(
- self.fs,
- shared_path,
- pathlib.Path(os.path.join(self.root_folder, path)).resolve(),
- overwrite=overwrite,
- )
- else:
- if timeout:
- start_time = time()
- while True:
- if (time() - start_time) > timeout:
- raise Exception(f"The following {path} wasn't found in {timeout} seconds.")
- match = self._find_match(path)
- if match is None:
- sleep(1)
- continue
- break
- else:
- match = self._find_match(path)
- if not match:
- raise Exception(f"We didn't find any match for the associated {path}.")
-
- self._get(self.fs, match, pathlib.Path(os.path.join(self.root_folder, path)).resolve(), overwrite=overwrite)
-
- def delete(self, path: str) -> None:
- """This method enables to delete files under the provided path from the Drive in a blocking fashion. Only the
- component which added a file can delete them.
-
- Arguments:
- path: The relative path you want to delete files from the Drive.
-
- """
- if not self.component_name:
- raise Exception("The component name needs to be known to delete a path to the Drive.")
-
- shared_path = self._to_shared_path(
- path,
- component_name=self.component_name,
- )
- if self.fs.exists(str(shared_path)):
- self.fs.rm(str(shared_path))
- else:
- raise Exception(f"The file {path} doesn't exists in the component_name space {self.component_name}.")
-
- def to_dict(self):
- return {
- "type": self.__IDENTIFIER__,
- "id": self.id,
- "protocol": self.protocol,
- "allow_duplicates": self.allow_duplicates,
- "component_name": self.component_name,
- "root_folder": str(self.root_folder),
- }
-
- @classmethod
- def from_dict(cls, dict: Dict) -> "Drive":
- assert dict["type"] == cls.__IDENTIFIER__
- drive = cls(
- dict["protocol"] + dict["id"],
- allow_duplicates=dict["allow_duplicates"],
- root_folder=dict["root_folder"],
- )
- drive.component_name = dict["component_name"]
- return drive
-
- def __deepcopy__(self, memo):
- cls = self.__class__
- result = cls.__new__(cls)
- memo[id(self)] = result
- for k, v in self.__dict__.items():
- setattr(result, k, deepcopy(v, memo))
- return result
-
- def _collect_component_names(self) -> List[str]:
- sep = "/"
- if self.fs.exists(self.drive_root):
- # Invalidate cache before running ls in case new directories have been added
- # TODO: Re-evaluate this - may lead to performance issues
- self.fs.invalidate_cache()
- return [str(p.split(sep)[-1]) for p in self.fs.ls(self.drive_root)]
- return []
-
- def _to_shared_path(self, path: str, component_name: Optional[str] = None) -> pathlib.Path:
- shared_path = self.drive_root
- if component_name:
- shared_path /= component_name
- shared_path /= path
- return shared_path
-
- def _get(self, fs, src: pathlib.Path, dst: pathlib.Path, overwrite: bool):
- if fs.isdir(src):
- if isinstance(fs, LocalFileSystem):
- dst = dst.resolve()
- if fs.exists(dst):
- if overwrite:
- fs.rm(str(dst), recursive=True)
- else:
- raise FileExistsError(f"The file {dst} was found. Add get(..., overwrite=True) to replace it.")
-
- shutil.copytree(src, dst)
- else:
- glob = f"{str(src)}/**"
- fs.get(glob, str(dst.absolute()), recursive=False)
- else:
- fs.get(str(src), str(dst.absolute()), recursive=False)
-
- def _find_match(self, path: str) -> Optional[pathlib.Path]:
- matches = []
- for component_name in self._collect_component_names():
- possible_path = self._to_shared_path(path, component_name=component_name)
- if self.fs.exists(possible_path):
- matches.append(possible_path)
-
- if not matches:
- return None
-
- if len(matches) > 1:
- sep = "\\" if sys.platform == "win32" else "/"
- prefix_len = len(str(self.root).split(sep))
- matches = [str(pathlib.Path(*pathlib.Path(p).parts[prefix_len:])) for p in matches]
- raise Exception(f"We found several matching files created by multiples components: {matches}.")
-
- return matches[0]
-
- def _check_for_allow_duplicates(self, path):
- possible_paths = [
- self._to_shared_path(
- path,
- component_name=component_name,
- )
- for component_name in self._collect_component_names()
- if component_name != self.component_name
- ]
- matches = [self.fs.exists(p) for p in possible_paths]
-
- if sum(matches):
- raise Exception(f"The file {path} can't be added as already found in the Drive.")
-
- def _validate_path(self, path: str) -> None:
- if not os.path.exists(os.path.join(self.root_folder, path)):
- raise FileExistsError(f"The provided path {path} doesn't exists")
-
- def __str__(self) -> str:
- assert self.id
- return self.id
-
-
-def _maybe_create_drive(component_name: str, state: Dict) -> Union[Dict, Drive]:
- if state.get("type") == Drive.__IDENTIFIER__:
- drive = Drive.from_dict(state)
- drive.component_name = component_name
- return drive
- return state
diff --git a/src/lightning/app/storage/filesystem.py b/src/lightning/app/storage/filesystem.py
deleted file mode 100644
index 943a6a750bd2b..0000000000000
--- a/src/lightning/app/storage/filesystem.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import os
-import shutil
-from pathlib import Path
-from typing import Callable, List
-
-from fsspec.implementations.local import LocalFileSystem
-
-from lightning.app.storage.copier import _copy_files
-from lightning.app.storage.path import _filesystem, _shared_storage_path
-
-
-def _get_files(fs, src: Path, dst: Path, overwrite: bool = True):
- dst = dst.resolve()
- if fs.isdir(src):
- if isinstance(fs, LocalFileSystem):
- dst = dst.resolve()
- if fs.exists(dst):
- if overwrite:
- fs.rm(str(dst), recursive=True)
- else:
- raise FileExistsError(f"The file {dst} was found. Add get(..., overwrite=True) to replace it.")
-
- shutil.copytree(src, dst)
- else:
- glob = f"{str(src)}/**"
- fs.get(glob, str(dst), recursive=False)
- else:
- fs.get(str(src), str(dst), recursive=False)
-
-
-class FileSystem:
- """This filesystem enables to easily move files from and to the shared storage."""
-
- def __init__(self) -> None:
- self._fs = _filesystem()
- self._root = str(_shared_storage_path())
-
- def put(self, src_path: str, dst_path: str, put_fn: Callable = _copy_files) -> None:
- """This method enables to put a file to the shared storage in a blocking fashion.
-
- Arguments:
- src_path: The path to your files locally
- dst_path: The path to your files transfered in the shared storage.
- put_fn: The method to use to put files in the shared storage.
-
- """
- if not os.path.exists(Path(src_path).resolve()):
- raise FileExistsError(f"The provided path {src_path} doesn't exist")
-
- if not dst_path.startswith("/"):
- raise Exception(f"The provided destination {dst_path} needs to start with `/`.")
-
- if dst_path == "/":
- dst_path = os.path.join(self._root, os.path.basename(src_path))
- else:
- dst_path = os.path.join(self._root, dst_path[1:])
-
- src = Path(src_path).resolve()
- dst = Path(dst_path).resolve()
-
- return put_fn(src, dst, fs=self._fs)
-
- def get(self, src_path: str, dst_path: str, overwrite: bool = True, get_fn: Callable = _get_files) -> None:
- """This method enables to get files from the shared storage in a blocking fashion.
-
- Arguments:
- src_path: The path to your files in the shared storage
- dst_path: The path to your files transfered locally
- get_fn: The method to use to put files in the shared storage.
-
- """
- if not src_path.startswith("/"):
- raise Exception(f"The provided destination {src_path} needs to start with `/`.")
-
- src = Path(os.path.join(self._root, src_path[1:])).resolve()
- dst = Path(dst_path).resolve()
-
- return get_fn(fs=self._fs, src=src, dst=dst, overwrite=overwrite)
-
- def listdir(self, path: str) -> List[str]:
- """This method enables to list files from the shared storage in a blocking fashion.
-
- Arguments:
- path: The path to files to list.
-
- """
- if not path.startswith("/"):
- raise Exception(f"The provided destination {path} needs to start with `/`.")
-
- shared_path = Path(os.path.join(self._root, path[1:])).resolve()
-
- if not self._fs.exists(shared_path):
- raise RuntimeError(f"The provided path {shared_path} doesn't exist.")
-
- # Invalidate cache before running ls in case new directories have been added
- # TODO: Re-evaluate this - may lead to performance issues
- self._fs.invalidate_cache()
-
- paths = self._fs.ls(shared_path)
- if not paths:
- return paths
-
- return sorted([path.replace(self._root + os.sep, "") for path in paths if not path.endswith("info.txt")])
-
- def walk(self, path: str) -> List[str]:
- """This method enables to list files from the shared storage in a blocking fashion.
-
- Arguments:
- path: The path to files to list.
-
- """
- if not path.startswith("/"):
- raise Exception(f"The provided destination {path} needs to start with `/`.")
-
- shared_path = Path(os.path.join(self._root, path[1:])).resolve()
-
- if not self._fs.exists(shared_path):
- raise RuntimeError(f"The provided path {shared_path} doesn't exist.")
-
- # Invalidate cache before running ls in case new directories have been added
- # TODO: Re-evaluate this - may lead to performance issues
- self._fs.invalidate_cache()
-
- paths = self._fs.ls(shared_path)
- if not paths:
- return paths
-
- out = []
-
- for shared_path in paths:
- path = str(shared_path).replace(self._root, "")
- if self._fs.isdir(shared_path):
- out.extend(self.walk(path))
- else:
- if path.endswith("info.txt"):
- continue
- out.append(path[1:])
- return sorted(out)
-
- def rm(self, path) -> None:
- if not path.startswith("/"):
- raise Exception(f"The provided destination {path} needs to start with `/`.")
-
- delete_path = Path(os.path.join(self._root, path[1:])).resolve()
-
- if self._fs.exists(str(delete_path)):
- if self._fs.isdir(str(delete_path)):
- self._fs.rmdir(str(delete_path))
- else:
- self._fs.rm(str(delete_path))
- else:
- raise Exception(f"The file path {path} doesn't exist.")
-
- def isfile(self, path: str) -> bool:
- if not path.startswith("/"):
- raise Exception(f"The provided destination {path} needs to start with `/`.")
-
- path = Path(os.path.join(self._root, path[1:])).resolve()
- return self._fs.isfile(path)
-
- def isdir(self, path: str) -> bool:
- if not path.startswith("/"):
- raise Exception(f"The provided destination {path} needs to start with `/`.")
-
- path = Path(os.path.join(self._root, path[1:])).resolve()
- return self._fs.isdir(path)
diff --git a/src/lightning/app/storage/mount.py b/src/lightning/app/storage/mount.py
deleted file mode 100644
index 8142b4574a8b0..0000000000000
--- a/src/lightning/app/storage/mount.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from dataclasses import dataclass
-from pathlib import Path
-from typing import List
-
-__MOUNT_IDENTIFIER__: str = "__mount__"
-__MOUNT_PROTOCOLS__: List[str] = ["s3://"]
-
-
-@dataclass
-class Mount:
- """Allows you to mount the contents of an AWS S3 bucket on disk when running an app on the cloud.
-
- Arguments:
- source: The location which contains the external data which should be mounted in the
- running work. At the moment, only AWS S3 mounts are supported. This must be a full
- `s3` style identifier pointing to a bucket and (optionally) prefix to mount. For
- example: `s3://foo/bar/`.
-
- mount_path: An absolute directory path in the work where external data source should
- be mounted as a filesystem. This path should not already exist in your codebase.
- If not included, then the root_dir will be set to `/data/`
-
- """
-
- source: str = ""
- mount_path: str = ""
-
- def __post_init__(self) -> None:
- for protocol in __MOUNT_PROTOCOLS__:
- if self.source.startswith(protocol):
- protocol = protocol
- break
- else: # N.B. for-else loop
- raise ValueError(
- f"Unknown protocol for the mount 'source' argument '{self.source}`. The 'source' "
- f"string must start with one of the following prefixes: {__MOUNT_PROTOCOLS__}"
- )
-
- if protocol == "s3://" and not self.source.endswith("/"):
- raise ValueError(
- "S3 mounts must end in a trailing slash (`/`) to indicate a folder is being mounted. "
- f"Received: '{self.source}'. Mounting a single file is not currently supported."
- )
-
- if self.mount_path == "":
- self.mount_path = f"/data/{Path(self.source).stem}"
-
- if not os.path.isabs(self.mount_path):
- raise ValueError(
- f"mount_path argument must be an absolute path to a "
- f"location; received relative path {self.mount_path}"
- )
-
- @property
- def protocol(self) -> str:
- """The backing storage protocol indicated by this drive source."""
- for protocol in __MOUNT_PROTOCOLS__:
- if self.source.startswith(protocol):
- return protocol
- return ""
diff --git a/src/lightning/app/storage/orchestrator.py b/src/lightning/app/storage/orchestrator.py
deleted file mode 100644
index adb7078ff3360..0000000000000
--- a/src/lightning/app/storage/orchestrator.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import threading
-import traceback
-from queue import Empty
-from threading import Thread
-from typing import TYPE_CHECKING, Dict, Optional, Union
-
-from lightning.app.core.queues import BaseQueue
-from lightning.app.storage.path import _filesystem, _path_to_work_artifact
-from lightning.app.storage.requests import _ExistsRequest, _ExistsResponse, _GetRequest, _GetResponse
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.enum import WorkStageStatus
-
-if TYPE_CHECKING:
- from lightning.app.core.app import LightningApp
-
-
-_PathRequest = Union[_GetRequest, _ExistsRequest]
-_PathResponse = Union[_ExistsResponse, _GetResponse]
-_logger = Logger(__name__)
-
-
-class StorageOrchestrator(Thread):
- """The StorageOrchestrator processes file transfer requests from Work that need file(s) from other Work.
-
- Args:
- app: A reference to the ``LightningApp`` which holds the copy request- and response queues for storage.
- request_queues: A dictionary with Queues connected to consumer Work. The Queue will contain transfer requests
- coming from a consumer Work.
- response_queues: A dictionary with Queues connected to consumer Work.
- The Queue will contain the confirmation responses to the consumer Work that files were transferred.
- copy_request_queues: A dictionary of Queues where each Queue connects to one Work. The orchestrator will
- put requests on this queue for the file-transfer thread to complete.
- copy_response_queues: A dictionary of Queues where each Queue connects to one Work. The queue is expected to
- contain the completion response from the file-transfer thread running in the Work process.
-
- """
-
- def __init__(
- self,
- app: "LightningApp",
- request_queues: Dict[str, BaseQueue],
- response_queues: Dict[str, BaseQueue],
- copy_request_queues: Dict[str, BaseQueue],
- copy_response_queues: Dict[str, BaseQueue],
- ) -> None:
- super().__init__(daemon=True)
- self.app = app
- self.request_queues = request_queues
- self.response_queues = response_queues
- self.copy_request_queues = copy_request_queues
- self.copy_response_queues = copy_response_queues
- self.waiting_for_response: Dict[str, str] = {}
- self._validate_queues()
- self._exit_event = threading.Event()
-
- # Note: Use different sleep time locally and in the cloud
- # to reduce queue calls.
- self._sleep_time = 0.1 if "LIGHTNING_APP_STATE_URL" not in os.environ else 2
- self._fs = None
-
- @property
- def fs(self):
- if self._fs is None:
- self._fs = _filesystem()
- return self._fs
-
- def _validate_queues(self):
- assert (
- self.request_queues.keys()
- == self.response_queues.keys()
- == self.copy_request_queues.keys()
- == self.copy_response_queues.keys()
- )
-
- def run(self) -> None:
- while not self._exit_event.is_set():
- for work_name in list(self.request_queues.keys()):
- try:
- self.run_once(work_name)
- except Exception:
- _logger.error(traceback.format_exc())
- self._exit_event.wait(self._sleep_time)
-
- def join(self, timeout: Optional[float] = None) -> None:
- self._exit_event.set()
- super().join(timeout)
-
- def run_once(self, work_name: str) -> None:
- if work_name not in self.waiting_for_response:
- # check if there is a new request from this work for a file transfer
- # there can only be one pending request per work
- request_queue = self.request_queues[work_name]
- try:
- request: _PathRequest = request_queue.get(timeout=0) # this should not block
- # This should not happen under normal conditions, but it has occurred.
- # For now we are tolerant with respect to requests being None in the queue
- # and just move on.
- if request is None:
- raise Empty
- except Empty:
- pass
- else:
- request.destination = work_name
- source_work = self.app.get_component_by_name(request.source)
- maybe_artifact_path = str(_path_to_work_artifact(request.path, source_work))
-
- if self.fs.exists(maybe_artifact_path):
- # First check if the shared filesystem has the requested file stored as an artifact
- # If so, we will let the destination Work access this file directly
- # NOTE: This is NOT the right thing to do, because the Work could still be running and producing
- # a newer version of the requested file, but we can't rely on the Work status to be accurate
- # (at the moment)
- if isinstance(request, _GetRequest):
- response = _GetResponse(
- source=request.source,
- name=request.name,
- path=maybe_artifact_path,
- hash=request.hash,
- size=self.fs.info(maybe_artifact_path)["size"],
- destination=request.destination,
- )
- if isinstance(request, _ExistsRequest):
- response = _ExistsResponse(
- source=request.source,
- path=maybe_artifact_path,
- name=request.name,
- hash=request.hash,
- destination=request.destination,
- exists=True,
- )
- response_queue = self.response_queues[response.destination]
- response_queue.put(response)
- elif source_work.status.stage not in (
- WorkStageStatus.NOT_STARTED,
- WorkStageStatus.STOPPED,
- WorkStageStatus.FAILED,
- ):
- _logger.debug(
- f"Request for File Transfer received from {work_name}: {request}. Sending request to"
- f" {request.source} to copy the file."
- )
- # The Work is running, and we can send a request to the copier for moving the file to the
- # shared storage
- self.copy_request_queues[request.source].put(request)
- # Store a destination to source mapping.
- self.waiting_for_response[work_name] = request.source
- else:
- if isinstance(request, _GetRequest):
- response = _GetResponse(
- source=request.source,
- path=request.path,
- name=request.name,
- hash=request.hash,
- size=0,
- destination=request.destination,
- )
- if isinstance(request, _ExistsRequest):
- response = _ExistsResponse(
- source=request.source,
- path=request.path,
- hash=request.hash,
- destination=request.destination,
- exists=False,
- name=request.name,
- )
- response.exception = FileNotFoundError(
- "The work is not running and the requested object is not available in the artifact store."
- )
- response_queue = self.response_queues[response.destination]
- response_queue.put(response)
-
- # Check the current work is within the sources.
- # It is possible to have multiple destination targeting
- # the same source concurrently.
- if work_name in self.waiting_for_response.values():
- # check if the current work has responses for file transfers to other works.
- copy_response_queue = self.copy_response_queues[work_name]
- try:
- # check if the share-point file manager has confirmed a copy request
- response: _PathResponse = copy_response_queue.get(timeout=0) # this should not block
- except Empty:
- pass
- else:
- _logger.debug(
- f"Received confirmation of a completed file copy request from {work_name}:{response}."
- f" Sending the confirmation back to {response.destination}."
- )
- destination = response.destination
- assert response.source == work_name
- response_queue = self.response_queues[destination]
- response_queue.put(response)
- # the request has been processed, allow new requests to come in for the destination work
- del self.waiting_for_response[destination]
diff --git a/src/lightning/app/storage/path.py b/src/lightning/app/storage/path.py
deleted file mode 100644
index 0d1297d998fb2..0000000000000
--- a/src/lightning/app/storage/path.py
+++ /dev/null
@@ -1,453 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import hashlib
-import os
-import pathlib
-import shutil
-import sys
-from time import sleep
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
-
-from fsspec import AbstractFileSystem
-from fsspec.implementations.local import LocalFileSystem
-
-from lightning.app.core.constants import REMOTE_STORAGE_WAIT
-from lightning.app.core.queues import BaseQueue
-from lightning.app.storage.requests import _ExistsRequest, _ExistsResponse, _GetRequest, _GetResponse
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.component import _is_flow_context
-from lightning.app.utilities.imports import _is_s3fs_available
-
-if _is_s3fs_available():
- from s3fs import S3FileSystem
-
-PathlibPath = type(pathlib.Path()) # PosixPath or a WindowsPath depending on the platform
-
-if TYPE_CHECKING:
- from lightning.app.core.work import LightningWork
-
-num_workers = 8
-
-_logger = Logger(__name__)
-
-
-class Path(PathlibPath):
- """A drop-in replacement for :class:`pathlib.Path` for all paths in Lightning.
-
- The Lightning Path works exactly the same as :class:`pathlib.Path` but it also remembers in which LightningWork
- it was created. If the Path gets passed to a different LightningWork, the file or folder can then be easily
- accessed no matter where it is located in the other Work's filesystem.
-
- Args:
- *args: Accepts the same arguments as in :class:`pathlib.Path`
- **kwargs: Accepts the same keyword arguments as in :class:`pathlib.Path`
-
- """
-
- @classmethod
- def _from_parts(cls, args: Any, **__unused) -> "Path":
- """This gets called from the super class in ``pathlib.Path.__new__``.
-
- The Lightning Path overrides this to validate the instantiation in the case parts are passed in individually. In
- such a case we need to validate that all parts have the same `origin` and if not, an error is raised.
-
- """
- if args and isinstance(args[0], str) and args[0].startswith("lit://"):
- parts = list(args)
- parts[0] = parts[0][len("lit://") :]
- args = (_storage_root_dir(), *parts)
-
- if (sys.version_info.major, sys.version_info.minor) < (3, 10):
- __unused.setdefault("init", True)
- new_path = super()._from_parts(args, **__unused)
- else:
- new_path = super()._from_parts(args)
-
- new_path._init_attributes() # we use this instead of defining a __init__() method
-
- paths_from_parts = [part for part in args if isinstance(part, Path)]
- if not paths_from_parts:
- return new_path
- top_path = paths_from_parts[0]
- origins = [part._origin for part in paths_from_parts]
- if not all(origins[0] == origin or origin is None for origin in origins):
- raise TypeError(
- "Tried to instantiate a Lightning Path from multiple other Paths that originate from different"
- " LightningWork."
- )
- new_path._copy_properties_from(top_path)
- return new_path
-
- def _init_attributes(self):
- self._name: Optional[str] = None
- # the origin is the work that created this Path and wants to expose file(s)
- self._origin: Optional[Union["LightningWork", str]] = None
- # the consumer is the Work that needs access to the file(s) from the consumer
- self._consumer: Optional[Union["LightningWork", str]] = None
- self._metadata = {}
- # request queue: used to transfer message to storage orchestrator
- self._request_queue: Optional[BaseQueue] = None
- # response queue: used to receive status message from storage orchestrator
- self._response_queue: Optional[BaseQueue] = None
-
- @property
- def origin_name(self) -> str:
- """The name of the LightningWork where this path was first created.
-
- Attaching a Path to a LightningWork will automatically make it the `origin`.
-
- """
- from lightning.app.core.work import LightningWork
-
- return self._origin.name if isinstance(self._origin, LightningWork) else self._origin
-
- @property
- def consumer_name(self) -> str:
- """The name of the LightningWork where this path is being accessed.
-
- By default, this is the same as the :attr:`origin_name`.
-
- """
- from lightning.app.core.work import LightningWork
-
- return self._consumer.name if isinstance(self._consumer, LightningWork) else self._consumer
-
- @property
- def hash(self) -> Optional[str]:
- """The hash of this Path uniquely identifies the file path and the associated origin Work.
-
- Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
-
- """
- if self._origin is None:
- return None
- contents = f"{self.origin_name}/{self}"
- return hashlib.sha1(contents.encode("utf-8")).hexdigest()
-
- @property
- def parents(self) -> Sequence["Path"]:
- parents: List["Path"] = list(super().parents)
- for parent in parents:
- parent._copy_properties_from(self)
- return parents
-
- @property
- def parent(self) -> "Path":
- parent: Path = super().parent
- parent._copy_properties_from(self)
- return parent
-
- def exists(self) -> bool:
- """Check if the path exists locally or remotely.
-
- If the path exists locally, this method immediately returns ``True``, otherwise it will make a RPC call
- to the attached origin Work and check if the path exists remotely.
- If you strictly want to check local existence only, use :meth:`exists_local` instead. If you strictly want
- to check existence on the remote (regardless of whether the file exists locally or not), use
- :meth:`exists_remote`.
-
- """
- return self.exists_local() or (self._origin and self.exists_remote())
-
- def exists_local(self) -> bool:
- """Check if the path exists locally."""
- return super().exists()
-
- def exists_remote(self) -> bool:
- """Check if the path exists remotely on the attached orgin Work.
-
- Raises:
- RuntimeError: If the path is not attached to any Work (origin undefined).
-
- """
- # Fail early if we need to check the remote but an origin is not defined
- if not self._origin or self._request_queue is None or self._response_queue is None:
- raise RuntimeError(
- f"Trying to check if the file {self} exists, but the path is not attached to a LightningWork."
- f" Set it as an attribute to a LightningWork or pass it to the `run()` method."
- )
-
- # 1. Send message to orchestrator through queue that with a request for a path existence check
- request = _ExistsRequest(source=self.origin_name, path=str(self), name=self._name, hash=self.hash)
- self._request_queue.put(request)
-
- # 2. Wait for the response to come back
- response: _ExistsResponse = self._response_queue.get() # blocking
- return response.exists
-
- def get(self, overwrite: bool = False) -> None:
- if _is_flow_context():
- raise RuntimeError("`Path.get()` can only be called from within the `run()` method of LightningWork.")
- if self._request_queue is None or self._response_queue is None:
- raise RuntimeError(
- f"Trying to get the file {self}, but the path is not attached to a LightningApp."
- f" Are you trying to get the file from within `__init__`?"
- )
- if self._origin is None:
- raise RuntimeError(
- f"Trying to get the file {self}, but the path is not attached to a LightningWork. Set it as an"
- f" attribute to a LightningWork or pass it to the `run()` method."
- )
-
- if self.exists_local() and not overwrite:
- raise FileExistsError(
- f"The file or folder {self} exists locally. Pass `overwrite=True` if you wish to replace it"
- f" with the new contents."
- )
-
- # 1. Send message to orchestrator through queue with details of the transfer
- # the source is the name of the work that owns the file that we request
- # the destination is determined by the queue, since each work has a dedicated send and recv queue
- request = _GetRequest(source=self.origin_name, path=str(self), hash=self.hash, name=self._name)
- self._request_queue.put(request)
-
- # 2. Wait for the transfer to finish
- response: _GetResponse = self._response_queue.get() # blocking
- self._validate_get_response(response)
-
- fs = _filesystem()
-
- # 3. Wait until the file appears in shared storage
- while not fs.exists(response.path) or fs.info(response.path)["size"] != response.size:
- sleep(REMOTE_STORAGE_WAIT)
-
- if self.exists_local() and self.is_dir():
- # Delete the directory, otherwise we can't overwrite it
- shutil.rmtree(self)
-
- # 4. Copy the file from the shared storage to the destination on the local filesystem
- if fs.isdir(response.path):
- if isinstance(fs, LocalFileSystem):
- shutil.copytree(response.path, self.resolve())
- else:
- glob = f"{str(response.path)}/**"
- _logger.debug(f"Attempting to copy {glob} -> {str(self.absolute())}")
- fs.get(glob, str(self.absolute()), recursive=False)
- else:
- _logger.debug(f"Attempting to copy {str(response.path)} -> {str(self.absolute())}")
- fs.get(str(response.path), str(self.absolute()), recursive=False)
-
- def to_dict(self) -> dict:
- """Serialize this Path to a dictionary."""
- return {
- "path": str(self),
- "origin_name": self.origin_name,
- "consumer_name": self.consumer_name,
- "metadata": self._metadata,
- }
-
- @classmethod
- def from_dict(cls, content: dict) -> "Path":
- """Instantiate a Path from a dictionary."""
- path = cls(content["path"])
- path._origin = content["origin_name"]
- path._consumer = content["consumer_name"]
- path._metadata = content["metadata"]
- return path
-
- def _validate_get_response(self, response: "_GetResponse") -> None:
- if response.source != self._origin or response.hash != self.hash:
- raise RuntimeError(
- f"Tried to get the file {self} but received a response for a request it did not send. The response"
- f" contents are: {response}"
- )
-
- if response.exception is not None:
- raise RuntimeError(
- f"An exception was raised while trying to transfer the contents at {response.path}"
- f" from Work {response.source} to {response.destination}. See the full stacktrace above."
- ) from response.exception
-
- def _attach_work(self, work: "LightningWork") -> None:
- """Attach a LightningWork to this Path.
-
- The first work to be attached becomes the `origin`, i.e., the Work that is meant to expose the file to other
- Work. Attaching a Work to a Path that already has an `origin` Work will make it a `consumer`. A consumer Work
- is a work that can access the file only by first transferring it via :meth:`transfer`.
-
- Args:
- work: LightningWork to be attached to this Path.
-
- """
- if self._origin is None:
- # Can become an owner only if there is not already one
- self._origin = work
- self._consumer = work
-
- def _attach_queues(self, request_queue: BaseQueue, response_queue: BaseQueue) -> None:
- """Attaches the queues for communication with the Storage Orchestrator."""
- self._request_queue = request_queue
- self._response_queue = response_queue
-
- def _sanitize(self) -> None:
- """Sanitize this Path so that it can be deep-copied."""
- self._origin = self.origin_name
- self._consumer = self.consumer_name
- self._request_queue = None
- self._response_queue = None
-
- def _copy_properties_from(self, other: "Path") -> None:
- self._origin = other._origin
- self._consumer = other._consumer
- self._metadata = other._metadata
- self._request_queue = other._request_queue
- self._response_queue = other._response_queue
-
- def with_name(self, name: str) -> "Path":
- path: Path = super().with_name(name)
- path._copy_properties_from(self)
- return path
-
- def with_stem(self, stem: str) -> "Path":
- path: Path = super().with_stem(stem)
- path._copy_properties_from(self)
- return path
-
- def with_suffix(self, suffix: str) -> "Path":
- path: Path = super().with_suffix(suffix)
- path._copy_properties_from(self)
- return path
-
- def relative_to(self, *other) -> "Path":
- path: Path = super().relative_to(*other)
- path._copy_properties_from(self)
- return path
-
- def __truediv__(self, other: Union["Path", PathlibPath, str]) -> "Path":
- path: Path = super().__truediv__(other)
- path._copy_properties_from(self)
- return path
-
- def __rtruediv__(self, other: Union["Path", PathlibPath, str]) -> "Path":
- path: Path = super().__rtruediv__(other)
- path._copy_properties_from(self)
- return path
-
- def __reduce__(self):
- return Path.from_dict, (self.to_dict(),)
-
- def __json__(self) -> dict:
- """Converts the Path to a json-serializable dict object."""
- return self.to_dict()
-
- @staticmethod
- def _handle_exists_request(work: "LightningWork", request: _ExistsRequest) -> _ExistsResponse:
- return _ExistsResponse(
- source=request.source,
- name=request.name,
- hash=request.hash,
- path=request.path,
- destination=request.destination,
- exists=os.path.exists(request.path),
- )
-
- @staticmethod
- def _handle_get_request(work: "LightningWork", request: _GetRequest) -> _GetResponse:
- from lightning.app.storage.copier import _copy_files
-
- source_path = pathlib.Path(request.path)
- destination_path = _shared_storage_path() / request.hash
- response = _GetResponse(
- source=request.source,
- name=request.name,
- path=str(destination_path),
- hash=request.hash,
- size=source_path.stat().st_size,
- destination=request.destination,
- )
-
- try:
- _copy_files(source_path, destination_path)
- _logger.debug(f"All files copied from {request.path} to {response.path}.")
- except Exception as ex:
- response.exception = ex
- return response
-
-
-def _is_lit_path(path: Union[str, Path]) -> bool:
- path = Path(path)
- return path == _storage_root_dir() or _storage_root_dir() in path.parents
-
-
-def _shared_local_mount_path() -> pathlib.Path:
- """Returns the shared directory through which the Copier threads move files from one Work filesystem to another.
-
- The shared directory can be set via the environment variable ``SHARED_MOUNT_DIRECTORY`` and should be pointing to a
- directory that all Works have mounted (shared filesystem).
-
- """
- path = pathlib.Path(os.environ.get("SHARED_MOUNT_DIRECTORY", ".shared"))
- path.mkdir(parents=True, exist_ok=True)
- return path.absolute()
-
-
-def _storage_root_dir() -> pathlib.Path:
- path = pathlib.Path(os.environ.get("STORAGE_ROOT_DIR", "./.storage")).absolute()
- path.mkdir(parents=True, exist_ok=True)
- return path
-
-
-def _shared_storage_path() -> pathlib.Path:
- """Returns the shared path through which the Copier threads move files from one Work filesystem to another.
-
- The shared path gets set by the environment. Locally, it is pointing to a directory determined by the
- ``SHARED_MOUNT_DIRECTORY`` environment variable. In the cloud, the shared path will point to a S3 bucket. All Works
- have access to this shared dropbox.
-
- """
- storage_path = os.getenv("LIGHTNING_STORAGE_PATH", "")
- if storage_path != "":
- return pathlib.Path(storage_path)
-
- # TODO[dmitsf]: this logic is still needed for compatibility reasons.
- # We should remove it after some time.
- bucket_name = os.getenv("LIGHTNING_BUCKET_NAME", "")
- app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", "")
-
- if bucket_name != "" and app_id != "":
- return pathlib.Path(f"{bucket_name}/lightningapps/{app_id}")
-
- return _shared_local_mount_path()
-
-
-def _artifacts_path(work: "LightningWork") -> pathlib.Path:
- return _shared_storage_path() / "artifacts" / work.name
-
-
-def _path_to_work_artifact(path: Union[Path, pathlib.Path, str], work: "LightningWork") -> pathlib.Path:
- return _artifacts_path(work) / pathlib.Path(*pathlib.Path(path).absolute().parts[1:])
-
-
-def _filesystem() -> AbstractFileSystem:
- fs = LocalFileSystem()
-
- endpoint_url = os.getenv("LIGHTNING_BUCKET_ENDPOINT_URL", "")
- bucket_name = os.getenv("LIGHTNING_BUCKET_NAME", "")
- if endpoint_url != "" and bucket_name != "":
- # FIXME: Temporary fix until we remove the injection from the platform
- if "AWS_ACCESS_KEY_ID" in os.environ:
- del os.environ["AWS_ACCESS_KEY_ID"]
- del os.environ["AWS_SECRET_ACCESS_KEY"]
-
- fs = S3FileSystem()
-
- app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", "")
- if app_id == "":
- raise RuntimeError("missing LIGHTNING_CLOUD_APP_ID")
-
- if not fs.exists(_shared_storage_path()):
- raise RuntimeError(f"shared filesystem {_shared_storage_path()} does not exist")
-
- return fs
diff --git a/src/lightning/app/storage/payload.py b/src/lightning/app/storage/payload.py
deleted file mode 100644
index 03dd018bd8fcc..0000000000000
--- a/src/lightning/app/storage/payload.py
+++ /dev/null
@@ -1,274 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import hashlib
-import pathlib
-import pickle
-from abc import ABC, abstractmethod
-from time import sleep
-from typing import TYPE_CHECKING, Any, Optional, Union
-
-from lightning.app.core.constants import REMOTE_STORAGE_WAIT
-from lightning.app.core.queues import BaseQueue
-from lightning.app.storage.path import Path, _filesystem, _shared_storage_path
-from lightning.app.storage.requests import _ExistsRequest, _ExistsResponse, _GetRequest, _GetResponse
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.component import _is_flow_context
-
-_logger = Logger(__name__)
-
-if TYPE_CHECKING:
- from lightning.app.core.work import LightningWork
-
-
-class _BasePayload(ABC):
- def __init__(self, value: Any) -> None:
- self._value = value
- # the attribute name given to the payload
- self._name: Optional[str] = None
- # the origin is the work that created this Path and wants to expose file(s)
- self._origin: Optional[Union["LightningWork", str]] = None
- # the consumer is the Work that needs access to the file(s) from the consumer
- self._consumer: Optional[Union["LightningWork", str]] = None
- self._metadata = {}
- # request queue: used to transfer message to storage orchestrator
- self._request_queue: Optional[BaseQueue] = None
- # response queue: used to receive status message from storage orchestrator
- self._response_queue: Optional[BaseQueue] = None
-
- @property
- def name(self) -> Optional[str]:
- return self._name
-
- @property
- def value(self) -> Optional[Any]:
- """The real object that this payload holds."""
- return self._value
-
- @property
- def hash(self) -> Optional[str]:
- """The hash of this Payload uniquely identifies the payload and the associated origin Work.
-
- Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
-
- """
- if self._origin is None:
- return None
- contents = f"{self.origin_name}/{self.consumer_name}/{self.name}"
- return hashlib.sha1(contents.encode("utf-8")).hexdigest()
-
- @property
- def origin_name(self) -> str:
- """The name of the LightningWork where this payload was first created.
-
- Attaching a Payload to a LightningWork will automatically make it the `origin`.
-
- """
- from lightning.app.core.work import LightningWork
-
- return self._origin.name if isinstance(self._origin, LightningWork) else self._origin
-
- @property
- def consumer_name(self) -> str:
- """The name of the LightningWork where this payload is being accessed.
-
- By default, this is the same as the :attr:`origin_name`.
-
- """
- from lightning.app.core.work import LightningWork
-
- return self._consumer.name if isinstance(self._consumer, LightningWork) else self._consumer
-
- @property
- def _path(self) -> Optional[Path]:
- """Path to the file that the payload value gets serialized to."""
- if not self._name:
- return None
- return Path("lit://", self._name)
-
- @abstractmethod
- def save(self, obj: Any, path: str) -> None:
- """Override this method with your own saving logic."""
-
- @abstractmethod
- def load(self, path: str) -> Any:
- """Override this method with your own loading logic."""
-
- def _attach_work(self, work: "LightningWork") -> None:
- """Attach a LightningWork to this PayLoad.
-
- Args:
- work: LightningWork to be attached to this Payload.
-
- """
- if self._origin is None:
- # Can become an owner only if there is not already one
- self._origin = work.name
- self._consumer = work.name
-
- def _attach_queues(self, request_queue: BaseQueue, response_queue: BaseQueue) -> None:
- """Attaches the queues for communication with the Storage Orchestrator."""
- self._request_queue = request_queue
- self._response_queue = response_queue
-
- def _sanitize(self) -> None:
- """Sanitize this Payload so that it can be deep-copied."""
- self._origin = self.origin_name
- self._consumer = self.consumer_name
- self._request_queue = None
- self._response_queue = None
-
- def exists_remote(self):
- """Check if the payload exists remotely on the attached orgin Work.
-
- Raises:
- RuntimeError: If the payload is not attached to any Work (origin undefined).
-
- """
- # Fail early if we need to check the remote but an origin is not defined
- if not self._origin or self._request_queue is None or self._response_queue is None:
- raise RuntimeError(
- f"Trying to check if the payload {self} exists, but the payload is not attached to a LightningWork."
- f" Set it as an attribute to a LightningWork or pass it to the `run()` method."
- )
-
- # 1. Send message to orchestrator through queue that with a request for a path existence check
- request = _ExistsRequest(source=self.origin_name, name=self._name, path=str(self._path), hash=self.hash)
- self._request_queue.put(request)
-
- # 2. Wait for the response to come back
- response: _ExistsResponse = self._response_queue.get() # blocking
- return response.exists
-
- def get(self) -> Any:
- if _is_flow_context():
- raise RuntimeError("`Payload.get()` can only be called from within the `run()` method of LightningWork.")
-
- if self._request_queue is None or self._response_queue is None:
- raise RuntimeError(
- f"Trying to get the file {self}, but the payload is not attached to a LightningApp."
- f" Are you trying to get the file from within `__init__`?"
- )
- if self._origin is None:
- raise RuntimeError(
- f"Trying to get the file {self}, but the payload is not attached to a LightningWork. Set it as an"
- f" attribute to a LightningWork or pass it to the `run()` method."
- )
-
- # 1. Send message to orchestrator through queue with details of the transfer
- # the source is the name of the work that owns the file that we request
- # the destination is determined by the queue, since each work has a dedicated send and recv queue
- request = _GetRequest(source=self.origin_name, name=self._name, path=str(self._path), hash=self.hash)
- self._request_queue.put(request)
-
- # 2. Wait for the transfer to finish
- response: _GetResponse = self._response_queue.get() # blocking
- self._validate_get_response(response)
-
- fs = _filesystem()
-
- # 3. Wait until the file appears in shared storage
- while not fs.exists(response.path) or fs.info(response.path)["size"] != response.size:
- sleep(REMOTE_STORAGE_WAIT)
-
- # 4. Copy the file from the shared storage to the destination on the local filesystem
- local_path = self._path
- _logger.debug(f"Attempting to copy {str(response.path)} -> {str(local_path)}")
- fs.get(str(response.path), str(local_path), recursive=False)
-
- # Ensure the file is properly written
- sleep(0.5)
-
- self._value = self.load(local_path)
- return self._value
-
- def _validate_get_response(self, response: "_GetResponse") -> None:
- if response.source != self._origin or response.hash != self.hash:
- raise RuntimeError(
- f"Tried to get the file {self} but received a response for a request it did not send. The response"
- f" contents are: {response}"
- )
-
- if response.exception is not None:
- raise RuntimeError(
- f"An exception was raised while trying to transfer the contents at {response.path}"
- f" from Work {response.source} to {response.destination}. See the full stacktrace above."
- ) from response.exception
-
- def to_dict(self) -> dict:
- """Serialize this Path to a dictionary."""
- return {
- "name": self.name,
- "origin_name": self.origin_name,
- "consumer_name": self.consumer_name,
- "metadata": self._metadata,
- }
-
- @classmethod
- def from_dict(cls, content: dict) -> "_BasePayload":
- """Instantiate a Payload from a dictionary."""
- payload = cls(None)
- payload._name = content["name"]
- payload._origin = content["origin_name"]
- payload._consumer = content["consumer_name"]
- payload._metadata = content["metadata"]
- return payload
-
- @staticmethod
- def _handle_exists_request(work: "LightningWork", request: _ExistsRequest) -> _ExistsResponse:
- return _ExistsResponse(
- source=request.source,
- path=request.path,
- name=request.name,
- destination=request.destination,
- hash=request.hash,
- exists=getattr(work, request.name, None) is not None,
- )
-
- @staticmethod
- def _handle_get_request(work: "LightningWork", request: _GetRequest) -> _GetResponse:
- from lightning.app.storage.copier import _copy_files
-
- source_path = pathlib.Path(request.path)
- destination_path = _shared_storage_path() / request.hash
- response = _GetResponse(
- source=request.source,
- name=request.name,
- path=str(destination_path),
- hash=request.hash,
- destination=request.destination,
- )
-
- try:
- payload = getattr(work, request.name)
- payload.save(payload.value, source_path)
- response.size = source_path.stat().st_size
- _copy_files(source_path, destination_path)
- _logger.debug(f"All files copied from {request.path} to {response.path}.")
- except Exception as ex:
- response.exception = ex
- return response
-
-
-class Payload(_BasePayload):
- """The Payload object enables to transfer python objects from one work to another in a similar fashion as
- :class:`~lightning.app.storage.path.Path`."""
-
- def save(self, obj: Any, path: str) -> None:
- with open(path, "wb") as f:
- pickle.dump(obj, f)
-
- def load(self, path: str) -> Any:
- with open(path, "rb") as f:
- return pickle.load(f)
diff --git a/src/lightning/app/storage/requests.py b/src/lightning/app/storage/requests.py
deleted file mode 100644
index 83f430c452ce0..0000000000000
--- a/src/lightning/app/storage/requests.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import Optional
-
-
-@dataclass
-class _GetRequest:
- source: str
- name: str
- path: str
- hash: str
- destination: str = ""
-
-
-@dataclass
-class _GetResponse:
- source: str
- name: str
- path: str
- hash: str
- size: int = 0
- destination: str = ""
- exception: Optional[Exception] = None
- timedelta: Optional[float] = None
-
-
-@dataclass
-class _ExistsRequest:
- source: str
- name: str
- path: str
- hash: str
- destination: str = ""
-
-
-@dataclass
-class _ExistsResponse:
- source: str
- name: str
- path: str
- hash: str
- destination: str = ""
- exists: Optional[bool] = None
- timedelta: Optional[float] = None
diff --git a/src/lightning/app/structures/__init__.py b/src/lightning/app/structures/__init__.py
deleted file mode 100644
index a432fb6daf1b5..0000000000000
--- a/src/lightning/app/structures/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning.app.structures.dict import Dict
-from lightning.app.structures.list import List
-
-__all__ = ["Dict", "List"]
diff --git a/src/lightning/app/structures/dict.py b/src/lightning/app/structures/dict.py
deleted file mode 100644
index 7accc0c1db627..0000000000000
--- a/src/lightning/app/structures/dict.py
+++ /dev/null
@@ -1,160 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import typing as t
-
-from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name
-
-T = t.TypeVar("T")
-
-if t.TYPE_CHECKING:
- from lightning.app.utilities.types import Component
-
-
-def _prepare_name(component: "Component") -> str:
- return str(component.name.split(".")[-1])
-
-
-# TODO: add support and tests for dict operations (insertion, update, etc.)
-class Dict(t.Dict[str, T]):
- def __init__(self, **kwargs: T):
- """The Dict Object is used to represents dict collection of :class:`~lightning.app.core.work.LightningWork` or
- :class:`~lightning.app.core.flow.LightningFlow`.
-
- Example:
-
- >>> from lightning.app import LightningFlow, LightningWork
- >>> from lightning.app.structures import Dict
- >>> class CounterWork(LightningWork):
- ... def __init__(self):
- ... super().__init__()
- ... self.counter = 0
- ... def run(self):
- ... self.counter += 1
- ...
- >>> class RootFlow(LightningFlow):
- ... def __init__(self):
- ... super().__init__()
- ... self.dict = Dict(**{"work_0": CounterWork(), "work_1": CounterWork()})
- ... def run(self):
- ... for work_name, work in self.dict.items():
- ... work.run()
- ...
- >>> flow = RootFlow()
- >>> flow.run()
- >>> assert flow.dict["work_0"].counter == 1
-
- Arguments:
- items: A sequence of LightningWork or LightningFlow.
-
- """
- super().__init__(**kwargs)
- from lightning.app.runners.backends import Backend
-
- self._name: t.Optional[str] = ""
- self._backend: t.Optional[Backend] = None
- for k, v in kwargs.items():
- if "." in k:
- raise Exception(f"The provided name {k} contains . which is forbidden.")
- _set_child_name(self, v, k)
-
- def __setitem__(self, k, v):
- from lightning.app.core import LightningFlow, LightningWork
-
- if not isinstance(k, str):
- raise Exception("The provided key should be an string")
-
- if isinstance(k, str) and "." in k:
- raise Exception(f"The provided name {k} contains . which is forbidden.")
-
- _set_child_name(self, v, k)
- if self._backend:
- if isinstance(v, LightningFlow):
- LightningFlow._attach_backend(v, self._backend)
- elif isinstance(v, LightningWork):
- self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
- v._name = f"{self.name}.{k}"
- super().__setitem__(k, v)
-
- @property
- def works(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self.values() if isinstance(item, LightningWork)]
- for flow in [item for item in self.values() if isinstance(item, LightningFlow)]:
- for child_work in flow.works(recurse=False):
- works.append(child_work)
- return works
-
- @property
- def flows(self):
- from lightning.app.core.flow import LightningFlow
- from lightning.app.structures import Dict as _Dict
- from lightning.app.structures import List as _List
-
- flows = {}
- for item in self.values():
- if isinstance(item, LightningFlow):
- flows[item.name] = item
- for child_flow in item.flows.values():
- flows[child_flow.name] = child_flow
- if isinstance(item, (_Dict, _List)):
- for child_flow in item.flows.values():
- flows[child_flow.name] = child_flow
- return flows
-
- @property
- def name(self):
- return self._name or "root"
-
- @property
- def state(self):
- """Returns the state of its flows and works."""
- from lightning.app.core import LightningFlow, LightningWork
-
- return {
- "works": {key: item.state for key, item in self.items() if isinstance(item, LightningWork)},
- "flows": {key: item.state for key, item in self.items() if isinstance(item, LightningFlow)},
- }
-
- @property
- def state_vars(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- return {
- "works": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningWork)},
- "flows": {key: item.state_vars for key, item in self.items() if isinstance(item, LightningFlow)},
- }
-
- @property
- def state_with_changes(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- return {
- "works": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningWork)},
- "flows": {key: item.state_with_changes for key, item in self.items() if isinstance(item, LightningFlow)},
- }
-
- def set_state(self, state):
- state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))
- current_state_keys = set(self.keys())
- if current_state_keys != state_keys:
- key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)
- raise Exception(
- f"The provided state doesn't match the `Dict` {self.name}. Found `{key_diff}` un-matching keys"
- )
- for work_key, work_state in state["works"].items():
- self[work_key].set_state(work_state)
- for child_key, child_state in state["flows"].items():
- self[child_key].set_state(child_state)
diff --git a/src/lightning/app/structures/list.py b/src/lightning/app/structures/list.py
deleted file mode 100644
index f2d81ac9f86dc..0000000000000
--- a/src/lightning/app/structures/list.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import typing as t
-
-from lightning.app.utilities.app_helpers import _LightningAppRef, _set_child_name
-
-T = t.TypeVar("T")
-
-if t.TYPE_CHECKING:
- from lightning.app.utilities.types import Component
-
-
-def _prepare_name(component: "Component") -> str:
- return str(component.name.split(".")[-1])
-
-
-# TODO: add support and tests for list operations (concatenation, deletion, insertion, etc.)
-class List(t.List[T]):
- def __init__(self, *items: T):
- """The List Object is used to represents list collection of :class:`~lightning.app.core.work.LightningWork` or
- :class:`~lightning.app.core.flow.LightningFlow`.
-
- Example:
-
- >>> from lightning.app import LightningFlow, LightningWork
- >>> from lightning.app.structures import List
- >>> class CounterWork(LightningWork):
- ... def __init__(self):
- ... super().__init__()
- ... self.counter = 0
- ... def run(self):
- ... self.counter += 1
- ...
- >>> class RootFlow(LightningFlow):
- ... def __init__(self):
- ... super().__init__()
- ... self.list = List(*[CounterWork(), CounterWork()])
- ... def run(self):
- ... for work in self.list:
- ... work.run()
- ...
- >>> flow = RootFlow()
- >>> flow.run()
- >>> assert flow.list[0].counter == 1
-
- Arguments:
- items: A sequence of LightningWork or LightningFlow.
-
- """
- super().__init__()
- from lightning.app.runners.backends import Backend
-
- self._name: t.Optional[str] = ""
- self._last_index = 0
- self._backend: t.Optional[Backend] = None
- for item in items:
- self.append(item)
-
- def append(self, v):
- from lightning.app.core import LightningFlow, LightningWork
-
- _set_child_name(self, v, str(self._last_index))
- if self._backend:
- if isinstance(v, LightningFlow):
- LightningFlow._attach_backend(v, self._backend)
- elif isinstance(v, LightningWork):
- self._backend._wrap_run_method(_LightningAppRef().get_current(), v)
- v._name = f"{self.name}.{self._last_index}"
- self._last_index += 1
- super().append(v)
-
- @property
- def name(self):
- """Returns the name of this List object."""
- return self._name or "root"
-
- @property
- def works(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self if isinstance(item, LightningWork)]
- for flow in [item for item in self if isinstance(item, LightningFlow)]:
- for child_work in flow.works(recurse=False):
- works.append(child_work)
- return works
-
- @property
- def flows(self):
- from lightning.app.core import LightningFlow
- from lightning.app.structures import Dict as _Dict
- from lightning.app.structures import List as _List
-
- flows = {}
- for item in self:
- if isinstance(item, LightningFlow):
- flows[item.name] = item
- for child_flow in item.flows.values():
- flows[child_flow.name] = child_flow
- if isinstance(item, (_Dict, _List)):
- for child_flow in item.flows.values():
- flows[child_flow.name] = child_flow
- return flows
-
- @property
- def state(self):
- """Returns the state of its flows and works."""
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self if isinstance(item, LightningWork)]
- children = [item for item in self if isinstance(item, LightningFlow)]
- return {
- "works": {_prepare_name(w): w.state for w in works},
- "flows": {_prepare_name(flow): flow.state for flow in children},
- }
-
- @property
- def state_vars(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self if isinstance(item, LightningWork)]
- children = [item for item in self if isinstance(item, LightningFlow)]
- return {
- "works": {_prepare_name(w): w.state_vars for w in works},
- "flows": {_prepare_name(flow): flow.state_vars for flow in children},
- }
-
- @property
- def state_with_changes(self):
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self if isinstance(item, LightningWork)]
- children = [item for item in self if isinstance(item, LightningFlow)]
- return {
- "works": {str(_prepare_name(w)): w.state_with_changes for w in works},
- "flows": {_prepare_name(flow): flow.state_with_changes for flow in children},
- }
-
- def set_state(self, state):
- """Method to set the state of the list and its children."""
- from lightning.app.core import LightningFlow, LightningWork
-
- works = [item for item in self if isinstance(item, LightningWork)]
- children = [item for item in self if isinstance(item, LightningFlow)]
-
- current_state_keys = {_prepare_name(w) for w in self}
- state_keys = set(list(state["works"].keys()) + list(state["flows"].keys()))
-
- if current_state_keys != state_keys:
- key_diff = (current_state_keys - state_keys) | (state_keys - current_state_keys)
- raise Exception(
- f"The provided state doesn't match the `List` {self.name}. Found `{key_diff}` un-matching keys"
- )
-
- for work_key, work_state in state["works"].items():
- for work in works:
- if _prepare_name(work) == work_key:
- work.set_state(work_state)
- for child_key, child_state in state["flows"].items():
- for child in children:
- if _prepare_name(child) == child_key:
- child.set_state(child_state)
-
- def __len__(self):
- """Returns the number of elements within this List."""
- return sum(1 for _ in self)
diff --git a/src/lightning/app/testing/__init__.py b/src/lightning/app/testing/__init__.py
deleted file mode 100644
index 1c4e2e171608d..0000000000000
--- a/src/lightning/app/testing/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from lightning.app.testing.helpers import EmptyFlow, EmptyWork
-from lightning.app.testing.testing import (
- LightningTestApp,
- application_testing,
- delete_cloud_lightning_apps,
- run_app_in_cloud,
- run_work_isolated,
- wait_for,
-)
-
-__all__ = [
- "application_testing",
- "run_work_isolated",
- "LightningTestApp",
- "delete_cloud_lightning_apps",
- "run_app_in_cloud",
- "wait_for",
- "EmptyFlow",
- "EmptyWork",
-]
diff --git a/src/lightning/app/testing/config.py b/src/lightning/app/testing/config.py
deleted file mode 100644
index 677c382f5b5ca..0000000000000
--- a/src/lightning/app/testing/config.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from dataclasses import dataclass
-
-
-@dataclass
-class _Config:
- id = os.getenv("LIGHTNING_USER_ID")
- key = os.getenv("LIGHTNING_API_KEY")
- url = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
- api_key = os.getenv("LIGHTNING_API_KEY")
- username = os.getenv("LIGHTNING_USERNAME")
- video_location = os.getenv("VIDEO_LOCATION", "./artifacts/videos")
- har_location = os.getenv("HAR_LOCATION", "./artifacts/hars")
- slowmo = os.getenv("SLOW_MO", "0")
diff --git a/src/lightning/app/testing/helpers.py b/src/lightning/app/testing/helpers.py
deleted file mode 100644
index 61a00f957299e..0000000000000
--- a/src/lightning/app/testing/helpers.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import subprocess
-import sys
-from queue import Empty
-from typing import Any, List, Optional, Tuple
-
-from packaging.version import Version
-
-from lightning.app.core import LightningFlow, LightningWork
-from lightning.app.core.queues import BaseQueue
-from lightning.app.utilities.imports import (
- _CLOUD_TEST_RUN,
- _is_lightning_flash_available,
- _is_pytorch_lightning_available,
-)
-
-
-def _call_script(
- filepath: str,
- args: Optional[List[str]] = None,
- timeout: Optional[int] = 60 * 10,
-) -> Tuple[int, str, str]:
- if args is None:
- args = []
- args = [str(a) for a in args]
- command = [sys.executable, filepath] + args # todo: add back coverage
- p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- try:
- stdout, stderr = p.communicate(timeout=timeout)
- except subprocess.TimeoutExpired:
- p.kill()
- stdout, stderr = p.communicate()
- stdout = stdout.decode("utf-8")
- stderr = stderr.decode("utf-8")
- return p.returncode, stdout, stderr
-
-
-def _run_script(filepath):
- code, stdout, stderr = _call_script(filepath)
- print(f"{filepath} STDOUT: {stdout}")
- print(f"{filepath} STDERR: {stderr}")
- assert not code, code
-
-
-class _RunIf:
- """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
-
- @RunIf(...)
- @pytest.mark.parametrize("arg1", [1, 2.0])
- def test_wrapper(arg1):
- assert arg1 > 0.0
-
- """
-
- def __new__(
- self,
- *args: Any,
- pl: bool = False,
- flash: bool = False,
- min_python: Optional[str] = None,
- skip_windows: bool = False,
- skip_linux: bool = False,
- skip_mac_os: bool = False,
- local_end_to_end: bool = False,
- cloud: bool = False,
- **kwargs: Any,
- ):
- """
- Args:
- *args: Any :class:`pytest.mark.skipif` arguments.
- pl: Requires that PyTorch Lightning is installed.
- flash: Requires that Flash is installed.
- min_python: Require that Python is greater or equal than this version.
- skip_windows: Skip for Windows platform.
- skip_linux: Skip for Linux platform.
- skip_mac_os: Skip for Mac Os Platform.
- **kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
- """
- import pytest
-
- conditions = []
- reasons = []
-
- if min_python:
- py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
- conditions.append(Version(py_version) < Version(min_python))
- reasons.append(f"python>={min_python}")
-
- if skip_windows:
- conditions.append(sys.platform == "win32")
- reasons.append("unimplemented on Windows")
-
- if skip_linux:
- conditions.append(sys.platform == "linux")
- reasons.append("unimplemented on linux")
-
- if skip_mac_os:
- conditions.append(sys.platform == "darwin")
- reasons.append("unimplemented on MacOS")
-
- if pl:
- conditions.append(not _is_pytorch_lightning_available())
- reasons.append("PyTorch Lightning is required.")
-
- if flash:
- conditions.append(not _is_lightning_flash_available())
- reasons.append("Lightning Flash is required.")
-
- if cloud:
- conditions.append(not _CLOUD_TEST_RUN)
- reasons.append("Cloud End to End tests should not run.")
-
- reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
- return pytest.mark.skipif(
- *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
- )
-
-
-class _MockQueue(BaseQueue):
- def __init__(self, name: str = "", default_timeout: float = 0):
- super().__init__(name, default_timeout)
- self._queue = []
-
- def put(self, item):
- self._queue.append(item)
-
- def get(self, timeout: int = 0):
- if not self._queue:
- raise Empty()
- return self._queue.pop(0)
-
- def batch_get(self, timeout: int = 0, count: int = None):
- if not self._queue:
- raise Empty()
- return [self._queue.pop(0)]
-
- def __contains__(self, item):
- return item in self._queue
-
- def __len__(self):
- return len(self._queue)
-
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self._queue})"
-
-
-class EmptyFlow(LightningFlow):
- """A LightningFlow that implements all abstract methods to do nothing.
-
- Useful for mocking in tests.
-
- """
-
- def run(self):
- pass
-
-
-class EmptyWork(LightningWork):
- """A LightningWork that implements all abstract methods to do nothing.
-
- Useful for mocking in tests.
-
- """
-
- def run(self):
- pass
diff --git a/src/lightning/app/testing/testing.py b/src/lightning/app/testing/testing.py
deleted file mode 100644
index c1f50fc3f3643..0000000000000
--- a/src/lightning/app/testing/testing.py
+++ /dev/null
@@ -1,535 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import datetime
-import json
-import os
-import shutil
-import subprocess
-import sys
-import tempfile
-import time
-from contextlib import contextmanager
-from multiprocessing import Process
-from subprocess import Popen
-from time import sleep
-from typing import Any, Callable, Dict, Generator, List, Optional, Type
-
-import requests
-from lightning_cloud.openapi import V1LightningappInstanceState
-from lightning_cloud.openapi.rest import ApiException
-from requests import Session
-from rich import print
-from rich.color import ANSI_COLOR_NAMES
-
-from lightning.app.cli.lightning_cli import run_app
-from lightning.app.core import LightningApp, LightningFlow, constants
-from lightning.app.runners.multiprocess import MultiProcessRuntime
-from lightning.app.testing.config import _Config
-from lightning.app.utilities.app_logs import _app_logs_reader
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.enum import CacheCallsKeys
-from lightning.app.utilities.imports import _is_playwright_available, requires
-from lightning.app.utilities.log import get_logfile
-from lightning.app.utilities.logs_socket_api import _LightningLogsSocketAPI
-from lightning.app.utilities.network import LightningClient, _configure_session
-from lightning.app.utilities.packaging.lightning_utils import get_dist_path_if_editable_install
-from lightning.app.utilities.proxies import ProxyWorkRun
-
-if _is_playwright_available():
- from playwright.sync_api import HttpCredentials, sync_playwright
-
-
-def _on_error_callback(ws_app, *_):
- ws_app.close()
-
-
-def _print_logs(app_id: str):
- client = LightningClient()
- project = _get_project(client)
-
- works = client.lightningwork_service_list_lightningwork(
- project_id=project.project_id,
- app_id=app_id,
- ).lightningworks
- component_names = ["flow"] + [w.name for w in works]
-
- rich_colors = list(ANSI_COLOR_NAMES)
- colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)}
-
- max_length = max(len(c.replace("root.", "")) for c in component_names)
- identifiers = []
-
- print("################### PRINTING LOGS ###################")
-
- logs_api_client = _LightningLogsSocketAPI(client.api_client)
-
- while True:
- gen = _app_logs_reader(
- logs_api_client=logs_api_client,
- project_id=project.project_id,
- app_id=app_id,
- component_names=component_names,
- follow=False,
- on_error_callback=_on_error_callback,
- )
- for log_event in gen:
- message = log_event.message
- identifier = f"{log_event.timestamp}{log_event.message}"
- if identifier not in identifiers:
- date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
- identifiers.append(identifier)
- color = colors[log_event.component_name]
- padding = (max_length - len(log_event.component_name)) * " "
- print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}")
-
-
-class LightningTestApp(LightningApp):
- def __init__(self, *args: Any, **kwargs: Any):
- super().__init__(*args, **kwargs)
- self.counter = 0
-
- @staticmethod
- def _configure_session() -> Session:
- return _configure_session()
-
- def make_request(self, fn, *args: Any, **kwargs: Any):
- loop = asyncio.new_event_loop()
- loop.run_until_complete(self._make_request(fn, *args, **kwargs))
-
- async def _make_request(self, fn: Callable, *args: Any, **kwargs: Any):
- from lightning.app.utilities.state import AppState
-
- state = AppState()
- state._request_state()
- fn(state, *args, **kwargs)
- state.send_delta()
-
- def on_before_run_once(self):
- pass
-
- def on_after_run_once(self):
- self.counter += 1
-
- def run_once(self):
- before_done = self.on_before_run_once()
- if before_done is not None:
- return before_done
- done = super().run_once()
- after_done = self.on_after_run_once()
- if after_done is not None:
- return after_done
- return done
-
- def kill_work(self, work_name: str, sleep_time: int = 1):
- """Use this method to kill a specific work by its name."""
- self.processes[work_name].kill()
-
- def restart_work(self, work_name: str):
- """Use this method to restart a specific work by its name."""
- self.processes[work_name].restart()
-
-
-@requires("click")
-def application_testing(lit_app_cls: Type[LightningTestApp] = LightningTestApp, command_line: List[str] = []) -> Any:
- from unittest import mock
-
- from click.testing import CliRunner
-
- with mock.patch("lightning.app.LightningApp", lit_app_cls):
- original = sys.argv
- sys.argv = command_line
- runner = CliRunner()
- result = runner.invoke(run_app, command_line, catch_exceptions=False)
- sys.argv = original
- return result
-
-
-class _SingleWorkFlow(LightningFlow):
- def __init__(self, work, args, kwargs):
- super().__init__()
- self.work = work
- self.args = args
- self.kwargs = kwargs
-
- def run(self):
- if self.work.has_succeeded or self.work.has_failed:
- self.stop()
- self.work.run(*self.args, **self.kwargs)
-
-
-def run_work_isolated(work, *args: Any, start_server: bool = False, **kwargs: Any):
- """This function is used to run a work a single time with multiprocessing runtime."""
- MultiProcessRuntime(
- LightningApp(_SingleWorkFlow(work, args, kwargs), log_level="debug"),
- start_server=start_server,
- ).dispatch()
- # pop the stopped status.
- call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
-
- if call_hash in work._calls:
- work._calls[call_hash]["statuses"].pop(-1)
-
- if isinstance(work.run, ProxyWorkRun):
- work.run = work.run.work_run
-
-
-def _browser_context_args(browser_context_args: Dict) -> Dict:
- return {
- **browser_context_args,
- "viewport": {
- "width": 1920,
- "height": 1080,
- },
- "ignore_https_errors": True,
- }
-
-
-@contextmanager
-def _run_cli(args) -> Generator:
- """This utility is used to automate end-to-end testing of the Lightning AI CLI."""
- cmd = [
- sys.executable,
- "-m",
- "lightning",
- ] + args
-
- with tempfile.TemporaryDirectory() as tmpdir:
- env_copy = os.environ.copy()
- process = Popen(
- cmd,
- cwd=tmpdir,
- env=env_copy,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
- process.wait()
-
- yield process.stdout.read().decode("UTF-8"), process.stderr.read().decode("UTF-8")
-
-
-def _fetch_app_by_name(client, project_id, name):
- lit_apps = [
- app
- for app in client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id).lightningapps
- if app.name == name or getattr(app, "display_name", None) == name
- ]
- if not len(lit_apps) == 1:
- raise ValueError(f"Expected to find just one app, found {len(lit_apps)}")
- return lit_apps[0]
-
-
-@requires("playwright")
-@contextmanager
-def run_app_in_cloud(
- app_folder: str, app_name: str = "app.py", extra_args: List[str] = [], debug: bool = True
-) -> Generator:
- """This utility is used to automate testing e2e application with lightning.ai."""
- # 1. Validate the provide app_folder is correct.
- if not os.path.exists(os.path.join(app_folder, app_name)):
- raise Exception(f"The app folder should contain an {app_name} file.")
- if app_folder.endswith("/"):
- app_folder = app_folder[:-1]
-
- # 2. Create the right application name.
- basename = app_folder.split("/")[-1]
- PR_NUMBER = os.getenv("PR_NUMBER", None)
-
- is_editable_mode = get_dist_path_if_editable_install("lightning")
- if not is_editable_mode and PR_NUMBER is not None:
- raise Exception("Lightning requires to be installed in editable mode in the CI.")
-
- TEST_APP_NAME = os.getenv("TEST_APP_NAME", basename)
- os.environ["TEST_APP_NAME"] = TEST_APP_NAME
-
- if PR_NUMBER:
- name = f"test-{PR_NUMBER}-{TEST_APP_NAME}-" + str(int(time.time()))
- else:
- name = f"test-{TEST_APP_NAME}-" + str(int(time.time()))
-
- os.environ["LIGHTNING_APP_NAME"] = name
-
- url = _Config.url
- if url.endswith("/"):
- url = url[:-1]
- payload = {"apiKey": _Config.api_key, "username": _Config.username}
- url_login = url + "/v1/auth/login"
- res = requests.post(url_login, data=json.dumps(payload))
- if "token" not in res.json():
- raise RuntimeError(
- f"You haven't properly setup your environment variables with {url_login} and data: \n{payload}"
- )
-
- token = res.json()["token"]
-
- # 3. Disconnect from the App if any.
- Popen("lightning_app logout", shell=True).wait()
-
- # 4. Launch the application in the cloud from the Lightning CLI.
- with tempfile.TemporaryDirectory() as tmpdir:
- env_copy = os.environ.copy()
- env_copy["PACKAGE_LIGHTNING"] = "1"
- env_copy["LIGHTING_TESTING"] = "1"
- if debug:
- print("Debug mode is enabled")
- env_copy["LIGHTNING_DEBUG"] = "1"
- shutil.copytree(app_folder, tmpdir, dirs_exist_ok=True)
- # TODO - add -no-cache to the command line.
- stdout_path = get_logfile(f"run_app_in_cloud_{name}")
-
- cmd_extra_args = []
-
- with open(stdout_path, "w") as stdout:
- cmd = [
- sys.executable,
- "-m",
- "lightning",
- "run",
- "app",
- app_name,
- "--cloud",
- "--name",
- name,
- "--open-ui",
- "false",
- *cmd_extra_args,
- ]
- print(f"Command: {cmd}")
- process = Popen((cmd + extra_args), cwd=tmpdir, env=env_copy, stdout=stdout, stderr=sys.stderr)
- process.wait()
-
- # Fallback URL to prevent failures in case we don't get the admin URL
- admin_url = _Config.url
- with open(stdout_path) as fo:
- for line in fo.readlines():
- if line.startswith("APP_LOGS_URL: "):
- admin_url = line.replace("APP_LOGS_URL: ", "")
- break
-
- if is_editable_mode:
- # Added to ensure the current code is properly uploaded.
- # Otherwise, it could result in un-tested PRs.
- pkg_found = False
- with open(stdout_path) as fo:
- for line in fo.readlines():
- if "Packaged Lightning with your application" in line:
- pkg_found = True
- print(line) # TODO: use logging
- assert pkg_found
- os.remove(stdout_path)
-
- # 5. Print your application name
- print(f"The Lightning App Name is: [bold magenta]{name}[/bold magenta]")
-
- # 6. Create chromium browser, auth to lightning.app.ai and yield the admin and view pages.
- with sync_playwright() as p:
- browser = p.chromium.launch(headless=bool(int(os.getenv("HEADLESS", "0"))))
- context = browser.new_context(
- # Eventually this will need to be deleted
- http_credentials=HttpCredentials({
- "username": os.getenv("LAI_USER", "").strip(),
- "password": os.getenv("LAI_PASS", ""),
- }),
- record_video_dir=os.path.join(_Config.video_location, TEST_APP_NAME),
- record_har_path=_Config.har_location,
- )
-
- client = LightningClient()
- project_id = _get_project(client).project_id
-
- app = _fetch_app_by_name(client, project_id, name)
- app_id = app.id
- print(f"The Lightning App ID is: {app.id}") # useful for Grafana
-
- if debug:
- process = Process(target=_print_logs, kwargs={"app_id": app_id})
- process.start()
-
- admin_page = context.new_page()
- admin_page.goto(admin_url)
- admin_page.evaluate(
- """data => {
- window.localStorage.setItem('gridUserId', data[0]);
- window.localStorage.setItem('gridUserKey', data[1]);
- window.localStorage.setItem('gridUserToken', data[2]);
- }
- """,
- [_Config.id, _Config.key, token],
- )
- if constants.LIGHTNING_CLOUD_PROJECT_ID:
- admin_page.evaluate(
- """data => {
- window.localStorage.setItem('gridDefaultProjectIdOverride', JSON.stringify(data[0]));
- }
- """,
- [constants.LIGHTNING_CLOUD_PROJECT_ID],
- )
-
- admin_page.reload()
-
- view_page = context.new_page()
- i = 1
- while True:
- app = _fetch_app_by_name(client, project_id, name)
- msg = f"Still in phase {app.status.phase}"
-
- # wait until the app is running and openapi.json is ready
- if app.status.phase == V1LightningappInstanceState.RUNNING:
- status_code = requests.get(f"{app.status.url}/openapi.json").status_code
- if status_code == 200:
- print("App is running, continuing with testing...")
- view_page.goto(f"{app.status.url}/view")
- break
- msg = f"Received status code {status_code} at {app.status.url!r}"
- elif app.status.phase not in (V1LightningappInstanceState.PENDING, V1LightningappInstanceState.NOT_STARTED):
- # there's a race condition if the app goes from pending to running to something else before we evaluate
- # the condition above. avoid it by checking stopped explicitly
- print(f"App finished with phase {app.status.phase}, finished testing...")
- break
- if debug and i % 30 == 0:
- print(f"{msg}, continuing infinite loop...")
- i += 1
- sleep(1)
-
- logs_api_client = _LightningLogsSocketAPI(client.api_client)
-
- def fetch_logs(component_names: Optional[List[str]] = None) -> Generator:
- """This methods creates websockets connection in threads and returns the logs to the main thread."""
- if not component_names:
- works = client.lightningwork_service_list_lightningwork(
- project_id=project_id,
- app_id=app_id,
- ).lightningworks
-
- component_names = ["flow"] + [w.name for w in works]
- else:
-
- def add_prefix(c: str) -> str:
- if c == "flow":
- return c
- if not c.startswith("root."):
- return "root." + c
- return c
-
- component_names = [add_prefix(c) for c in component_names]
-
- gen = _app_logs_reader(
- logs_api_client=logs_api_client,
- project_id=project_id,
- app_id=app_id,
- component_names=component_names,
- follow=False,
- on_error_callback=_on_error_callback,
- )
- for log_event in gen:
- yield log_event.message
-
- try:
- yield admin_page, view_page, fetch_logs, name
- except KeyboardInterrupt:
- pass
- finally:
- if debug:
- process.kill()
-
- context.close()
- browser.close()
- Popen("lightning disconnect", shell=True).wait()
-
- delete_cloud_lightning_apps(name=name)
-
-
-def wait_for(page, callback: Callable, *args: Any, **kwargs: Any) -> Any:
- import playwright
-
- while True:
- try:
- res = callback(*args, **kwargs)
- if res:
- return res
- except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError) as err:
- print(err)
- try:
- sleep(7)
- page.reload()
- except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError) as err:
- print(err)
- pass
- sleep(3)
-
-
-def _delete_lightning_app(client, project_id, app_id, app_name):
- print(f"Deleting {app_name} id: {app_id}")
- try:
- res = client.lightningapp_instance_service_delete_lightningapp_instance(
- project_id=project_id,
- id=app_id,
- )
- assert res == {}
- except ApiException as ex:
- print(f"Failed to delete {app_name}. Exception {ex}")
-
-
-def _delete_cloud_space(client, project_id, cloud_space_id, app_name):
- """Used to delete the parent cloudspace."""
- print(f"Deleting {app_name} id: {cloud_space_id}")
- try:
- res = client.cloud_space_service_delete_cloud_space(
- project_id=project_id,
- id=cloud_space_id,
- )
- assert res == {}
- except ApiException as ex:
- print(f"Failed to delete {app_name}. Exception {ex}")
-
-
-def delete_cloud_lightning_apps(name=None):
- """Cleanup cloud apps that start with the name test-{PR_NUMBER}-{TEST_APP_NAME}.
-
- PR_NUMBER and TEST_APP_NAME are environment variables.
-
- """
-
- client = LightningClient()
-
- try:
- pr_number = int(os.getenv("PR_NUMBER", None))
- except (TypeError, ValueError):
- # Failed when the PR is running master or 'PR_NUMBER' isn't defined.
- pr_number = ""
-
- app_name = os.getenv("TEST_APP_NAME", "").replace("_", "-")
-
- print(f"deleting apps for pr_number: {pr_number}, app_name: {app_name}")
- project_id = _get_project(client).project_id
- list_apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
-
- if pr_number and app_name:
- for lit_app in list_apps.lightningapps:
- if name == lit_app.name or (str(pr_number) in lit_app.name and app_name in lit_app.name):
- _delete_lightning_app(client, project_id=project_id, app_id=lit_app.id, app_name=lit_app.name)
- _delete_cloud_space(
- client, project_id=project_id, cloud_space_id=lit_app.spec.cloud_space_id, app_name=lit_app.name
- )
-
- print("deleting apps that were created more than 20 minutes ago.")
-
- for lit_app in list_apps.lightningapps:
- time_diff = datetime.datetime.now(lit_app.created_at.tzinfo) - lit_app.created_at
- if time_diff > datetime.timedelta(minutes=20):
- _delete_lightning_app(client, project_id=project_id, app_id=lit_app.id, app_name=lit_app.name)
- _delete_cloud_space(
- client, project_id=project_id, cloud_space_id=lit_app.spec.cloud_space_id, app_name=lit_app.name
- )
diff --git a/src/lightning/app/utilities/__init__.py b/src/lightning/app/utilities/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/utilities/app_commands.py b/src/lightning/app/utilities/app_commands.py
deleted file mode 100644
index e3e8af50d3e23..0000000000000
--- a/src/lightning/app/utilities/app_commands.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import subprocess
-from dataclasses import dataclass, field
-from typing import List
-
-from lightning.app.utilities.exceptions import MisconfigurationException
-
-logger = logging.getLogger(__name__)
-
-# These are common lines at the top of python files which conflict with our
-# command syntax but which should not be executed. This is non-exhaustive,
-# and it may be better to just ignoring shebang lines if we see problems here.
-APP_COMMAND_LINES_TO_IGNORE = {
- "#!/usr/bin/python",
- "#!/usr/local/bin/python",
- "#!/usr/bin/env python",
- "#!/usr/bin/env python3",
-}
-
-
-@dataclass
-class CommandLines:
- file: str
- commands: List[str] = field(default_factory=list)
- line_numbers: List[int] = field(default_factory=list)
-
-
-def _extract_commands_from_file(file_name: str) -> CommandLines:
- """Extract all lines at the top of the file which contain commands to execute.
-
- The return struct contains a list of commands to execute with the corresponding line number the command executed on.
-
- """
- cl = CommandLines(
- file=file_name,
- )
- with open(file_name) as f:
- file_lines = f.readlines()
-
- for line_number, line in enumerate(file_lines):
- line = line.strip()
- if line in APP_COMMAND_LINES_TO_IGNORE:
- continue
-
- # stop parsing at first non-comment line at top of file
- if not line.startswith("#"):
- continue
-
- # remove comment marker and any leading / trailing whitespaces
- line = line.lstrip("#").strip()
- if len(line) == 0:
- # do not stop parsing on empty on comment lines
- continue
-
- # only run commands starting with a bang (!) & strip the bang from the executed command.
- if line[0] != "!":
- continue
- line = line[1:].strip()
-
- cl.commands.append(line)
- # add 1 to line number because enumerate returns indexes starting at 0, while
- # text exitors list lines beginning at index 1.
- cl.line_numbers.append(line_number + 1)
-
- return cl
-
-
-def _execute_app_commands(cl: CommandLines) -> None:
- """Open a subprocess shell to execute app commands.
-
- The calling app environment is used in the current environment the code is running in
-
- """
- for command, line_number in zip(cl.commands, cl.line_numbers):
- logger.info(f"Running app setup command: {command}")
- completed = subprocess.run(
- command,
- shell=True,
- env=os.environ,
- )
- try:
- completed.check_returncode()
- except subprocess.CalledProcessError:
- err_txt = (
- f"There was a problem on line {line_number} of {cl.file} while executing the command: "
- f"{command}. More information on the problem is shown in the output above this "
- f"message. After editing this line to fix the problem you can run the app again."
- )
- logger.error(err_txt)
- raise MisconfigurationException(err_txt) from None
-
-
-def run_app_commands(file: str) -> None:
- """Extract all lines at the top of the file which contain commands & execute them.
-
- Commands to execute are comment lines whose first non-whitespace character begins with the "bang" symbol (`!`).
- After the first non comment line we stop parsing the rest of the file. Running environment is preserved in the
- subprocess shell.
-
- For example:
-
- # some file <--- not a command # !echo "hello world" <--- a command # ! pip install foo <--- a command #
- foo! bar <--- not a command import lightning <--- not a command, end parsing.
-
- where `echo "hello world" && pip install foo` would be executed in the current running environment.
-
- """
- cl = _extract_commands_from_file(file_name=file)
- if len(cl.commands) == 0:
- logger.debug("No in app commands to install.")
- return
- _execute_app_commands(cl=cl)
diff --git a/src/lightning/app/utilities/app_helpers.py b/src/lightning/app/utilities/app_helpers.py
deleted file mode 100644
index 9efb173a001d4..0000000000000
--- a/src/lightning/app/utilities/app_helpers.py
+++ /dev/null
@@ -1,582 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import abc
-import asyncio
-import builtins
-import enum
-import functools
-import inspect
-import json
-import logging
-import os
-import sys
-import threading
-import time
-from abc import ABC, abstractmethod
-from contextlib import contextmanager
-from copy import deepcopy
-from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Type
-from unittest.mock import MagicMock
-
-import websockets
-from deepdiff import Delta
-
-import lightning.app
-from lightning.app.utilities.exceptions import LightningAppStateException
-from lightning.app.utilities.tree import breadth_first
-
-if TYPE_CHECKING:
- from lightning.app.core.app import LightningApp
- from lightning.app.core.flow import LightningFlow
- from lightning.app.utilities.types import Component
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class StateEntry:
- """Dataclass used to keep track the latest state shared through the app REST API."""
-
- app_state: Mapping = field(default_factory=dict)
- served_state: Mapping = field(default_factory=dict)
- session_id: Optional[str] = None
-
-
-class StateStore(ABC):
- """Base class of State store that provides simple key, value store to keep track of app state, served app state."""
-
- @abstractmethod
- def __init__(self):
- pass
-
- @abstractmethod
- def add(self, k: str):
- """Creates a new empty state with input key 'k'."""
- pass
-
- @abstractmethod
- def remove(self, k: str):
- """Deletes a state with input key 'k'."""
- pass
-
- @abstractmethod
- def get_app_state(self, k: str) -> Mapping:
- """Returns a stored appstate for an input key 'k'."""
- pass
-
- @abstractmethod
- def get_served_state(self, k: str) -> Mapping:
- """Returns a last served app state for an input key 'k'."""
- pass
-
- @abstractmethod
- def get_served_session_id(self, k: str) -> str:
- """Returns session id for state of a key 'k'."""
- pass
-
- @abstractmethod
- def set_app_state(self, k: str, v: Mapping):
- """Sets the app state for state of a key 'k'."""
- pass
-
- @abstractmethod
- def set_served_state(self, k: str, v: Mapping):
- """Sets the served state for state of a key 'k'."""
- pass
-
- @abstractmethod
- def set_served_session_id(self, k: str, v: str):
- """Sets the session id for state of a key 'k'."""
- pass
-
-
-class InMemoryStateStore(StateStore):
- """In memory simple store to keep track of state through the app REST API."""
-
- def __init__(self):
- self.store = {}
- self.counter = 0
-
- def add(self, k):
- self.store[k] = StateEntry()
-
- def remove(self, k):
- del self.store[k]
-
- def get_app_state(self, k):
- return self.store[k].app_state
-
- def get_served_state(self, k):
- return self.store[k].served_state
-
- def get_served_session_id(self, k):
- return self.store[k].session_id
-
- def set_app_state(self, k, v):
- state_size = sys.getsizeof(v)
- if state_size > lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES:
- raise LightningAppStateException(
- f"App state size is {state_size} bytes, which is larger than the recommended size "
- f"of {lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES}. Please investigate this."
- )
- self.store[k].app_state = deepcopy(v)
- self.counter += 1
-
- def set_served_state(self, k, v):
- self.store[k].served_state = deepcopy(v)
-
- def set_served_session_id(self, k, v):
- self.store[k].session_id = v
-
-
-class _LightningAppRef:
- _app_instance: Optional["LightningApp"] = None
-
- @classmethod
- def connect(cls, app_instance: "LightningApp") -> None:
- cls._app_instance = app_instance
-
- @classmethod
- def get_current(cls) -> Optional["LightningApp"]:
- if cls._app_instance:
- return cls._app_instance
- return None
-
-
-def affiliation(component: "Component") -> Tuple[str, ...]:
- """Returns the affiliation of a component."""
- if component.name in ("root", ""):
- return ()
- return tuple(component.name.split(".")[1:])
-
-
-class AppStateType(str, enum.Enum):
- STREAMLIT = "STREAMLIT"
- DEFAULT = "DEFAULT"
-
-
-class BaseStatePlugin(abc.ABC):
- def __init__(self):
- self.authorized = None
-
- @abc.abstractmethod
- def should_update_app(self, deep_diff):
- pass
-
- @abc.abstractmethod
- def get_context(self):
- pass
-
- @abc.abstractmethod
- def render_non_authorized(self):
- pass
-
-
-class AppStatePlugin(BaseStatePlugin):
- def should_update_app(self, deep_diff):
- return deep_diff
-
- def get_context(self):
- return {"type": AppStateType.DEFAULT.value}
-
- def render_non_authorized(self):
- pass
-
-
-def target_fn():
- try:
- # streamlit >= 1.14.0
- from streamlit import runtime
-
- get_instance = runtime.get_instance
- exists = runtime.exists()
- except ImportError:
- # Older versions
- from streamlit.server.server import Server
-
- get_instance = Server.get_current
- exists = bool(Server._singleton)
-
- async def update_fn():
- runtime_instance = get_instance()
- sessions = list(runtime_instance._session_info_by_id.values())
- url = (
- "localhost:8080"
- if "LIGHTNING_APP_STATE_URL" in os.environ
- else f"localhost:{lightning.app.core.constants.APP_SERVER_PORT}"
- )
- ws_url = f"ws://{url}/api/v1/ws"
- last_updated = time.time()
- async with websockets.connect(ws_url) as websocket:
- while True:
- try:
- _ = await websocket.recv()
-
- while (time.time() - last_updated) < 1:
- time.sleep(0.1)
- for session in sessions:
- session = session.session
- session.request_rerun(session._client_state)
- last_updated = time.time()
- except websockets.exceptions.ConnectionClosedOK:
- # The websocket is not enabled
- break
-
- if exists:
- asyncio.run(update_fn())
-
-
-class StreamLitStatePlugin(BaseStatePlugin):
- def __init__(self):
- super().__init__()
- import streamlit as st
-
- if hasattr(st, "session_state") and "websocket_thread" not in st.session_state:
- thread = threading.Thread(target=target_fn)
- st.session_state.websocket_thread = thread
- thread.setDaemon(True)
- thread.start()
-
- def should_update_app(self, deep_diff):
- return deep_diff
-
- def get_context(self):
- return {"type": AppStateType.DEFAULT.value}
-
- def render_non_authorized(self):
- pass
-
-
-def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool:
- if instance is None:
- return False
- if parent is None:
- if isinstance(instance, lightning.app.LightningFlow):
- parent = lightning.app.LightningFlow
- elif isinstance(instance, lightning.app.LightningWork):
- parent = lightning.app.LightningWork
- if parent is None:
- raise ValueError("Expected a parent")
- from lightning_utilities.core.overrides import is_overridden
-
- return is_overridden(method_name, instance, parent)
-
-
-def _is_json_serializable(x: Any) -> bool:
- """Test whether a variable can be encoded as json."""
- if type(x) in lightning.app.core.constants.SUPPORTED_PRIMITIVE_TYPES:
- # shortcut for primitive types that are not containers
- return True
- try:
- json.dumps(x, cls=LightningJSONEncoder)
- return True
- except (TypeError, OverflowError):
- # OverflowError is raised if number is too large to encode
- return False
-
-
-def _set_child_name(component: "Component", child: "Component", new_name: str) -> str:
- """Computes and sets the name of a child given the parent, and returns the name."""
- child_name = f"{component.name}.{new_name}"
- child._name = child_name
-
- # the name changed, so recursively update the names of the children of this child
- if isinstance(child, lightning.app.core.LightningFlow):
- for n in child._flows:
- c = getattr(child, n)
- _set_child_name(child, c, n)
- for n in child._works:
- c = getattr(child, n)
- _set_child_name(child, c, n)
- for n in child._structures:
- s = getattr(child, n)
- _set_child_name(child, s, n)
- if isinstance(child, lightning.app.structures.Dict):
- for n, c in child.items():
- _set_child_name(child, c, n)
- if isinstance(child, lightning.app.structures.List):
- for c in child:
- _set_child_name(child, c, c.name.split(".")[-1])
-
- return child_name
-
-
-def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta:
- delta_dict = delta.to_dict()
- for changed in delta_dict.values():
- for delta_key in changed.copy():
- val = changed[delta_key]
-
- new_prefix = "root"
- for p, c in _walk_to_component(root, component):
- if isinstance(c, lightning.app.core.LightningWork):
- new_prefix += "['works']"
-
- if isinstance(c, lightning.app.core.LightningFlow):
- new_prefix += "['flows']"
-
- if isinstance(c, (lightning.app.structures.Dict, lightning.app.structures.List)):
- new_prefix += "['structures']"
-
- c_n = c.name.split(".")[-1]
- new_prefix += f"['{c_n}']"
-
- delta_key_without_root = delta_key[4:] # the first 4 chars are the word 'root', strip it
- new_key = new_prefix + delta_key_without_root
- if new_key != delta_key:
- changed[new_key] = val
- del changed[delta_key]
-
- return Delta(delta_dict)
-
-
-def _walk_to_component(
- root: "LightningFlow",
- component: "Component",
-) -> Generator[Tuple["Component", "Component"], None, None]:
- """Returns a generator that runs through the tree starting from the root down to the given component.
-
- At each node, yields parent and child as a tuple.
-
- """
- from lightning.app.structures import Dict, List
-
- name_parts = component.name.split(".")[1:] # exclude 'root' from the name
- parent = root
- for n in name_parts:
- if isinstance(parent, (Dict, List)):
- child = parent[n] if isinstance(parent, Dict) else parent[int(n)]
- else:
- child = getattr(parent, n)
- yield parent, child
- parent = child
-
-
-def _collect_child_process_pids(pid: int) -> List[int]:
- """Function to return the list of child process pid's of a process."""
- processes = os.popen("ps -ej | grep -i 'python' | grep -v 'grep' | awk '{ print $2,$3 }'").read()
- processes = [p.split(" ") for p in processes.split("\n")[:-1]]
- return [int(child) for child, parent in processes if parent == str(pid) and child != str(pid)]
-
-
-def _print_to_logger_info(*args: Any, **kwargs: Any):
- # TODO Find a better way to re-direct print to loggers.
- lightning.app._logger.info(" ".join([str(v) for v in args]))
-
-
-def convert_print_to_logger_info(func: Callable) -> Callable:
- """This function is used to transform any print into logger.info calls, so it gets tracked in the cloud."""
-
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- original_print = __builtins__["print"]
- __builtins__["print"] = _print_to_logger_info
- res = func(*args, **kwargs)
- __builtins__["print"] = original_print
- return res
-
- return wrapper
-
-
-def pretty_state(state: Dict) -> Dict:
- """Utility to prettify the state by removing hidden attributes."""
- new_state = {}
- for k, v in state["vars"].items():
- if not k.startswith("_"):
- if "vars" not in new_state:
- new_state["vars"] = {}
- new_state["vars"][k] = v
- if "flows" in state:
- for k, v in state["flows"].items():
- if "flows" not in new_state:
- new_state["flows"] = {}
- new_state["flows"][k] = pretty_state(state["flows"][k])
- if "works" in state:
- for k, v in state["works"].items():
- if "works" not in new_state:
- new_state["works"] = {}
- new_state["works"][k] = pretty_state(state["works"][k])
- return new_state
-
-
-class LightningJSONEncoder(json.JSONEncoder):
- def default(self, obj: Any) -> Any:
- if callable(getattr(obj, "__json__", None)):
- return obj.__json__()
- return json.JSONEncoder.default(self, obj)
-
-
-class Logger:
- """This class is used to improve the debugging experience."""
-
- def __init__(self, name: str):
- self.logger = logging.getLogger(name)
- self.level = None
-
- def info(self, msg, *args: Any, **kwargs: Any):
- self.logger.info(msg, *args, **kwargs)
-
- def warn(self, msg, *args: Any, **kwargs: Any):
- self._set_level()
- self.logger.warn(msg, *args, **kwargs)
-
- def debug(self, msg, *args: Any, **kwargs: Any):
- self._set_level()
- self.logger.debug(msg, *args, **kwargs)
-
- def error(self, msg, *args: Any, **kwargs: Any):
- self._set_level()
- self.logger.error(msg, *args, **kwargs)
-
- def _set_level(self):
- """Lazily set the level once set by the users."""
- # Set on the first from either log, warn, debug or error call.
- if self.level is None:
- self.level = logging.DEBUG if bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) else logging.INFO
- self.logger.setLevel(self.level)
-
-
-def _state_dict(flow: "LightningFlow"):
- state = {}
- flows = [flow] + list(flow.flows.values())
- for f in flows:
- state[f.name] = f.state_dict()
- for w in flow.works():
- state[w.name] = w.state
- return state
-
-
-def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict: bool = True) -> None:
- """This function is used to reload the state assuming dynamic components creation.
-
- When a component isn't found but its state exists, its state is passed up to its closest existing parent.
-
- Arguments:
- root_flow: The flow at the top of the component tree.
- state: The collected state dict.
- strict: Whether to validate all components have been re-created.
-
- """
- # 1: Reload the state of the existing works
- for w in root_flow.works():
- w.set_state(state.pop(w.name))
-
- # 2: Collect the existing flows
- flows = [root_flow] + list(root_flow.flows.values())
- flow_map = {f.name: f for f in flows}
-
- # 3: Find the state of the all dynamic components
- dynamic_components = {k: v for k, v in state.items() if k not in flow_map}
-
- # 4: Propagate the state of the dynamic components to their closest parents
- dynamic_children_state = {}
- for name, component_state in dynamic_components.items():
- affiliation = name.split(".")
- for idx in range(0, len(affiliation)):
- parent_name = ".".join(affiliation[:-idx])
- has_matched = False
- for flow_name, flow in flow_map.items():
- if flow_name == parent_name:
- if flow_name not in dynamic_children_state:
- dynamic_children_state[flow_name] = {}
-
- dynamic_children_state[flow_name].update({name.replace(parent_name + ".", ""): component_state})
- has_matched = True
- break
- if has_matched:
- break
-
- # 5: Reload the flow states
- for flow_name, flow in flow_map.items():
- flow.load_state_dict(state.pop(flow_name), dynamic_children_state.get(flow_name, {}), strict=strict)
-
- # 6: Verify all dynamic components has been re-created.
- if strict:
- components_names = (
- [root_flow.name] + [f.name for f in root_flow.flows.values()] + [w.name for w in root_flow.works()]
- )
- for component_name in dynamic_components:
- if component_name not in components_names:
- raise Exception(f"The component {component_name} was re-created during state reloading.")
-
-
-class _MagicMockJsonSerializable(MagicMock):
- @staticmethod
- def __json__():
- return "{}"
-
-
-def _mock_import(*args, original_fn=None):
- try:
- return original_fn(*args)
- except Exception:
- return _MagicMockJsonSerializable()
-
-
-@contextmanager
-def _mock_missing_imports():
- original_fn = builtins.__import__
- builtins.__import__ = functools.partial(_mock_import, original_fn=original_fn)
- try:
- yield
- finally:
- builtins.__import__ = original_fn
-
-
-def is_static_method(klass_or_instance, attr) -> bool:
- return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod)
-
-
-def _lightning_dispatched() -> bool:
- return bool(int(os.getenv("LIGHTNING_DISPATCHED", 0)))
-
-
-def _using_debugger() -> bool:
- """This method is used to detect whether the app is run with a debugger attached."""
- if "LIGHTNING_DETECTED_DEBUGGER" in os.environ:
- return True
-
- # Collect the information about the process.
- parent_process = os.popen(f"ps -ax | grep -i {os.getpid()} | grep -v grep").read()
-
- # Detect whether VSCode or PyCharm debugger are used
- use_debugger = "debugpy" in parent_process or "pydev" in parent_process
-
- # Store the result to avoid multiple popen calls.
- if use_debugger:
- os.environ["LIGHTNING_DETECTED_DEBUGGER"] = "1"
- return use_debugger
-
-
-def _should_dispatch_app() -> bool:
- return (
- not _lightning_dispatched()
- and "LIGHTNING_APP_STATE_URL" not in os.environ
- # Keep last to avoid running it if already dispatched
- and _using_debugger()
- )
-
-
-def _is_headless(app: "LightningApp") -> bool:
- """Utility which returns True if the given App has no ``Frontend`` objects or URLs exposed through
- ``configure_layout``."""
- if app.frontends:
- return False
- for component in breadth_first(app.root, types=(lightning.app.LightningFlow,)):
- for entry in component._layout:
- if "target" in entry:
- return False
- return True
diff --git a/src/lightning/app/utilities/app_logs.py b/src/lightning/app/utilities/app_logs.py
deleted file mode 100644
index 446418f9b18e2..0000000000000
--- a/src/lightning/app/utilities/app_logs.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import queue
-from dataclasses import dataclass
-from threading import Thread
-from typing import Callable, Iterator, List, Optional
-
-import dateutil.parser
-from websocket import WebSocketApp
-
-from lightning.app.utilities.log_helpers import _error_callback, _OrderedLogEntry
-from lightning.app.utilities.logs_socket_api import _LightningLogsSocketAPI
-
-
-@dataclass
-class _LogEventLabels:
- app: Optional[str] = None
- container: Optional[str] = None
- filename: Optional[str] = None
- job: Optional[str] = None
- namespace: Optional[str] = None
- node_name: Optional[str] = None
- pod: Optional[str] = None
- component: Optional[str] = None
- projectID: Optional[str] = None
- stream: Optional[str] = None
-
-
-@dataclass
-class _LogEvent(_OrderedLogEntry):
- component_name: str
- labels: _LogEventLabels
-
-
-def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
- """Pushes _LogEvents from websocket to read_queue.
-
- Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
-
- """
-
- def callback(ws_app: WebSocketApp, msg: str):
- # We strongly trust that the contract on API will hold atm :D
- event_dict = json.loads(msg)
- labels = _LogEventLabels(**event_dict.get("labels", {}))
-
- if "message" in event_dict:
- message = event_dict["message"]
- timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
- event = _LogEvent(
- message=message,
- timestamp=timestamp,
- component_name=component_name,
- labels=labels,
- )
- read_queue.put(event)
-
- return callback
-
-
-def _app_logs_reader(
- logs_api_client: _LightningLogsSocketAPI,
- project_id: str,
- app_id: str,
- component_names: List[str],
- follow: bool,
- on_error_callback: Optional[Callable] = None,
-) -> Iterator[_LogEvent]:
- read_queue = queue.PriorityQueue()
-
- # We will use a socket per component
- log_sockets = [
- logs_api_client.create_lightning_logs_socket(
- project_id=project_id,
- app_id=app_id,
- component=component_name,
- on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
- on_error_callback=on_error_callback or _error_callback,
- )
- for component_name in component_names
- ]
-
- # And each socket on separate thread pushing log event to print queue
- # run_forever() will run until we close() the connection from outside
- log_threads = [Thread(target=work.run_forever, daemon=True) for work in log_sockets]
-
- # Establish connection and begin pushing logs to the print queue
- for th in log_threads:
- th.start()
-
- # Print logs from queue when log event is available
- flow = "Your app has started."
- work = "USER_RUN_WORK"
- start_timestamps = {}
-
- # Print logs from queue when log event is available
- try:
- while True:
- log_event: _LogEvent = read_queue.get(timeout=None if follow else 1.0)
-
- token = flow if log_event.component_name == "flow" else work
- if token in log_event.message:
- start_timestamps[log_event.component_name] = log_event.timestamp
-
- timestamp = start_timestamps.get(log_event.component_name, None)
- if timestamp and log_event.timestamp >= timestamp and "launcher" not in log_event.message:
- yield log_event
-
- except queue.Empty:
- # Empty is raised by queue.get if timeout is reached. Follow = False case.
- pass
-
- except KeyboardInterrupt:
- # User pressed CTRL+C to exit, we should respect that
- pass
-
- finally:
- # Close connections - it will cause run_forever() to finish -> thread as finishes aswell
- for socket in log_sockets:
- socket.close()
-
- # Because all socket were closed, we can just wait for threads to finish.
- for th in log_threads:
- th.join()
diff --git a/src/lightning/app/utilities/app_status.py b/src/lightning/app/utilities/app_status.py
deleted file mode 100644
index 1f40da05bc140..0000000000000
--- a/src/lightning/app/utilities/app_status.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from datetime import datetime
-from typing import Any, Dict, Optional
-
-from pydantic import BaseModel
-
-
-class WorkStatus(BaseModel):
- """The ``WorkStatus`` captures the status of a work according to the app."""
-
- stage: str
- timestamp: float
- reason: Optional[str] = None
- message: Optional[str] = None
- count: int = 1
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, **kwargs)
-
- assert self.timestamp > 0
- assert self.timestamp < (int(datetime.now().timestamp()) + 10)
-
-
-class AppStatus(BaseModel):
- """The ``AppStatus`` captures the current status of the app and its components."""
-
- # ``True`` when the app UI is ready to be viewed
- is_ui_ready: bool
-
- # The statuses of ``LightningWork`` objects currently associated with this app
- work_statuses: Dict[str, WorkStatus]
diff --git a/src/lightning/app/utilities/auth.py b/src/lightning/app/utilities/auth.py
deleted file mode 100644
index 2ccb2d068109f..0000000000000
--- a/src/lightning/app/utilities/auth.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict
-
-from lightning_cloud.openapi import ApiClient, AuthServiceApi, V1LoginRequest
-
-from lightning.app.utilities.login import Auth
-
-
-# This class joins common things for reading logs,
-# initialization and getting API token
-class _AuthTokenGetter:
- def __init__(self, api_client: ApiClient):
- self.api_client = api_client
- self._auth = Auth()
- self._auth.authenticate()
- self._auth_service = AuthServiceApi(api_client)
-
- def _get_api_token(self) -> str:
- token_resp = self._auth_service.auth_service_login(
- body=V1LoginRequest(
- username=self._auth.username,
- api_key=self._auth.api_key,
- )
- )
- return token_resp.token
-
-
-def _credential_string_to_basic_auth_params(credential_string: str) -> Dict[str, str]:
- """Returns the name/ID pair for each given Secret name.
-
- Raises a `ValueError` if any of the given Secret names do not exist.
-
- """
- if credential_string.count(":") != 1:
- raise ValueError(
- "Credential string must follow the format username:password; "
- + f"the provided one ('{credential_string}') does not."
- )
-
- username, password = credential_string.split(":")
-
- if not username:
- raise ValueError("Username cannot be empty.")
-
- if not password:
- raise ValueError("Password cannot be empty.")
-
- return {"username": username, "password": password}
diff --git a/src/lightning/app/utilities/cli_helpers.py b/src/lightning/app/utilities/cli_helpers.py
deleted file mode 100644
index 45428e44c7310..0000000000000
--- a/src/lightning/app/utilities/cli_helpers.py
+++ /dev/null
@@ -1,358 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import contextlib
-import functools
-import json
-import os
-import re
-import subprocess
-import sys
-from typing import Dict, Optional
-
-import arrow
-import click
-import packaging
-import requests
-import rich
-from lightning_cloud.openapi import Externalv1LightningappInstance
-
-from lightning.app import __package_name__, __version__
-from lightning.app.core.constants import APP_SERVER_PORT
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.network import LightningClient
-
-logger = Logger(__name__)
-
-
-def _format_input_env_variables(env_list: tuple) -> Dict[str, str]:
- """
- Args:
- env_list:
- List of str for the env variables, e.g. ['foo=bar', 'bla=bloz']
-
- Returns:
- Dict of the env variables with the following format
- key: env variable name
- value: env variable value
- """
- env_vars_dict = {}
- for env_str in env_list:
- var_parts = env_str.split("=")
- if len(var_parts) != 2 or not var_parts[0]:
- raise Exception(
- f"Invalid format of environment variable {env_str}, "
- f"please ensure that the variable is in the format e.g. foo=bar."
- )
- var_name, value = var_parts
-
- if var_name in env_vars_dict:
- raise Exception(f"Environment variable '{var_name}' is duplicated. Please only include it once.")
-
- if not re.match(r"[0-9a-zA-Z_]+", var_name):
- raise ValueError(
- f"Environment variable '{var_name}' is not a valid name. It is only allowed to contain digits 0-9, "
- f"letters A-Z, a-z and _ (underscore)."
- )
-
- env_vars_dict[var_name] = value
- return env_vars_dict
-
-
-def _is_url(id: Optional[str]) -> bool:
- if isinstance(id, str) and (id.startswith("https://") or id.startswith("http://")):
- return True
- return False
-
-
-def _get_metadata_from_openapi(paths: Dict, path: str):
- parameters = paths[path]["post"].get("parameters", {})
- tag = paths[path]["post"].get("tags", [None])[0]
- cls_path = paths[path]["post"].get("cls_path", None)
- cls_name = paths[path]["post"].get("cls_name", None)
- description = paths[path]["post"].get("description", None)
- requirements = paths[path]["post"].get("requirements", None)
- app_info = paths[path]["post"].get("app_info", None)
-
- metadata = {"tag": tag, "parameters": {}}
-
- if cls_path:
- metadata["cls_path"] = cls_path
-
- if cls_name:
- metadata["cls_name"] = cls_name
-
- if description:
- metadata["description"] = description
-
- if description:
- metadata["requirements"] = requirements
-
- if app_info:
- metadata["app_info"] = app_info
-
- if not parameters:
- return metadata
-
- metadata["parameters"].update({d["name"]: d["schema"]["type"] for d in parameters})
- return metadata
-
-
-def _extract_command_from_openapi(openapi_resp: Dict) -> Dict[str, Dict[str, str]]:
- command_paths = [p for p in openapi_resp["paths"] if p.startswith("/command/")]
- return {p.replace("/command/", ""): _get_metadata_from_openapi(openapi_resp["paths"], p) for p in command_paths}
-
-
-def _get_app_display_name(app: Externalv1LightningappInstance) -> str:
- return getattr(app, "display_name", None) or app.name
-
-
-class _LightningAppOpenAPIRetriever:
- def __init__(
- self,
- app_id_or_name_or_url: Optional[str],
- use_cache: bool = False,
- ):
- """This class encapsulates the logic to collect the openapi.json file from the app to use the CLI Commands.
-
- Arguments:
- app_id_or_name_or_url: An identified for the app.
- use_cache: Whether to load the openapi spec from the cache.
-
- """
- self.app_id_or_name_or_url = app_id_or_name_or_url
- self.url = None
- self.openapi = None
- self.api_commands = None
- self.app_id = None
- self.app_name = None
- home = os.path.expanduser("~")
- if use_cache:
- cache_openapi = os.path.join(home, ".lightning", "lightning_connection", "commands", "openapi.json")
- if os.path.exists(cache_openapi):
- with open(cache_openapi) as f:
- self.openapi = json.load(f)
- self.api_commands = _extract_command_from_openapi(self.openapi)
-
- if not self.api_commands:
- self._collect_open_api_json()
- if self.openapi:
- self.api_commands = _extract_command_from_openapi(self.openapi)
-
- def is_alive(self) -> bool:
- """Returns whether the Lightning App Rest API is available."""
- if self.url is None:
- self._maybe_find_url()
- if self.url is None:
- return False
- resp = requests.get(self.url)
- return resp.status_code == 200
-
- def _maybe_find_url(self):
- """Tries to resolve the app url from the provided `app_id_or_name_or_url`."""
- if _is_url(self.app_id_or_name_or_url):
- self.url = self.app_id_or_name_or_url
- assert self.url
- return
-
- if self.app_id_or_name_or_url is None:
- url = f"http://localhost:{APP_SERVER_PORT}"
- resp = requests.get(f"{self.url}/openapi.json")
- if resp.status_code == 200:
- self.url = url
- return
-
- app = self._maybe_find_matching_cloud_app()
- if app:
- self.url = app.status.url
-
- def _maybe_find_matching_cloud_app(self):
- """Tries to resolve the app url from the provided `app_id_or_name_or_url`."""
- client = LightningClient(retry=False)
- project = _get_project(client)
- list_apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id)
-
- app_names = [_get_app_display_name(lit_app) for lit_app in list_apps.lightningapps]
-
- if not self.app_id_or_name_or_url:
- print(f"ERROR: Provide an application name, id or url with --app_id=X. Found {app_names}")
- sys.exit(0)
-
- for app in list_apps.lightningapps:
- if app.id == self.app_id_or_name_or_url or _get_app_display_name(app) == self.app_id_or_name_or_url:
- if app.status.url == "":
- print("The application is starting. Try in a few moments.")
- sys.exit(0)
- return app
- return None
-
- def _collect_open_api_json(self):
- """This function is used to retrieve the current url associated with an id."""
- if _is_url(self.app_id_or_name_or_url):
- self.url = self.app_id_or_name_or_url
- assert self.url
- resp = requests.get(self.url + "/openapi.json")
- if resp.status_code != 200:
- print(f"ERROR: The server didn't process the request properly. Found {resp.json()}")
- sys.exit(0)
- self.openapi = resp.json()
- return
-
- # 2: If no identifier has been provided, evaluate the local application
- if self.app_id_or_name_or_url is None:
- with contextlib.suppress(requests.exceptions.ConnectionError):
- self.url = f"http://localhost:{APP_SERVER_PORT}"
- resp = requests.get(f"{self.url}/openapi.json")
- if resp.status_code != 200:
- raise Exception(f"The server didn't process the request properly. Found {resp.json()}")
- self.openapi = resp.json()
-
- # 3: If an identified was provided or the local evaluation has failed, evaluate the cloud.
- else:
- app = self._maybe_find_matching_cloud_app()
- if app:
- if app.status.url == "":
- raise Exception("The application is starting. Try in a few moments.")
- resp = requests.get(app.status.url + "/openapi.json")
- if resp.status_code != 200:
- raise Exception(
- "The server didn't process the request properly. " "Try once your application is ready."
- )
- self.url = app.status.url
- self.openapi = resp.json()
- self.app_id = app.id
- self.app_name = _get_app_display_name(app)
-
-
-def _arrow_time_callback(
- _ctx: "click.core.Context", _param: "click.core.Option", value: str, arw_now=arrow.utcnow()
-) -> arrow.Arrow:
- try:
- return arw_now.dehumanize(value)
- except ValueError:
- try:
- return arrow.get(value)
- except (ValueError, TypeError):
- raise click.ClickException(f"cannot parse time {value}")
-
-
-@functools.lru_cache(maxsize=1)
-def _get_newer_version() -> Optional[str]:
- """Check PyPI for newer versions of ``lightning``, returning the newest version if different from the current or
- ``None`` otherwise."""
- if packaging.version.parse(__version__).is_prerelease:
- return None
- try:
- response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json")
- response_json = response.json()
- releases = response_json["releases"]
- if __version__ not in releases:
- # Always return None if not installed from PyPI (e.g. dev versions)
- return None
- latest_version = response_json["info"]["version"]
- parsed_version = packaging.version.parse(latest_version)
- is_invalid = response_json["info"]["yanked"] or parsed_version.is_devrelease or parsed_version.is_prerelease
- return None if __version__ == latest_version or is_invalid else latest_version
- except Exception:
- # Return None if any exception occurs
- return None
-
-
-def _redirect_command(executable: str):
- """Redirect the current lightning CLI call to the given executable."""
- subprocess.run(
- [executable, "-m", "lightning"] + sys.argv[1:],
- env=os.environ,
- )
-
- sys.exit()
-
-
-def _check_version_and_upgrade():
- """Checks that the current version of ``lightning`` is the latest on PyPI.
-
- If not, prompt the user to upgrade ``lightning`` for them and re-run the current call in the new version.
-
- """
- new_version = _get_newer_version()
- if new_version:
- prompt = f"A newer version of {__package_name__} is available ({new_version}). Would you like to upgrade?"
-
- if click.confirm(prompt, default=True):
- command = f"pip install {__package_name__}=={new_version}"
-
- logger.info(f"⚡ RUN: {command}")
-
- # Upgrade
- subprocess.run(
- [sys.executable, "-m"] + command.split(" "),
- check=True,
- )
-
- # Re-launch
- _redirect_command(sys.executable)
- return
-
-
-def _check_environment_and_redirect():
- """Checks that the current ``sys.executable`` is the same as the executable resolved from the current environment.
-
- If not, this utility tries to redirect the ``lightning`` call to the environment executable (prompting the user to
- install lightning for them there if needed).
-
- """
- process = subprocess.run(
- ["python", "-c", "import sys; print(sys.executable)"],
- capture_output=True,
- env=os.environ,
- check=True,
- )
-
- env_executable = os.path.realpath(process.stdout.decode().strip())
- sys_executable = os.path.realpath(sys.executable)
-
- # on windows, the extension might be different, where one uses `.EXE` and the other `.exe`
- if env_executable.lower() != sys_executable.lower():
- logger.info(
- "Lightning is running from outside your current environment. Switching to your current environment."
- )
-
- process = subprocess.run(
- [env_executable, "-m", "lightning", "--version"],
- capture_output=True,
- text=True,
- )
-
- if "No module named lightning" in process.stderr:
- prompt = f"The {__package_name__} package is not installed. Would you like to install it? [Y/n (exit)]"
-
- if click.confirm(prompt, default=True, show_default=False):
- command = f"pip install {__package_name__}"
-
- logger.info(f"⚡ RUN: {command}")
-
- subprocess.run(
- [env_executable, "-m"] + command.split(" "),
- check=True,
- )
- else:
- sys.exit()
-
- _redirect_command(env_executable)
- return
-
-
-def _error_and_exit(msg: str) -> None:
- rich.print(f"[red]ERROR[/red]: {msg}")
- sys.exit(0)
diff --git a/src/lightning/app/utilities/cloud.py b/src/lightning/app/utilities/cloud.py
deleted file mode 100644
index 95c76c1be926a..0000000000000
--- a/src/lightning/app/utilities/cloud.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from typing import Optional
-
-from lightning_cloud.openapi import V1Membership
-
-import lightning.app
-from lightning.app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
-from lightning.app.utilities.enum import AppStage
-from lightning.app.utilities.network import LightningClient
-
-
-def _get_project(client: LightningClient, project_id: Optional[str] = None, verbose: bool = True) -> V1Membership:
- """Get a project membership for the user from the backend."""
- if project_id is None:
- project_id = LIGHTNING_CLOUD_PROJECT_ID
-
- if project_id is not None:
- project = client.projects_service_get_project(project_id)
- if not project:
- raise ValueError(
- "Environment variable `LIGHTNING_CLOUD_PROJECT_ID` is set but could not find an associated project."
- )
- return V1Membership(
- name=project.name,
- display_name=project.display_name,
- description=project.description,
- created_at=project.created_at,
- project_id=project.id,
- owner_id=project.owner_id,
- owner_type=project.owner_type,
- quotas=project.quotas,
- updated_at=project.updated_at,
- )
-
- projects = client.projects_service_list_memberships()
- if len(projects.memberships) == 0:
- raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
- if len(projects.memberships) > 1 and verbose:
- print(f"Defaulting to the project: {projects.memberships[0].name}")
- return projects.memberships[0]
-
-
-def _sigterm_flow_handler(*_, app: "lightning.app.LightningApp"):
- app.stage = AppStage.STOPPING
-
-
-def is_running_in_cloud() -> bool:
- """Returns True if the Lightning App is running in the cloud."""
- return bool(int(os.environ.get("LAI_RUNNING_IN_CLOUD", "0"))) or "LIGHTNING_APP_STATE_URL" in os.environ
diff --git a/src/lightning/app/utilities/clusters.py b/src/lightning/app/utilities/clusters.py
deleted file mode 100644
index 663ba66d456e9..0000000000000
--- a/src/lightning/app/utilities/clusters.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import random
-
-from lightning_cloud.openapi import ProjectIdProjectclustersbindingsBody, V1ClusterType
-from lightning_cloud.openapi.rest import ApiException
-
-from lightning.app.utilities.network import LightningClient
-
-
-def _ensure_cluster_project_binding(client: LightningClient, project_id: str, cluster_id: str) -> None:
- cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id)
-
- for cluster_binding in cluster_bindings.clusters:
- if cluster_binding.cluster_id != cluster_id:
- continue
- if cluster_binding.project_id == project_id:
- return
-
- client.projects_service_create_project_cluster_binding(
- project_id=project_id,
- body=ProjectIdProjectclustersbindingsBody(cluster_id=cluster_id),
- )
-
-
-def _get_default_cluster(client: LightningClient, project_id: str) -> str:
- """This utility implements a minimal version of the cluster selection logic used in the cloud.
-
- TODO: This should be requested directly from the platform.
-
- """
- cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id).clusters
-
- if not cluster_bindings:
- raise ValueError(f"No clusters are bound to the project {project_id}.")
-
- if len(cluster_bindings) == 1:
- return cluster_bindings[0].cluster_id
-
- clusters = []
- for cluster_binding in cluster_bindings:
- try:
- clusters.append(client.cluster_service_get_cluster(cluster_binding.cluster_id))
- except ApiException:
- # If we failed to get the cluster, ignore it
- continue
-
- # Filter global clusters
- clusters = [cluster for cluster in clusters if cluster.spec.cluster_type == V1ClusterType.GLOBAL]
-
- if len(clusters) == 0:
- raise RuntimeError(f"No clusters found on `{client.api_client.configuration.host}`.")
-
- return random.choice(clusters).id # noqa: S311
diff --git a/src/lightning/app/utilities/commands/__init__.py b/src/lightning/app/utilities/commands/__init__.py
deleted file mode 100644
index 14e842687e43a..0000000000000
--- a/src/lightning/app/utilities/commands/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from lightning.app.utilities.commands.base import ClientCommand
-
-__all__ = ["ClientCommand"]
diff --git a/src/lightning/app/utilities/commands/base.py b/src/lightning/app/utilities/commands/base.py
deleted file mode 100644
index ca09a53318f74..0000000000000
--- a/src/lightning/app/utilities/commands/base.py
+++ /dev/null
@@ -1,308 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import errno
-import inspect
-import os
-import os.path as osp
-import shutil
-import sys
-import traceback
-from dataclasses import asdict
-from getpass import getuser
-from importlib.util import module_from_spec, spec_from_file_location
-from tempfile import gettempdir
-from typing import Any, Callable, Dict, List, Optional, Union
-
-import requests
-from fastapi import HTTPException
-from pydantic import BaseModel
-
-from lightning.app.api.http_methods import Post
-from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
-from lightning.app.utilities import frontend
-from lightning.app.utilities.app_helpers import Logger, is_overridden
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.network import LightningClient
-from lightning.app.utilities.state import AppState
-
-logger = Logger(__name__)
-
-
-def makedirs(path: str):
- r"""Recursive directory creation function."""
- try:
- os.makedirs(osp.expanduser(osp.normpath(path)))
- except OSError as ex:
- if ex.errno != errno.EEXIST and osp.isdir(path):
- raise ex
-
-
-class ClientCommand:
- description: str = ""
- requirements: List[str] = []
-
- def __init__(self, method: Callable):
- self.method = method
- if not self.description:
- self.description = self.method.__doc__ or ""
- flow = getattr(self.method, "__self__", None)
- self.owner = flow.name if flow else None
- self.models: Optional[Dict[str, BaseModel]] = None
- self.app_url = None
- self._state = None
-
- def _setup(self, command_name: str, app_url: str) -> None:
- self.command_name = command_name
- self.app_url = app_url
-
- @property
- def state(self):
- if self._state is None:
- assert self.app_url
- # TODO: Resolve this hack
- os.environ["LIGHTNING_APP_STATE_URL"] = "1"
- self._state = AppState(host=self.app_url)
- self._state._request_state()
- os.environ.pop("LIGHTNING_APP_STATE_URL")
- return self._state
-
- def run(self, **cli_kwargs) -> None:
- """Overrides with the logic to execute on the client side."""
-
- def invoke_handler(self, config: Optional[BaseModel] = None) -> Dict[str, Any]:
- command = self.command_name.replace(" ", "_")
- resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
- if resp.status_code != 200:
- try:
- detail = str(resp.json())
- except Exception:
- detail = "Internal Server Error"
- print(f"Failed with status code {resp.status_code}. Detail: {detail}")
- sys.exit(0)
-
- return resp.json()
-
- def _to_dict(self):
- return {"owner": self.owner, "requirements": self.requirements}
-
- def __call__(self, **kwargs: Any):
- return self.method(**kwargs)
-
-
-def _download_command(
- command_name: str,
- cls_path: str,
- cls_name: str,
- app_id: Optional[str] = None,
- debug_mode: bool = False,
- target_file: Optional[str] = None,
-) -> ClientCommand:
- # TODO: This is a skateboard implementation and the final version will rely on versioned
- # immutable commands for security concerns
- command_name = command_name.replace(" ", "_")
- tmpdir = None
- if not target_file:
- tmpdir = osp.join(gettempdir(), f"{getuser()}_commands")
- makedirs(tmpdir)
- target_file = osp.join(tmpdir, f"{command_name}.py")
-
- if not debug_mode:
- if app_id:
- if not os.path.exists(target_file):
- client = LightningClient(retry=False)
- project_id = _get_project(client).project_id
- response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(
- project_id=project_id, id=app_id
- )
- for artifact in response.artifacts:
- if f"commands/{command_name}.py" == artifact.filename:
- resp = requests.get(artifact.url, allow_redirects=True)
-
- with open(target_file, "wb") as f:
- f.write(resp.content)
- else:
- shutil.copy(cls_path, target_file)
-
- spec = spec_from_file_location(cls_name, target_file)
- mod = module_from_spec(spec)
- sys.modules[cls_name] = mod
- spec.loader.exec_module(mod)
- command_type = getattr(mod, cls_name)
- if issubclass(command_type, ClientCommand):
- command = command_type(method=None)
- else:
- raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand`.")
- if tmpdir and os.path.exists(tmpdir):
- shutil.rmtree(tmpdir)
- return command
-
-
-def _to_annotation(anno: str) -> str:
- anno = anno.split("'")[1]
- if "." in anno:
- return anno.split(".")[-1]
- return anno
-
-
-def _validate_client_command(command: ClientCommand):
- """Extract method and its metadata from a ClientCommand."""
- params = inspect.signature(command.method).parameters
- command_metadata = {
- "cls_path": inspect.getfile(command.__class__),
- "cls_name": command.__class__.__name__,
- "params": {p.name: _to_annotation(str(p.annotation)) for p in params.values()},
- **command._to_dict(),
- }
- method = command.method
- command.models = {}
- for k, v in command_metadata["params"].items():
- if v == "_empty":
- raise Exception(
- f"Please, annotate your method {method} with pydantic BaseModel. Refer to the documentation."
- )
- config = getattr(sys.modules[command.__module__], v, None)
- if config is None:
- config = getattr(sys.modules[method.__module__], v, None)
- if config:
- raise Exception(
- f"The provided annotation for the argument {k} should in the file "
- f"{inspect.getfile(command.__class__)}, not {inspect.getfile(command.method)}."
- )
- if config is None or not issubclass(config, BaseModel):
- raise Exception(
- f"The provided annotation for the argument {k} shouldn't an instance of pydantic BaseModel."
- )
-
-
-def _upload(name: str, prefix: str, obj: Any) -> Optional[str]:
- from lightning.app.storage.path import _filesystem, _is_s3fs_available, _shared_storage_path
-
- name = name.replace(" ", "_")
- filepath = f"{prefix}/{name}.py"
- fs = _filesystem()
-
- if _is_s3fs_available():
- from s3fs import S3FileSystem
-
- if not isinstance(fs, S3FileSystem):
- return None
-
- source_file = str(inspect.getfile(obj.__class__))
- remote_url = str(_shared_storage_path() / "artifacts" / filepath)
- fs.put(source_file, remote_url)
- return filepath
- return None
-
-
-def _prepare_commands(app) -> List:
- if not is_overridden("configure_commands", app.root):
- return []
-
- # 1: Upload the command to s3.
- commands = app.root.configure_commands()
- for command_mapping in commands:
- for command_name, command in command_mapping.items():
- if isinstance(command, ClientCommand):
- _upload(command_name, "commands", command)
-
- # 2: Cache the commands on the app.
- app.commands = commands
- return commands
-
-
-def _process_api_request(app, request: _APIRequest):
- flow = app.get_component_by_name(request.name)
- method = getattr(flow, request.method_name)
- try:
- response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
- except HTTPException as ex:
- logger.error(repr(ex))
- response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
- except Exception:
- logger.error(traceback.print_exc())
- response = _RequestResponse(status_code=500)
- return {"response": response, "id": request.id}
-
-
-def _process_command_requests(app, request: _CommandRequest):
- for command in app.commands:
- for command_name, method in command.items():
- command_name = command_name.replace(" ", "_")
- if request.method_name == command_name:
- # 2.1: Evaluate the method associated to a specific command.
- # Validation is done on the CLI side.
- try:
- response = _RequestResponse(content=method(*request.args, **request.kwargs), status_code=200)
- except HTTPException as ex:
- logger.error(repr(ex))
- response = _RequestResponse(status_code=ex.status_code, content=ex.detail)
- except Exception:
- logger.error(traceback.print_exc())
- response = _RequestResponse(status_code=500)
- return {"response": response, "id": request.id}
- return None
-
-
-def _process_requests(app, requests: List[Union[_APIRequest, _CommandRequest]]) -> None:
- """Convert user commands to API endpoint."""
- responses = []
- for request in requests:
- if isinstance(request, _APIRequest):
- response = _process_api_request(app, request)
- else:
- response = _process_command_requests(app, request)
-
- if response:
- responses.append(response)
-
- app.api_response_queue.put(responses)
-
-
-def _collect_open_api_extras(command, info) -> Dict:
- if not isinstance(command, ClientCommand):
- if command.__doc__ is not None:
- return {"description": command.__doc__}
- return {}
-
- extras = {
- "cls_path": inspect.getfile(command.__class__),
- "cls_name": command.__class__.__name__,
- "description": command.description,
- }
- if command.requirements:
- extras.update({"requirements": command.requirements})
- if info:
- extras.update({"app_info": asdict(info)})
- return extras
-
-
-def _commands_to_api(
- commands: List[Dict[str, Union[Callable, ClientCommand]]], info: Optional[frontend.AppInfo] = None
-) -> List:
- """Convert user commands to API endpoint."""
- api = []
- for command in commands:
- for k, v in command.items():
- k = k.replace(" ", "_")
- api.append(
- Post(
- f"/command/{k}",
- v.method if isinstance(v, ClientCommand) else v,
- method_name=k,
- tags=["app_client_command"] if isinstance(v, ClientCommand) else ["app_command"],
- openapi_extra=_collect_open_api_extras(v, info),
- )
- )
- return api
diff --git a/src/lightning/app/utilities/component.py b/src/lightning/app/utilities/component.py
deleted file mode 100644
index 29a60168bf1eb..0000000000000
--- a/src/lightning/app/utilities/component.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
-
-from deepdiff.helper import NotPresent
-from lightning_utilities.core.apply_func import apply_to_collection
-
-from lightning.app.utilities.app_helpers import is_overridden
-from lightning.app.utilities.enum import ComponentContext
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-from lightning.app.utilities.tree import breadth_first
-
-if TYPE_CHECKING:
- from lightning.app.core import LightningFlow
-
-COMPONENT_CONTEXT: Optional[ComponentContext] = None
-
-
-def _convert_paths_after_init(root: "LightningFlow"):
- """Converts the path attributes on a component to a dictionary.
-
- This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
- that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
- consumer.
-
- """
- from lightning.app.core import LightningFlow, LightningWork
- from lightning.app.storage.path import Path
-
- for component in breadth_first(root, types=(LightningFlow, LightningWork)):
- for attr in list(component.__dict__.keys()):
- value = getattr(component, attr)
- if isinstance(value, Path):
- delattr(component, attr)
- component._paths[attr] = value.to_dict()
-
-
-def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
- """Utility function to sanitize the state of a component.
-
- Sanitization enables the state to be deep-copied and hashed.
-
- """
- from lightning.app.storage import Drive, Path
- from lightning.app.storage.payload import _BasePayload
-
- def sanitize_path(path: Path) -> Path:
- path_copy = Path(path)
- path_copy._sanitize()
- return path_copy
-
- def sanitize_payload(payload: _BasePayload):
- return type(payload).from_dict(content=payload.to_dict())
-
- def sanitize_drive(drive: Drive) -> Dict:
- return drive.to_dict()
-
- def sanitize_cloud_compute(cloud_compute: CloudCompute) -> Dict:
- return cloud_compute.to_dict()
-
- state = apply_to_collection(state, dtype=Path, function=sanitize_path)
- state = apply_to_collection(state, dtype=_BasePayload, function=sanitize_payload)
- state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
- state = apply_to_collection(state, dtype=CloudCompute, function=sanitize_cloud_compute)
- return state
-
-
-def _state_to_json(state: Dict[str, Any]) -> Dict[str, Any]:
- """Utility function to make sure that state dict is json serializable."""
- from lightning.app.storage.path import Path
- from lightning.app.storage.payload import _BasePayload
-
- state_paths_cleaned = apply_to_collection(state, dtype=(Path, _BasePayload), function=lambda x: x.to_dict())
- return apply_to_collection(state_paths_cleaned, dtype=type(NotPresent), function=lambda x: None)
-
-
-def _set_context(name: Optional[str]) -> None:
- global COMPONENT_CONTEXT
- COMPONENT_CONTEXT = os.getenv("COMPONENT_CONTEXT") if name is None else ComponentContext(name)
-
-
-def _get_context() -> Optional[ComponentContext]:
- global COMPONENT_CONTEXT
- return COMPONENT_CONTEXT
-
-
-def _set_flow_context() -> None:
- global COMPONENT_CONTEXT
- COMPONENT_CONTEXT = ComponentContext.FLOW
-
-
-def _set_work_context() -> None:
- global COMPONENT_CONTEXT
- COMPONENT_CONTEXT = ComponentContext.WORK
-
-
-def _set_frontend_context() -> None:
- global COMPONENT_CONTEXT
- COMPONENT_CONTEXT = ComponentContext.FRONTEND
-
-
-def _is_flow_context() -> bool:
- global COMPONENT_CONTEXT
- return COMPONENT_CONTEXT == ComponentContext.FLOW
-
-
-def _is_work_context() -> bool:
- global COMPONENT_CONTEXT
- return COMPONENT_CONTEXT == ComponentContext.WORK
-
-
-def _is_frontend_context() -> bool:
- global COMPONENT_CONTEXT
- return COMPONENT_CONTEXT == ComponentContext.FRONTEND
-
-
-@contextmanager
-def _context(ctx: str) -> Generator[None, None, None]:
- """Set the global component context for the block below this context manager.
-
- The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
- See also :func:`_get_context`, :func:`_set_context`. For internal use only.
-
- """
- prev = _get_context()
- _set_context(ctx)
- yield
- _set_context(prev)
-
-
-def _validate_root_flow(flow: "LightningFlow") -> None:
- from lightning.app.core.flow import LightningFlow
-
- if not is_overridden("run", instance=flow, parent=LightningFlow):
- raise TypeError(
- "The root flow passed to `LightningApp` does not override the `run()` method. This is required. Please"
- f" implement `run()` in your `{flow.__class__.__name__}` class."
- )
diff --git a/src/lightning/app/utilities/data_structures.py b/src/lightning/app/utilities/data_structures.py
deleted file mode 100644
index 495c43fd0ea24..0000000000000
--- a/src/lightning/app/utilities/data_structures.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any, Dict, Optional
-
-
-class AttributeDict(Dict):
- """Extended dictionary accessible with dot notation.
-
- >>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})
- >>> ad.key1
- 1
- >>> ad.update({'my-key': 3.14})
- >>> ad.update(new_key=42)
- >>> ad.key1 = 2
- >>> ad
- "key1": 2
- "key2": abc
- "my-key": 3.14
- "new_key": 42
-
- """
-
- def __getattr__(self, key: str) -> Optional[Any]:
- try:
- return self[key]
- except KeyError as exp:
- raise AttributeError(f'Missing attribute "{key}"') from exp
-
- def __setattr__(self, key: str, val: Any) -> None:
- self[key] = val
-
- def __repr__(self) -> str:
- if not len(self):
- return ""
- max_key_length = max(len(str(k)) for k in self)
- tmp_name = "{:" + str(max_key_length + 3) + "s} {}"
- rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())]
- return "\n".join(rows)
diff --git a/src/lightning/app/utilities/dependency_caching.py b/src/lightning/app/utilities/dependency_caching.py
deleted file mode 100644
index 8cb389a20a164..0000000000000
--- a/src/lightning/app/utilities/dependency_caching.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import hashlib
-from pathlib import Path
-
-
-def get_hash(path: Path, chunk_num_blocks: int = 128) -> str:
- """Get the hash of a file."""
- h = hashlib.blake2b(digest_size=20)
- if not path.exists():
- raise FileNotFoundError(f"{path} does not exist")
- with path.open("rb") as f:
- for chunk in iter(lambda: f.read(chunk_num_blocks * h.block_size), b""):
- h.update(chunk)
- return h.hexdigest()
diff --git a/src/lightning/app/utilities/enum.py b/src/lightning/app/utilities/enum.py
deleted file mode 100644
index 1e5422289d220..0000000000000
--- a/src/lightning/app/utilities/enum.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import enum
-from datetime import datetime, timezone
-from typing import Optional
-
-
-class ComponentContext(enum.Enum):
- """Describes whether the current process is running LightningFlow or LightningWork."""
-
- FLOW = "flow"
- WORK = "work"
- FRONTEND = "frontend"
-
-
-class AppStage(enum.Enum):
- BLOCKING = "blocking"
- RUNNING = "running"
- RESTARTING = "restarting"
- STOPPING = "stopping"
- FAILED = "failed"
-
-
-class WorkFailureReasons:
- TIMEOUT = "timeout" # triggered when pending and wait timeout has been passed
- SPOT_RETRIVAL = "spot_retrival" # triggered when a SIGTERM signal is sent the spot instance work.
- USER_EXCEPTION = "user_exception" # triggered when an exception is raised by user code.
- INVALID_RETURN_VALUE = "invalid_return_value" # triggered when the return value isn't valid.
-
-
-class WorkStopReasons:
- SIGTERM_SIGNAL_HANDLER = "sigterm_signal_handler"
- PENDING = "pending"
-
-
-class WorkPendingReason(enum.Enum):
- IMAGE_BUILDING = "image_building"
- REQUESTING_RESOURCE = "requesting_ressource"
-
-
-class WorkStageStatus:
- NOT_STARTED = "not_started"
- STARTED = "started"
- STOPPED = "stopped"
- PENDING = "pending"
- RUNNING = "running"
- SUCCEEDED = "succeeded"
- FAILED = "failed"
-
-
-def make_status(stage: str, message: Optional[str] = None, reason: Optional[str] = None):
- status = {
- "stage": stage,
- "timestamp": datetime.now(tz=timezone.utc).timestamp(),
- }
- if message:
- status["message"] = message
- if reason:
- status["reason"] = reason
- return status
-
-
-class CacheCallsKeys:
- LATEST_CALL_HASH = "latest_call_hash"
-
-
-class OpenAPITags:
- APP_CLIENT_COMMAND = "app_client_command"
- APP_COMMAND = "app_command"
- APP_API = "app_api"
diff --git a/src/lightning/app/utilities/exceptions.py b/src/lightning/app/utilities/exceptions.py
deleted file mode 100644
index 3bb5eced46ed7..0000000000000
--- a/src/lightning/app/utilities/exceptions.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from json import JSONDecodeError, loads
-from typing import Any
-
-from click import ClickException, Context, Group
-from lightning_cloud.openapi.rest import ApiException
-
-
-class _ApiExceptionHandler(Group):
- """Attempts to convert ApiExceptions to ClickExceptions.
-
- This process clarifies the error for the user by:
- 1. Showing the error message from the lightning.ai servers,
- instead of showing the entire HTTP response
- 2. Suppressing long tracebacks
-
- However, if the ApiException cannot be decoded, or is not
- a 4xx error, the original ApiException will be re-raised.
-
- """
-
- def invoke(self, ctx: Context) -> Any:
- try:
- return super().invoke(ctx)
- except ApiException as api:
- exception_messages = []
- if 400 <= api.status < 500:
- try:
- body = loads(api.body)
- except JSONDecodeError:
- raise api
- exception_messages.append(body["message"])
- exception_messages.extend(body["details"])
- else:
- raise api
- raise ClickException("\n".join(exception_messages))
-
-
-class MisconfigurationException(Exception):
- """Exception used to inform users of misuse with Lightning."""
-
-
-class CacheMissException(Exception):
- """Exception used internally as a boundary to non-executed functions."""
-
-
-class ExitAppException(Exception):
- """Exception used by components to signal that App should exit."""
-
-
-class LightningComponentException(Exception):
- """Exception used to inform users of misuse with LightningComponent."""
-
-
-class InvalidPathException(Exception):
- """Exception used to inform users they are accessing an invalid path."""
-
-
-class LightningFlowException(Exception):
- """Exception used to inform users of misuse with LightningFlow."""
-
-
-class LightningWorkException(Exception):
- """Exception used to inform users of misuse with LightningWork."""
-
-
-class LightningPlatformException(Exception): # pragma: no cover
- """Exception used to inform users of issues related to platform the LightningApp is running on.
-
- It gets raised by the Lightning Launcher on the platform side when the app is running in the cloud, and is useful
- when framework or user code needs to catch exceptions specific to the platform, e.g., when resources exceed quotas.
-
- """
-
-
-class LightningAppStateException(Exception):
- """Exception to inform users of app state errors."""
-
-
-class LightningSigtermStateException(Exception):
- """Exception to propagate exception in work proxy."""
-
- def __init__(self, exit_code):
- self.exit_code = exit_code
-
-
-class LogLinesLimitExceeded(Exception):
- """Exception to inform the user that we've reached the maximum number of log lines."""
diff --git a/src/lightning/app/utilities/frontend.py b/src/lightning/app/utilities/frontend.py
deleted file mode 100644
index 29b6daef793f4..0000000000000
--- a/src/lightning/app/utilities/frontend.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import List, Optional
-
-from bs4 import BeautifulSoup
-
-
-@dataclass
-class AppInfo:
- title: Optional[str] = None
- favicon: Optional[str] = None
- description: Optional[str] = None
- image: Optional[str] = None
- # ensure the meta tags are correct or the UI might fail to load.
- meta_tags: Optional[List[str]] = None
- on_connect_end: Optional[str] = None
-
-
-def update_index_file(ui_root: str, info: Optional[AppInfo] = None, root_path: str = "") -> None:
- import shutil
- from pathlib import Path
-
- entry_file = Path(ui_root) / "index.html"
- original_file = Path(ui_root) / "index.original.html"
-
- if not original_file.exists():
- shutil.copyfile(entry_file, original_file) # keep backup
- else:
- # revert index.html in case it was modified after creating original.html
- shutil.copyfile(original_file, entry_file)
-
- if info:
- with original_file.open() as f:
- original = f.read()
-
- with entry_file.open("w") as f:
- f.write(_get_updated_content(original=original, root_path=root_path, info=info))
-
- if root_path:
- root_path_without_slash = root_path.replace("/", "", 1) if root_path.startswith("/") else root_path
- src_dir = Path(ui_root)
- dst_dir = src_dir / root_path_without_slash
-
- if dst_dir.exists():
- shutil.rmtree(dst_dir, ignore_errors=True)
- # copy everything except the current root_path, this is to fix a bug if user specifies
- # /abc at first and then /abc/def, server don't start
- # ideally we should copy everything except custom root_path that user passed.
- shutil.copytree(src_dir, dst_dir, ignore=shutil.ignore_patterns(f"{root_path_without_slash}*"))
-
-
-def _get_updated_content(original: str, root_path: str, info: AppInfo) -> str:
- soup = BeautifulSoup(original, "html.parser")
-
- # replace favicon
- if info.favicon:
- soup.find("link", {"rel": "icon"}).attrs["href"] = info.favicon
-
- if info.title is not None:
- soup.find("title").string = info.title
-
- if info.description:
- soup.find("meta", {"name": "description"}).attrs["content"] = info.description
-
- if info.image:
- soup.find("meta", {"property": "og:image"}).attrs["content"] = info.image
-
- if info.meta_tags:
- for meta in info.meta_tags:
- soup.find("head").append(BeautifulSoup(meta, "html.parser"))
-
- if root_path:
- # this will be used by lightning app ui to add root_path to add requests
- soup.find("head").append(BeautifulSoup(f'', "html.parser"))
-
- return str(soup).replace("/static", f"{root_path}/static")
diff --git a/src/lightning/app/utilities/git.py b/src/lightning/app/utilities/git.py
deleted file mode 100644
index 1293a2a5095fa..0000000000000
--- a/src/lightning/app/utilities/git.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import subprocess
-from pathlib import Path
-from typing import List, Union
-
-# TODO - github utilities are already defined in GridSDK use that?
-
-
-def execute_git_command(args: List[str], cwd=None) -> str:
- """Executes a git command. This is expected to return a single string back.
-
- Returns
- -------
- output: str
- String combining stdout and stderr.
-
- """
- process = subprocess.run(["git"] + args, capture_output=True, text=True, cwd=cwd, check=False)
- return process.stdout.strip() + process.stderr.strip()
-
-
-def get_dir_name(cwd=None) -> str:
- github_repository = execute_git_command(["config", "--get", "remote.origin.url"], cwd=cwd)
- if github_repository and "github.com" in github_repository:
- return github_repository.split("/")[-1].split(".")[0]
- raise RuntimeError("Only work with github repositories.")
-
-
-def check_github_repository(cwd=None) -> bool:
- """Checks if the active directory is a GitHub repository."""
- github_repository = execute_git_command(["config", "--get", "remote.origin.url"], cwd=cwd)
-
- if not github_repository or "github.com" not in github_repository:
- return False
- return True
-
-
-def get_git_relative_path(file: Union[str, Path]) -> str:
- """Finds the relative path of the file to the git root."""
- if not check_github_repository():
- raise ValueError("Not a GitHub repository.")
- abs_path = Path(file).absolute()
- repository_path = execute_git_command(["rev-parse", "--show-toplevel"])
- return str(abs_path.relative_to(repository_path))
-
-
-def check_if_remote_head_is_different() -> Union[bool, None]:
- """Checks if remote git repository is different than the version available locally.
-
- This only compares the local SHA to the HEAD commit of a given branch. This check won't be used if user isn't in a
- HEAD locally.
-
- """
- # Check SHA values.
- local_sha = execute_git_command(["rev-parse", "@"])
- remote_sha = execute_git_command(["rev-parse", r"@{u}"])
- base_sha = execute_git_command(["merge-base", "@", r"@{u}"])
-
- # Whenever a SHA is not avaialble, just return.
- if any("fatal" in f for f in (local_sha, remote_sha, base_sha)):
- return None
-
- return local_sha not in (remote_sha, base_sha)
-
-
-def has_uncommitted_files() -> bool:
- """Checks if user has uncommited files in local repository.
-
- If there are uncommited files, then show a prompt indicating that uncommited files exist locally.
-
- """
- files = execute_git_command(["update-index", "--refresh"])
- return bool(files)
diff --git a/src/lightning/app/utilities/imports.py b/src/lightning/app/utilities/imports.py
deleted file mode 100644
index 33d6c259e09b5..0000000000000
--- a/src/lightning/app/utilities/imports.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""General utilities."""
-
-import functools
-import os
-import platform
-import sys
-import warnings
-from typing import Any, List, Union
-
-from lightning_utilities.core.imports import module_available
-from packaging.requirements import Marker, Requirement
-
-try:
- from importlib import metadata
-except ImportError:
- # Python < 3.8
- import importlib_metadata as metadata # type: ignore
-
-
-def _get_extras(extras: str) -> str:
- """Get the given extras as a space delimited string.
-
- Used by the platform to install cloud extras in the cloud.
-
- """
- from lightning.app import __package_name__
-
- requirements = {r: Requirement(r) for r in metadata.requires(__package_name__)}
- marker = Marker(f'extra == "{extras}"')
- requirements = [r for r, req in requirements.items() if str(req.marker) == str(marker)]
-
- if requirements:
- requirements = [f"'{r.split(';')[0].strip()}'" for r in requirements]
- return " ".join(requirements)
- return ""
-
-
-def requires(module_paths: Union[str, List]):
- if not isinstance(module_paths, list):
- module_paths = [module_paths]
-
- def decorator(func):
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any):
- unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
- if any(unavailable_modules):
- is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
- msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
- if is_lit_testing:
- warnings.warn(msg)
- else:
- raise ModuleNotFoundError(msg)
- return func(*args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
-# TODO: Automatically detect dependencies
-def _is_redis_available() -> bool:
- return module_available("redis")
-
-
-def _is_torch_available() -> bool:
- return module_available("torch")
-
-
-def _is_pytorch_lightning_available() -> bool:
- return module_available("lightning.pytorch")
-
-
-def _is_torchvision_available() -> bool:
- return module_available("torchvision")
-
-
-def _is_json_argparse_available() -> bool:
- return module_available("jsonargparse")
-
-
-def _is_streamlit_available() -> bool:
- return module_available("streamlit")
-
-
-def _is_param_available() -> bool:
- return module_available("param")
-
-
-def _is_streamlit_tensorboard_available() -> bool:
- return module_available("streamlit_tensorboard")
-
-
-def _is_gradio_available() -> bool:
- return module_available("gradio")
-
-
-def _is_lightning_flash_available() -> bool:
- return module_available("flash")
-
-
-def _is_pil_available() -> bool:
- return module_available("PIL")
-
-
-def _is_numpy_available() -> bool:
- return module_available("numpy")
-
-
-def _is_docker_available() -> bool:
- return module_available("docker")
-
-
-def _is_jinja2_available() -> bool:
- return module_available("jinja2")
-
-
-def _is_playwright_available() -> bool:
- return module_available("playwright")
-
-
-def _is_s3fs_available() -> bool:
- return module_available("s3fs")
-
-
-def _is_sqlmodel_available() -> bool:
- return module_available("sqlmodel")
-
-
-def _is_aiohttp_available() -> bool:
- return module_available("aiohttp")
-
-
-_CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False))
-_IS_WINDOWS = platform.system() == "Windows"
-_IS_MACOS = sys.platform == "darwin"
diff --git a/src/lightning/app/utilities/introspection.py b/src/lightning/app/utilities/introspection.py
deleted file mode 100644
index a794c6ca9fc12..0000000000000
--- a/src/lightning/app/utilities/introspection.py
+++ /dev/null
@@ -1,400 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import ast
-import inspect
-from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union
-
-if TYPE_CHECKING:
- from lightning.app.core import LightningFlow, LightningWork
-
-
-class LightningVisitor(ast.NodeVisitor):
- """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
- define class_name and implement the analyze_class_def method.
-
- Attributes
- ----------
- class_name: str
- Name of class to identify, to be defined in subclasses.
-
- """
-
- class_name: Optional[str] = None
-
- def __init__(self):
- self.found: List[Dict[str, Any]] = []
-
- def analyze_class_def(self, node: ast.ClassDef) -> Dict[str, Any]:
- return {}
-
- def visit_ClassDef(self, node: ast.ClassDef) -> None:
- bases = []
- for base in node.bases:
- if type(base) == ast.Attribute:
- bases.append(base.attr)
- elif type(base) == ast.Name:
- bases.append(base.id)
- if self.class_name in bases:
- entry = {"name": node.name, "type": self.class_name}
- entry.update(self.analyze_class_def(node))
- self.found.append(entry)
-
-
-class LightningModuleVisitor(LightningVisitor):
- """Finds Lightning modules based on class inheritance.
-
- Attributes
- ----------
- class_name: Optional[str]
- Name of class to identify.
- methods: Set[str]
- Names of methods that are part of the LightningModule API.
- hooks: Set[str]
- Names of hooks that are part of the LightningModule API.
-
- """
-
- class_name: Optional[str] = "LightningModule"
-
- methods: Set[str] = {
- "configure_optimizers",
- "forward",
- "freeze",
- "log",
- "log_dict",
- "print",
- "save_hyperparameters",
- "test_step",
- "test_step_end",
- "to_onnx",
- "to_torchscript",
- "training_step",
- "training_step_end",
- "unfreeze",
- "validation_step",
- "validation_step_end",
- }
-
- hooks: Set[str] = {
- "backward",
- "get_progress_bar_dict",
- "manual_backward",
- "manual_optimizer_step",
- "on_after_backward",
- "on_before_zero_grad",
- "on_fit_start",
- "on_fit_end",
- "on_load_checkpoint",
- "on_save_checkpoint",
- "on_pretrain_routine_start",
- "on_pretrain_routine_end",
- "on_test_batch_start",
- "on_test_batch_end",
- "on_test_epoch_start",
- "on_test_epoch_end",
- "on_train_batch_start",
- "on_train_batch_end",
- "on_train_epoch_start",
- "on_train_epoch_end",
- "on_validation_batch_start",
- "on_validation_batch_end",
- "on_validation_epoch_start",
- "on_validation_epoch_end",
- "optimizer_step",
- "optimizer_zero_grad",
- "prepare_data",
- "setup",
- "teardown",
- "train_dataloader",
- "val_dataloader",
- "test_dataloader",
- "transfer_batch_to_device",
- }
-
-
-class LightningDataModuleVisitor(LightningVisitor):
- """Finds Lightning data modules based on class inheritance.
-
- Attributes
- ----------
- class_name: Optional[str]
- Name of class to identify.
- methods: Set[str]
- Names of methods that are part of the LightningDataModule API.
-
- """
-
- class_name = "LightningDataModule"
-
- methods: Set[str] = {
- "prepare_data",
- "setup",
- "train_dataloader",
- "val_dataloader",
- "test_dataloader",
- "transfer_batch_to_device",
- }
-
-
-class LightningLoggerVisitor(LightningVisitor):
- """Finds Lightning loggers based on class inheritance.
-
- Attributes
- ----------
- class_name: Optional[str]
- Name of class to identify.
- methods: Set[str]
- Names of methods that are part of the Logger API.
-
- """
-
- class_name = "Logger"
-
- methods: Set[str] = {"log_hyperparams", "log_metrics"}
-
-
-class LightningCallbackVisitor(LightningVisitor):
- """Finds Lightning callbacks based on class inheritance.
-
- Attributes
- ----------
- class_name: Optional[str]
- Name of class to identify.
- methods: Set[str]
- Names of methods that are part of the Logger API.
-
- """
-
- class_name = "Callback"
-
- methods: Set[str] = {
- "setup",
- "teardown",
- "on_init_start",
- "on_init_end",
- "on_fit_start",
- "on_fit_end",
- "on_sanity_check_start",
- "on_sanity_check_end",
- "on_train_batch_start",
- "on_train_batch_end",
- "on_train_epoch_start",
- "on_train_epoch_end",
- "on_validation_epoch_start",
- "on_validation_epoch_end",
- "on_test_epoch_start",
- "on_test_epoch_end",
- "on_epoch_start",
- "on_epoch_end",
- "on_batch_start",
- "on_validation_batch_start",
- "on_validation_batch_end",
- "on_test_batch_start",
- "on_test_batch_end",
- "on_batch_end",
- "on_train_start",
- "on_train_end",
- "on_pretrain_routine_start",
- "on_pretrain_routine_end",
- "on_validation_start",
- "on_validation_end",
- "on_test_start",
- "on_test_end",
- "on_keyboard_interrupt",
- "on_save_checkpoint",
- "on_load_checkpoint",
- }
-
-
-class LightningStrategyVisitor(LightningVisitor):
- """Finds Lightning callbacks based on class inheritance.
-
- Attributes
- ----------
- class_name: Optional[str]
- Name of class to identify.
- methods: Set[str]
- Names of methods that are part of the Logger API.
-
- """
-
- class_name = "Strategy"
-
- methods: Set[str] = {
- "setup",
- "train",
- "training_step",
- "validation_step",
- "test_step",
- "backward",
- "barrier",
- "broadcast",
- "sync_tensor",
- }
-
-
-class LightningTrainerVisitor(LightningVisitor):
- class_name = "Trainer"
-
-
-class LightningCLIVisitor(LightningVisitor):
- class_name = "LightningCLI"
-
-
-class LightningPrecisionPluginVisitor(LightningVisitor):
- class_name = "PrecisionPlugin"
-
-
-class LightningAcceleratorVisitor(LightningVisitor):
- class_name = "Accelerator"
-
-
-class TorchMetricVisitor(LightningVisitor):
- class_name = "Metric"
-
-
-class FabricVisitor(LightningVisitor):
- class_name = "Fabric"
-
-
-class LightningProfilerVisitor(LightningVisitor):
- class_name = "Profiler"
-
-
-class Scanner:
- """Finds relevant Lightning objects in files in the file system.
-
- Attributes
- ----------
- visitor_classes: List[Type]
- List of visitor classes to use when traversing files.
- Parameters
- ----------
- path: str
- Path to file, or directory where to look for files to scan.
- glob_pattern: str
- Glob pattern to use when looking for files in the path,
- applied when path is a directory. Default is "**/*.py".
-
- """
-
- # TODO: Finalize introspecting the methods from all the discovered methods.
- visitor_classes: List[Type] = [
- LightningCLIVisitor,
- LightningTrainerVisitor,
- LightningModuleVisitor,
- LightningDataModuleVisitor,
- LightningCallbackVisitor,
- LightningStrategyVisitor,
- LightningPrecisionPluginVisitor,
- LightningAcceleratorVisitor,
- LightningLoggerVisitor,
- TorchMetricVisitor,
- FabricVisitor,
- LightningProfilerVisitor,
- ]
-
- def __init__(self, path: str, glob_pattern: str = "**/*.py"):
- path_ = Path(path)
- if path_.is_dir():
- self.paths = path_.glob(glob_pattern)
- else:
- self.paths = [path_]
-
- self.modules_found: List[Dict[str, Any]] = []
-
- def has_class(self, cls) -> bool:
- # This method isn't strong enough as it is using only `ImportFrom`.
- # TODO: Use proper classDef scanning.
- classes = []
-
- for path in self.paths:
- try:
- module = ast.parse(path.open().read())
- except SyntaxError:
- print(f"Error while parsing {path}: SKIPPING")
- continue
-
- for node in ast.walk(module):
- if isinstance(node, ast.ImportFrom):
- for import_from_cls in node.names:
- classes.append(import_from_cls.name)
-
- if isinstance(node, ast.Call):
- cls_name = getattr(node.func, "attr", None)
- if cls_name:
- classes.append(cls_name)
-
- return cls.__name__ in classes
-
- def scan(self) -> List[Dict[str, str]]:
- """Finds Lightning modules in files, returning importable objects.
-
- Returns
- -------
- List[Dict[str, Any]]
- List of dicts containing all metadata required
- to import modules found.
-
- """
- modules_found: Dict[str, List[Dict[str, Any]]] = {}
-
- for path in self.paths:
- try:
- module = ast.parse(path.open().read())
- except SyntaxError:
- print(f"Error while parsing {path}: SKIPPING")
- continue
- for visitor_class in self.visitor_classes:
- visitor = visitor_class()
- visitor.visit(module)
- if not visitor.found:
- continue
- _path = str(path)
- ns_info = {
- "file": _path,
- "namespace": _path.replace("/", ".").replace(".py", ""),
- }
- modules_found[visitor_class.class_name] = [{**entry, **ns_info} for entry in visitor.found]
-
- return modules_found
-
-
-def _is_method_context(component: Union["LightningFlow", "LightningWork"], selected_caller_name: str) -> bool:
- """Checks whether the call to a component originates from within the context of the component's ``__init__``
- method."""
- frame = inspect.currentframe().f_back
-
- while frame is not None:
- caller_name = frame.f_code.co_name
- caller_self = frame.f_locals.get("self")
- if caller_name == selected_caller_name and caller_self is component:
- # the call originates from a frame under component.__init__
- return True
- frame = frame.f_back
-
- return False
-
-
-def _is_init_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
- """Checks whether the call to a component originates from within the context of the component's ``__init__``
- method."""
- return _is_method_context(component, "__init__")
-
-
-def _is_run_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
- """Checks whether the call to a component originates from within the context of the component's ``run`` method."""
- return _is_method_context(component, "run") or _is_method_context(component, "load_state_dict")
diff --git a/src/lightning/app/utilities/layout.py b/src/lightning/app/utilities/layout.py
deleted file mode 100644
index 6da0b2ca412ea..0000000000000
--- a/src/lightning/app/utilities/layout.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import warnings
-from typing import Dict, List, Union
-
-import lightning.app
-from lightning.app.frontend.frontend import Frontend
-from lightning.app.utilities.app_helpers import _MagicMockJsonSerializable, is_overridden
-from lightning.app.utilities.cloud import is_running_in_cloud
-
-
-def _add_comment_to_literal_code(method, contains, comment):
- """Inspects a method's code and adds a message to it.
-
- This is a nice to have, so if it fails for some reason, it shouldn't affect the program.
-
- """
- try:
- lines = inspect.getsource(method)
- lines = lines.split("\n")
- idx_list = [i for i, x in enumerate(lines) if contains in x]
- for i in idx_list:
- line = lines[i]
- line += comment
- lines[i] = line
-
- return "\n".join(lines)
-
- except Exception:
- return ""
-
-
-def _collect_layout(app: "lightning.app.LightningApp", flow: "lightning.app.LightningFlow") -> Union[Dict, List[Dict]]:
- """Process the layout returned by the ``configure_layout()`` method in each flow."""
- layout = flow.configure_layout()
-
- if isinstance(layout, Frontend):
- frontend = layout
- frontend.flow = flow
- app.frontends.setdefault(flow.name, frontend)
-
- # When running locally, the target will get overwritten by the dispatcher when launching the frontend servers
- # When running in the cloud, the frontend code will construct the URL based on the flow name
- return flow._layout
- if isinstance(layout, _MagicMockJsonSerializable):
- # The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI
- app.frontends.setdefault(flow.name, "mock")
- return flow._layout
- if isinstance(layout, dict):
- layout = _collect_content_layout([layout], app, flow)
- elif isinstance(layout, (list, tuple)) and all(isinstance(item, dict) for item in layout):
- layout = _collect_content_layout(layout, app, flow)
- else:
- lines = _add_comment_to_literal_code(flow.configure_layout, contains="return", comment=" <------- this guy")
- raise TypeError(
- f"""
- The return value of configure_layout() in `{flow.__class__.__name__}` is an unsupported layout format:
- \n{lines}
-
- Return either an object of type {Frontend} (e.g., StreamlitFrontend, StaticWebFrontend):
- def configure_layout(self):
- return la.frontend.Frontend(...)
-
- OR a single dict:
- def configure_layout(self):
- tab1 = {{'name': 'tab name', 'content': self.a_component}}
- return tab1
-
- OR a list of dicts:
- def configure_layout(self):
- tab1 = {{'name': 'tab name 1', 'content': self.component_a}}
- tab2 = {{'name': 'tab name 2', 'content': self.component_b}}
- return [tab1, tab2]
-
- (see the docs for `LightningFlow.configure_layout`).
- """
- )
-
- return layout
-
-
-def _collect_content_layout(
- layout: List[Dict], app: "lightning.app.LightningApp", flow: "lightning.app.LightningFlow"
-) -> Union[List[Dict], Dict]:
- """Process the layout returned by the ``configure_layout()`` method if the returned format represents an
- aggregation of child layouts."""
- for entry in layout:
- if "content" not in entry:
- raise ValueError(
- f"A dictionary returned by `{flow.__class__.__name__}.configure_layout()` is missing a key 'content'."
- f" For the value, choose either a reference to a child flow or a URla."
- )
- if isinstance(entry["content"], str): # assume this is a URL
- url = entry["content"]
- if url.startswith("/"):
- # The URL isn't fully defined yet. Looks something like ``self.work.url + /something``.
- entry["target"] = ""
- else:
- entry["target"] = url
- if url.startswith("http://") and is_running_in_cloud():
- warnings.warn(
- f"You configured an http link {url[:32]}... but it won't be accessible in the cloud."
- f" Consider replacing 'http' with 'https' in the link above."
- )
-
- elif isinstance(entry["content"], lightning.app.LightningFlow):
- entry["content"] = entry["content"].name
-
- elif isinstance(entry["content"], lightning.app.LightningWork):
- work = entry["content"]
- work_layout = _collect_work_layout(work)
-
- if work_layout is None:
- entry["content"] = ""
- elif isinstance(work_layout, str):
- entry["content"] = work_layout
- entry["target"] = work_layout
- elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
- if len(layout) > 1:
- lines = _add_comment_to_literal_code(
- flow.configure_layout, contains="return", comment=" <------- this guy"
- )
- m = f"""
- The return value of configure_layout() in `{flow.__class__.__name__}` is an
- unsupported format:
- \n{lines}
-
- The tab containing a `{work.__class__.__name__}` must be the only tab in the
- layout of this flow.
-
- (see the docs for `LightningWork.configure_layout`).
- """
- raise TypeError(m)
-
- if isinstance(work_layout, Frontend):
- # If the work returned a frontend, treat it as belonging to the flow.
- # NOTE: This could evolve in the future to run the Frontend directly in the work machine.
- frontend = work_layout
- frontend.flow = flow
- elif isinstance(work_layout, _MagicMockJsonSerializable):
- # The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI.
- frontend = "mock"
-
- app.frontends.setdefault(flow.name, frontend)
- return flow._layout
-
- elif isinstance(entry["content"], _MagicMockJsonSerializable):
- # The import was mocked, we just record dummy content so that `is_headless` knows there is a UI
- entry["content"] = "mock"
- entry["target"] = "mock"
- else:
- m = f"""
- A dictionary returned by `{flow.__class__.__name__}.configure_layout()` contains an unsupported entry.
-
- {{'content': {repr(entry["content"])}}}
-
- Set the `content` key to a child flow or a URL, for example:
-
- class {flow.__class__.__name__}(LightningFlow):
- def configure_layout(self):
- return {{'content': childFlow OR childWork OR 'http://some/url'}}
- """
- raise ValueError(m)
- return layout
-
-
-def _collect_work_layout(work: "lightning.app.LightningWork") -> Union[None, str, Frontend, _MagicMockJsonSerializable]:
- """Check if ``configure_layout`` is overridden on the given work and return the work layout (either a string, a
- ``Frontend`` object, or an instance of a mocked import).
-
- Args:
- work: The work to collect the layout for.
-
- Raises:
- TypeError: If the value returned by ``configure_layout`` is not of a supported format.
-
- """
- work_layout = work.configure_layout() if is_overridden("configure_layout", work) else work.url
-
- if work_layout is None:
- return None
- if isinstance(work_layout, str):
- url = work_layout
- # The URL isn't fully defined yet. Looks something like ``self.work.url + /something``.
- if url and not url.startswith("/"):
- return url
- return ""
- if isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
- return work_layout
- raise TypeError(
- f"""
- The value returned by `{work.__class__.__name__}.configure_layout()` is of an unsupported type.
-
- {repr(work_layout)}
-
- Return a `Frontend` or a URL string, for example:
-
- class {work.__class__.__name__}(LightningWork):
- def configure_layout(self):
- return MyFrontend() OR 'http://some/url'
- """
- )
diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py
deleted file mode 100644
index b2016194bfe57..0000000000000
--- a/src/lightning/app/utilities/load_app.py
+++ /dev/null
@@ -1,304 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import os
-import sys
-import traceback
-import types
-from contextlib import contextmanager
-from copy import copy
-from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
-
-from lightning.app.utilities.exceptions import MisconfigurationException
-
-if TYPE_CHECKING:
- from lightning.app.core import LightningApp, LightningFlow, LightningWork
- from lightning.app.plugin.plugin import LightningPlugin
-
-from lightning.app.utilities.app_helpers import Logger, _mock_missing_imports
-
-logger = Logger(__name__)
-
-
-def _prettifiy_exception(filepath: str):
- """Pretty print the exception that occurred when loading the app."""
- # we want to format the exception as if no frame was on top.
- exp, val, tb = sys.exc_info()
- listing = traceback.format_exception(exp, val, tb)
- # remove the entry for the first frame
- del listing[1]
- listing = [
- f"Found an exception when loading your application from {filepath}. Please, resolve it to run your app.\n\n"
- ] + listing
- logger.error("".join(listing))
- sys.exit(1)
-
-
-def _load_objects_from_file(
- filepath: str,
- target_type: Type,
- raise_exception: bool = False,
- mock_imports: bool = False,
- env_vars: Dict[str, str] = {},
-) -> Tuple[List[Any], types.ModuleType]:
- """Load all of the top-level objects of the given type from a file.
-
- Args:
- filepath: The file to load from.
- target_type: The type of object to load.
- raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
- mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
- be loaded without installing dependencies.
-
- """
-
- # Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
-
- # In order for imports to work in a non-package, Python normally adds the current working directory to the
- # system path, not however when running from an entry point like the `lightning` CLI command. So we do it manually:
- with _patch_sys_path(os.path.dirname(os.path.abspath(filepath))):
- code = _create_code(filepath)
- with _create_fake_main_module(filepath) as module:
- try:
- with _add_to_env(env_vars), _patch_sys_argv():
- if mock_imports:
- with _mock_missing_imports():
- exec(code, module.__dict__) # noqa: S102
- else:
- exec(code, module.__dict__) # noqa: S102
- except Exception as ex:
- if raise_exception:
- raise ex
- _prettifiy_exception(filepath)
-
- return [v for v in module.__dict__.values() if isinstance(v, target_type)], module
-
-
-def _load_plugin_from_file(filepath: str) -> "LightningPlugin":
- from lightning.app.plugin.plugin import LightningPlugin
-
- # TODO: Plugin should be run in the context of the created main module here
- plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False)
-
- if len(plugins) > 1:
- raise RuntimeError(f"There should not be multiple plugins instantiated within the file. Found {plugins}")
- if len(plugins) == 1:
- return plugins[0]
-
- raise RuntimeError(f"The provided file {filepath} does not contain a Plugin.")
-
-
-def load_app_from_file(
- filepath: str,
- raise_exception: bool = False,
- mock_imports: bool = False,
- env_vars: Dict[str, str] = {},
-) -> "LightningApp":
- """Load a LightningApp from a file.
-
- Arguments:
- filepath: The path to the file containing the LightningApp.
- raise_exception: If True, raise an exception if the app cannot be loaded.
-
- """
- from lightning.app.core.app import LightningApp
-
- apps, main_module = _load_objects_from_file(
- filepath, LightningApp, raise_exception=raise_exception, mock_imports=mock_imports, env_vars=env_vars
- )
-
- # TODO: Remove this, downstream code shouldn't depend on side-effects here but it does
- sys.path.append(os.path.dirname(os.path.abspath(filepath)))
- sys.modules["__main__"] = main_module
-
- if len(apps) > 1:
- raise MisconfigurationException(f"There should not be multiple apps instantiated within a file. Found {apps}")
- if len(apps) == 1:
- return apps[0]
-
- raise MisconfigurationException(
- f"The provided file {filepath} does not contain a LightningApp. Instantiate your app at the module level"
- " like so: `app = LightningApp(flow, ...)`"
- )
-
-
-def _new_module(name):
- """Create a new module with the given name."""
- return types.ModuleType(name)
-
-
-def open_python_file(filename):
- """Open a read-only Python file taking proper care of its encoding.
-
- In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
- headers in their source files with their own encodings. In that case, we should respect the author's encoding.
-
- """
- import tokenize
-
- if hasattr(tokenize, "open"): # Added in Python 3.2
- # Open file respecting PEP263 encoding. If no encoding header is
- # found, opens as utf-8.
- return tokenize.open(filename)
- return open(filename, encoding="utf-8") # noqa: SIM115
-
-
-def _create_code(script_path: str):
- with open_python_file(script_path) as f:
- filebody = f.read()
-
- return compile(
- filebody,
- # Pass in the file path so it can show up in exceptions.
- script_path,
- # We're compiling entire blocks of Python, so we need "exec"
- # mode (as opposed to "eval" or "single").
- mode="exec",
- # Don't inherit any flags or "future" statements.
- flags=0,
- dont_inherit=1,
- # Use the default optimization options.
- optimize=-1,
- )
-
-
-@contextmanager
-def _create_fake_main_module(script_path):
- # Create fake module. This gives us a name global namespace to
- # execute the code in.
- module = _new_module("__main__")
-
- # Install the fake module as the __main__ module. This allows
- # the pickle module to work inside the user's code, since it now
- # can know the module where the pickled objects stem from.
- # IMPORTANT: This means we can't use "if __name__ == '__main__'" in
- # our code, as it will point to the wrong module!!!
- old_main_module = sys.modules["__main__"]
- sys.modules["__main__"] = module
-
- # Add special variables to the module's globals dict.
- # Note: The following is a requirement for the CodeHasher to
- # work correctly. The CodeHasher is scoped to
- # files contained in the directory of __main__.__file__, which we
- # assume is the main script directory.
- module.__dict__["__file__"] = os.path.abspath(script_path)
-
- try:
- yield module
- finally:
- sys.modules["__main__"] = old_main_module
-
-
-@contextmanager
-def _patch_sys_path(append):
- """A context manager that appends the given value to the path once entered.
-
- Args:
- append: The value to append to the path.
-
- """
- if append in sys.path:
- yield
- return
-
- sys.path.append(append)
-
- try:
- yield
- finally:
- sys.path.remove(append)
-
-
-@contextmanager
-def _add_to_env(envs: Dict[str, str]):
- """This function adds the given environment variables to the current environment."""
- original_envs = dict(os.environ)
- os.environ.update(envs)
-
- try:
- yield
- finally:
- os.environ.clear()
- os.environ.update(original_envs)
-
-
-@contextmanager
-def _patch_sys_argv():
- """This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed everything
- else before executing the user app script.
-
- The command: ``lightning_app run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
- ``app.py --use_gpu``
-
- """
- from lightning.app.cli.lightning_cli import run_app
-
- original_argv = copy(sys.argv)
- # 1: Remove the CLI command
- if sys.argv[:3] == ["lightning", "run", "app"]:
- sys.argv = sys.argv[3:]
-
- if "--app_args" not in sys.argv:
- # 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
- new_argv = sys.argv[:1]
- else:
- # 3: Collect all the arguments from the CLI
- options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
- argv_slice = sys.argv
- # 4: Find the index of `app_args`
- first_index = argv_slice.index("--app_args") + 1
- # 5: Find the next argument from the CLI if any.
- matches = [
- argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
- ]
- last_index = len(argv_slice) if not matches else min(matches)
- # 6: last_index is either the fully command or the latest match from the CLI options.
- new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]
-
- # 7: Patch the command
- sys.argv = new_argv
-
- try:
- yield
- finally:
- # 8: Restore the command
- sys.argv = original_argv
-
-
-def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
- from lightning.app.core import LightningWork
-
- extras = {}
-
- if isinstance(obj, LightningWork):
- extras = {
- "local_build_config": obj.local_build_config.to_dict(),
- "cloud_build_config": obj.cloud_build_config.to_dict(),
- "cloud_compute": obj.cloud_compute.to_dict(),
- }
-
- return dict(
- affiliation=obj.name.split("."),
- cls_name=obj.__class__.__name__,
- module=obj.__module__,
- docstring=inspect.getdoc(obj.__init__),
- **extras,
- )
-
-
-def extract_metadata_from_app(app: "LightningApp") -> List:
- metadata = {flow.name: component_to_metadata(flow) for flow in app.flows}
- metadata.update({work.name: component_to_metadata(work) for work in app.works})
- return [metadata[key] for key in sorted(metadata.keys())]
diff --git a/src/lightning/app/utilities/log.py b/src/lightning/app/utilities/log.py
deleted file mode 100644
index 0e4f94e292517..0000000000000
--- a/src/lightning/app/utilities/log.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from pathlib import Path
-
-from lightning.app.storage.path import _storage_root_dir
-
-
-def get_logfile(filename: str = "logs.log") -> Path:
- log_dir = Path(_storage_root_dir(), "frontend")
- log_dir.mkdir(parents=True, exist_ok=True)
- return log_dir / filename
diff --git a/src/lightning/app/utilities/log_helpers.py b/src/lightning/app/utilities/log_helpers.py
deleted file mode 100644
index b7777d94e49e4..0000000000000
--- a/src/lightning/app/utilities/log_helpers.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from datetime import datetime
-from json import JSONDecodeError
-
-from websocket import WebSocketApp
-
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-# This is a superclass to inherit log entry classes from it:
-# it implements magic methods to sort logs by timestamps.
-@dataclass
-class _OrderedLogEntry:
- message: str
- timestamp: datetime
-
- def __ge__(self, other: "_OrderedLogEntry") -> bool:
- return self.timestamp >= other.timestamp
-
- def __gt__(self, other: "_OrderedLogEntry") -> bool:
- return self.timestamp > other.timestamp
-
-
-# A general error callback for log reading, prints most common types of possible errors.
-def _error_callback(ws_app: WebSocketApp, error: Exception):
- errors = {
- KeyError: "Malformed log message, missing key",
- JSONDecodeError: "Malformed log message",
- TypeError: "Malformed log format",
- ValueError: "Malformed date format",
- }
- logger.error(f"⚡ Error while reading logs ({errors.get(type(error), 'Unknown')}), {error}")
- ws_app.close()
diff --git a/src/lightning/app/utilities/login.py b/src/lightning/app/utilities/login.py
deleted file mode 100644
index eb03585eddc03..0000000000000
--- a/src/lightning/app/utilities/login.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import json
-import os
-import pathlib
-from dataclasses import dataclass
-from enum import Enum
-from time import sleep
-from typing import Optional
-from urllib.parse import urlencode
-
-import click
-import requests
-import uvicorn
-from fastapi import FastAPI, Query, Request
-from starlette.background import BackgroundTask
-from starlette.responses import RedirectResponse
-
-from lightning.app.core.constants import LIGHTNING_CREDENTIAL_PATH, get_lightning_cloud_url
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.network import find_free_network_port
-
-logger = Logger(__name__)
-
-
-class Keys(Enum):
- USERNAME = "LIGHTNING_USERNAME"
- USER_ID = "LIGHTNING_USER_ID"
- API_KEY = "LIGHTNING_API_KEY"
-
- @property
- def suffix(self):
- return self.value.lstrip("LIGHTNING_").lower()
-
-
-@dataclass
-class Auth:
- username: Optional[str] = None
- user_id: Optional[str] = None
- api_key: Optional[str] = None
-
- secrets_file = pathlib.Path(LIGHTNING_CREDENTIAL_PATH)
-
- def load(self) -> bool:
- """Load credentials from disk and update properties with credentials.
-
- Returns
- ----------
- True if credentials are available.
-
- """
- if not self.secrets_file.exists():
- logger.debug("Credentials file not found.")
- return False
- with self.secrets_file.open() as creds:
- credentials = json.load(creds)
- for key in Keys:
- setattr(self, key.suffix, credentials.get(key.suffix, None))
- return True
-
- def save(self, token: str = "", user_id: str = "", api_key: str = "", username: str = "") -> None:
- """Save credentials to disk."""
- self.secrets_file.parent.mkdir(exist_ok=True, parents=True)
- with self.secrets_file.open("w") as f:
- json.dump(
- {
- f"{Keys.USERNAME.suffix}": username,
- f"{Keys.USER_ID.suffix}": user_id,
- f"{Keys.API_KEY.suffix}": api_key,
- },
- f,
- )
-
- self.username = username
- self.user_id = user_id
- self.api_key = api_key
- logger.debug("credentials saved successfully")
-
- def clear(self) -> None:
- """Remove credentials from disk."""
- if self.secrets_file.exists():
- self.secrets_file.unlink()
- for key in Keys:
- setattr(self, key.suffix, None)
- logger.debug("credentials removed successfully")
-
- @property
- def auth_header(self) -> Optional[str]:
- """Authentication header used by lightning-cloud client."""
- if self.api_key:
- token = f"{self.user_id}:{self.api_key}"
- return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # E501
- raise AttributeError(
- "Authentication Failed, no authentication header available. "
- "This is most likely a bug in the LightningCloud Framework"
- )
-
- def _run_server(self) -> None:
- """Start a server to complete authentication."""
- AuthServer().login_with_browser(self)
-
- def authenticate(self) -> Optional[str]:
- """Performs end to end authentication flow.
-
- Returns
- ----------
- authorization header to use when authentication completes.
-
- """
- if not self.load():
- # First try to authenticate from env
- for key in Keys:
- setattr(self, key.suffix, os.environ.get(key.value, None))
-
- if self.user_id and self.api_key:
- self.save("", self.user_id, self.api_key, self.user_id)
- logger.info("Credentials loaded from environment variables")
- return self.auth_header
- if self.api_key or self.user_id:
- raise ValueError(
- "To use env vars for authentication both "
- f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set."
- )
-
- logger.debug("failed to load credentials, opening browser to get new.")
- self._run_server()
- return self.auth_header
-
- if self.user_id and self.api_key:
- return self.auth_header
-
- raise ValueError(
- "We couldn't find any credentials linked to your account. "
- "Please try logging in using the CLI command `lightning_app login`"
- )
-
-
-class AuthServer:
- @staticmethod
- def get_auth_url(port: int) -> str:
- redirect_uri = f"http://localhost:{port}/login-complete"
- params = urlencode({"redirectTo": redirect_uri})
- return f"{get_lightning_cloud_url()}/sign-in?{params}"
-
- def login_with_browser(self, auth: Auth) -> None:
- app = FastAPI()
- port = find_free_network_port()
- url = self.get_auth_url(port)
-
- try:
- # check if server is reachable or catch any network errors
- requests.head(url)
- except requests.ConnectionError as ex:
- raise requests.ConnectionError(
- f"No internet connection available. Please connect to a stable internet connection \n{ex}" # E501
- )
- except requests.RequestException as ex:
- raise requests.RequestException(
- f"An error occurred with the request. Please report this issue to Lightning Team \n{ex}" # E501
- )
-
- logger.info(
- "\nAttempting to automatically open the login page in your default browser.\n"
- 'If the browser does not open, navigate to the "Keys" tab on your Lightning AI profile page:\n\n'
- f"{get_lightning_cloud_url()}/me/keys\n\n"
- 'Copy the "Headless CLI Login" command, and execute it in your terminal.\n'
- )
- click.launch(url)
-
- @app.get("/login-complete")
- async def save_token(request: Request, token="", key="", user_id: str = Query("", alias="userID")):
- async def stop_server_once_request_is_done():
- while not await request.is_disconnected():
- sleep(0.25)
- server.should_exit = True
-
- if not token:
- logger.warn(
- "Login Failed. This is most likely because you're using an older version of the CLI. \n" # E501
- "Please try to update the CLI or open an issue with this information \n" # E501
- f"expected token in {request.query_params.items()}"
- )
- return RedirectResponse(
- url=f"{get_lightning_cloud_url()}/cli-login-failed",
- background=BackgroundTask(stop_server_once_request_is_done),
- )
-
- auth.save(token=token, username=user_id, user_id=user_id, api_key=key)
- logger.info("Login Successful")
-
- # Include the credentials in the redirect so that UI will also be logged in
- params = urlencode({"token": token, "key": key, "userID": user_id})
-
- return RedirectResponse(
- url=f"{get_lightning_cloud_url()}/cli-login-successful?{params}",
- background=BackgroundTask(stop_server_once_request_is_done),
- )
-
- server = uvicorn.Server(config=uvicorn.Config(app, port=port, log_level="error"))
- server.run()
diff --git a/src/lightning/app/utilities/logs_socket_api.py b/src/lightning/app/utilities/logs_socket_api.py
deleted file mode 100644
index 9d25c82ba72d3..0000000000000
--- a/src/lightning/app/utilities/logs_socket_api.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Callable, Optional
-from urllib.parse import urlparse
-
-from websocket import WebSocketApp
-
-from lightning.app.utilities.auth import _AuthTokenGetter
-
-
-class _LightningLogsSocketAPI(_AuthTokenGetter):
- @staticmethod
- def _app_logs_socket_url(host: str, project_id: str, app_id: str, token: str, component: str) -> str:
- return (
- f"wss://{host}/v1/projects/{project_id}/appinstances/{app_id}/logs?"
- f"token={token}&component={component}&follow=true"
- )
-
- def create_lightning_logs_socket(
- self,
- project_id: str,
- app_id: str,
- component: str,
- on_message_callback: Callable,
- on_error_callback: Optional[Callable] = None,
- ) -> WebSocketApp:
- """Creates and returns WebSocketApp to listen to lightning app logs.
-
- .. code-block:: python
- # Synchronous reading, run_forever() is blocking
-
-
- def print_log_msg(ws_app, msg):
- print(msg)
-
-
- flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
- flow_socket.run_forever()
-
- .. code-block:: python
- # Asynchronous reading (with Threads)
-
-
- def print_log_msg(ws_app, msg):
- print(msg)
-
-
- flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
- work_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "work_1", print_log_msg)
-
- flow_logs_thread = Thread(target=flow_logs_socket.run_forever)
- work_logs_thread = Thread(target=work_logs_socket.run_forever)
-
- flow_logs_thread.start()
- work_logs_thread.start()
- # .......
-
- flow_logs_socket.close()
- work_logs_thread.close()
-
- Arguments:
- project_id: Project ID.
- app_id: Application ID.
- component: Component name eg flow.
- on_message_callback: Callback object which is called when received data.
- on_error_callback: Callback object which is called when we get error.
-
- Returns:
- WebSocketApp of the wanted socket
-
- """
- _token = self._get_api_token()
- clean_ws_host = urlparse(self.api_client.configuration.host).netloc
- socket_url = self._app_logs_socket_url(
- host=clean_ws_host,
- project_id=project_id,
- app_id=app_id,
- token=_token,
- component=component,
- )
-
- return WebSocketApp(socket_url, on_message=on_message_callback, on_error=on_error_callback)
diff --git a/src/lightning/app/utilities/name_generator.py b/src/lightning/app/utilities/name_generator.py
deleted file mode 100644
index c57a65f63a3d8..0000000000000
--- a/src/lightning/app/utilities/name_generator.py
+++ /dev/null
@@ -1,1359 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from random import choice, randint
-
-_adjectives = [
- # Appearance, sound, smell...
- "acrid",
- "ambrosial",
- "amorphous",
- "armored",
- "aromatic",
- "bald",
- "blazing",
- "boisterous",
- "bouncy",
- "brawny",
- "bulky",
- "camouflaged",
- "caped",
- "chubby",
- "curvy",
- "elastic",
- "ethereal",
- "fat",
- "feathered",
- "fiery",
- "flashy",
- "flat",
- "fluffy",
- "foamy",
- "fragrant",
- "furry",
- "fuzzy",
- "glaring",
- "hairy",
- "heavy",
- "hissing",
- "horned",
- "icy",
- "imaginary",
- "invisible",
- "lean",
- "loud",
- "loutish",
- "lumpy",
- "lush",
- "masked",
- "meaty",
- "messy",
- "misty",
- "nebulous",
- "noisy",
- "nondescript",
- "organic",
- "purring",
- "quiet",
- "quirky",
- "radiant",
- "roaring",
- "ruddy",
- "rustling",
- "screeching",
- "shaggy",
- "shapeless",
- "shiny",
- "silent",
- "silky",
- "singing",
- "skinny",
- "smooth",
- "soft",
- "spicy",
- "spiked",
- "statuesque",
- "sticky",
- "tacky",
- "tall",
- "tangible",
- "tentacled",
- "thick",
- "thundering",
- "venomous",
- "warm",
- "weightless",
- "whispering",
- "winged",
- "wooden",
- # Beauty & Charm",
- "adorable",
- "affable",
- "amazing",
- "amiable",
- "attractive",
- "beautiful",
- "calm",
- "charming",
- "cherubic",
- "classic",
- "classy",
- "convivial",
- "cordial",
- "cuddly",
- "curly",
- "cute",
- "debonair",
- "elegant",
- "famous",
- "fresh",
- "friendly",
- "funny",
- "gorgeous",
- "graceful",
- "gregarious",
- "grinning",
- "handsome",
- "hilarious",
- "hot",
- "interesting",
- "kind",
- "laughing",
- "lovely",
- "meek",
- "mellow",
- "merciful",
- "neat",
- "nifty",
- "notorious",
- "poetic",
- "pretty",
- "refined",
- "refreshing",
- "sexy",
- "smiling",
- "sociable",
- "spiffy",
- "stylish",
- "sweet",
- "tactful",
- "whimsical",
- "boring",
- # Character & Emotions",
- "abiding",
- "accurate",
- "adamant",
- "adaptable",
- "adventurous",
- "alluring",
- "aloof",
- "ambitious",
- "amusing",
- "annoying",
- "arrogant",
- "aspiring",
- "belligerent",
- "benign",
- "berserk",
- "benevolent",
- "bold",
- "brave",
- "cheerful",
- "chirpy",
- "cocky",
- "congenial",
- "courageous",
- "cryptic",
- "curious",
- "daft",
- "dainty",
- "daring",
- "defiant",
- "delicate",
- "delightful",
- "determined",
- "devout",
- "didactic",
- "diligent",
- "discreet",
- "dramatic",
- "dynamic",
- "eager",
- "eccentric",
- "elated",
- "encouraging",
- "enigmatic",
- "enthusiastic",
- "evasive",
- "faithful",
- "fair",
- "fanatic",
- "fearless",
- "fervent",
- "festive",
- "fierce",
- "fine",
- "free",
- "gabby",
- "garrulous",
- "gay",
- "gentle",
- "glistening",
- "greedy",
- "grumpy",
- "happy",
- "honest",
- "hopeful",
- "hospitable",
- "impetuous",
- "independent",
- "industrious",
- "innocent",
- "intrepid",
- "jolly",
- "jovial",
- "just",
- "lively",
- "loose",
- "loyal",
- "merry",
- "modest",
- "mysterious",
- "nice",
- "obedient",
- "optimistic",
- "orthodox",
- "outgoing",
- "outrageous",
- "overjoyed",
- "passionate",
- "perky",
- "placid",
- "polite",
- "positive",
- "proud",
- "prudent",
- "puzzling",
- "quixotic",
- "quizzical",
- "rebel",
- "resolute",
- "rampant",
- "righteous",
- "romantic",
- "rough",
- "rousing",
- "sassy",
- "satisfied",
- "sly",
- "sincere",
- "snobbish",
- "spirited",
- "spry",
- "stalwart",
- "stirring",
- "swinging",
- "tasteful",
- "thankful",
- "tidy",
- "tremendous",
- "truthful",
- "unselfish",
- "upbeat",
- "uppish",
- "valiant",
- "vehement",
- "vengeful",
- "vigorous",
- "vivacious",
- "zealous",
- "zippy",
- # Intelligence & Abilities",
- "able",
- "adept",
- "analytic",
- "astute",
- "attentive",
- "brainy",
- "busy",
- "calculating",
- "capable",
- "careful",
- "cautious",
- "certain",
- "clever",
- "competent",
- "conscious",
- "cooperative",
- "crafty",
- "crazy",
- "cunning",
- "daffy",
- "devious",
- "discerning",
- "efficient",
- "expert",
- "functional",
- "gifted",
- "helpful",
- "enlightened",
- "idealistic",
- "impartial",
- "industrious",
- "ingenious",
- "inquisitive",
- "intelligent",
- "inventive",
- "judicious",
- "keen",
- "knowing",
- "literate",
- "logical",
- "masterful",
- "mindful",
- "nonchalant",
- "observant",
- "omniscient",
- "poised",
- "practical",
- "pragmatic",
- "proficient",
- "provocative",
- "qualified",
- "radical",
- "rational",
- "realistic",
- "resourceful",
- "savvy",
- "sceptical",
- "sensible",
- "serious",
- "shrewd",
- "skilled",
- "slick",
- "slim",
- "sloppy",
- "smart",
- "sophisticated",
- "stoic",
- "succinct",
- "talented",
- "thoughtful",
- "tricky",
- "unbiased",
- "uptight",
- "versatile",
- "versed",
- "visionary",
- "wise",
- "witty",
- # Strength & Agility",
- "accelerated",
- "active",
- "agile",
- "athletic",
- "dashing",
- "deft",
- "dexterous",
- "energetic",
- "fast",
- "frisky",
- "hasty",
- "hypersonic",
- "meteoric",
- "mighty",
- "muscular",
- "nimble",
- "nippy",
- "powerful",
- "prompt",
- "quick",
- "rapid",
- "resilient",
- "robust",
- "rugged",
- "solid",
- "speedy",
- "steadfast",
- "steady",
- "strong",
- "sturdy",
- "tireless",
- "tough",
- "unyielding",
- # Money & Power",
- "rich",
- "wealthy",
- # Science",
- "meticulous",
- "precise",
- "rigorous",
- "scrupulous",
- "strict",
- # Movement type",
- "airborne",
- "burrowing",
- "crouching",
- "flying",
- "hidden",
- "hopping",
- "jumping",
- "lurking",
- "tunneling",
- "warping",
- # Location and Dwelling",
- "aboriginal",
- "amphibian",
- "aquatic",
- "arboreal",
- "polar",
- "terrestrial",
- "urban",
- # Awesome",
- "accomplished",
- "astonishing",
- "authentic",
- "awesome",
- "delectable",
- "excellent",
- "exotic",
- "exuberant",
- "fabulous",
- "fantastic",
- "fascinating",
- "flawless",
- "fortunate",
- "funky",
- "godlike",
- "glorious",
- "groovy",
- "honored",
- "illustrious",
- "imposing",
- "important",
- "impressive",
- "incredible",
- "invaluable",
- "kickass",
- "majestic",
- "magnificent",
- "marvellous",
- "monumental",
- "perfect",
- "phenomenal",
- "pompous",
- "precious",
- "premium",
- "private",
- "remarkable",
- "spectacular",
- "splendid",
- "successful",
- "wonderful",
- "wondrous",
- # Original",
- "offbeat",
- "original",
- "outstanding",
- "quaint",
- "unique",
- # Time",
- "ancient",
- "antique",
- "prehistoric",
- "primitive",
- # Misc",
- "abstract",
- "acoustic",
- "angelic",
- "arcane",
- "archetypal",
- "augmented",
- "auspicious",
- "axiomatic",
- "beneficial",
- "bipedal",
- "bizarre",
- "complex",
- "dancing",
- "dangerous",
- "demonic",
- "divergent",
- "economic",
- "electric",
- "elite",
- "eminent",
- "enchanted",
- "esoteric",
- "finicky",
- "fractal",
- "futuristic",
- "gainful",
- "hallowed",
- "heavenly",
- "heretic",
- "holistic",
- "hungry",
- "hypnotic",
- "hysterical",
- "illegal",
- "imperial",
- "imported",
- "impossible",
- "inescapable",
- "juicy",
- "liberal",
- "ludicrous",
- "lyrical",
- "magnetic",
- "manipulative",
- "mature",
- "military",
- "macho",
- "married",
- "melodic",
- "natural",
- "naughty",
- "nocturnal",
- "nostalgic",
- "optimal",
- "pastoral",
- "peculiar",
- "piquant",
- "pristine",
- "prophetic",
- "psychedelic",
- "quantum",
- "rare",
- "real",
- "secret",
- "simple",
- "spectral",
- "spiritual",
- "stereotyped",
- "stimulating",
- "straight",
- "strange",
- "tested",
- "therapeutic",
- "true",
- "ubiquitous",
- "uncovered",
- "unnatural",
- "utopian",
- "vagabond",
- "vague",
- "vegan",
- "victorious",
- "vigilant",
- "voracious",
- "wakeful",
- "wandering",
- "watchful",
- "wild",
- # Pseudo-colors",
- "bright",
- "brilliant",
- "colorful",
- "crystal",
- "dark",
- "dazzling",
- "fluorescent",
- "glittering",
- "glossy",
- "gleaming",
- "light",
- "mottled",
- "neon",
- "opalescent",
- "pastel",
- "smoky",
- "sparkling",
- "spotted",
- "striped",
- "translucent",
- "transparent",
- "vivid",
-]
-
-# Docker, starting from 0.7.x, generated names from notable scientists and hackers.
-# Please, for any amazing man that you add to the list, consider adding an equally amazing woman to it, and vice versa.
-_surnames = [
- # Muhammad ibn Jabir Al-Battani was a founding father of astronomy. https://en.wikipedia.org/wiki/Al-Battani
- "albattani",
- # Frances E. Allen, became the first female IBM Fellow in 1989. In 2006, she became the first female
- # recipient of the ACM's Turing Award. https://en.wikipedia.org/wiki/Frances_E._Allen
- "allen",
- # June Almeida - Scottish virologist who took the first pictures of the rubella
- # virus - https://en.wikipedia.org/wiki/June_Almeida
- "almeida",
- # Kathleen Antonelli, American computer programmer and one of the six original
- # programmers of the ENIAC - https://en.wikipedia.org/wiki/Kathleen_Antonelli
- "antonelli",
- # Maria Gaetana Agnesi - Italian mathematician, philosopher, theologian and humanitarian.
- # She was the first woman to write a mathematics handbook and the first woman appointed
- # as a Mathematics Professor at a University. https://en.wikipedia.org/wiki/Maria_Gaetana_Agnesi
- "agnesi",
- # Archimedes was a physicist, engineer and mathematician who invented too many
- # things to list them here. https://en.wikipedia.org/wiki/Archimedes
- "archimedes",
- # Maria Ardinghelli - Italian translator, mathematician and physicist -
- # https://en.wikipedia.org/wiki/Maria_Ardinghelli
- "ardinghelli",
- # Aryabhata - Ancient Indian mathematician-astronomer during 476-550 CE https://en.wikipedia.org/wiki/Aryabhata
- "aryabhata",
- # Wanda Austin - Wanda Austin is the President and CEO of The Aerospace Corporation,
- # a leading architect for the US security space programs. https://en.wikipedia.org/wiki/Wanda_Austin
- "austin",
- # Charles Babbage invented the concept of a programmable computer. https://en.wikipedia.org/wiki/Charles_Babbage.
- "babbage",
- # Stefan Banach - Polish mathematician, was one of the founders of modern
- # functional analysis. https://en.wikipedia.org/wiki/Stefan_Banach
- "banach",
- # Buckaroo Banzai and his mentor Dr. Hikita perfected the "oscillation overthruster",
- # a device that allows one to pass through solid matter. -
- # https://en.wikipedia.org/wiki/The_Adventures_of_Buckaroo_Banzai_Across_the_8th_Dimension
- "banzai",
- # John Bardeen co-invented the transistor - https://en.wikipedia.org/wiki/John_Bardeen
- "bardeen",
- # Jean Bartik, born Betty Jean Jennings, was one of the original programmers
- # for the ENIAC computer. https://en.wikipedia.org/wiki/Jean_Bartik
- "bartik",
- # Laura Bassi, the world's first female professor https://en.wikipedia.org/wiki/Laura_Bassi
- "bassi",
- # Hugh Beaver, British engineer, founder of the Guinness Book of World
- # Records https://en.wikipedia.org/wiki/Hugh_Beaver
- "beaver",
- # Alexander Graham Bell - an eminent Scottish-born scientist, inventor, engineer and innovator
- # who is credited with inventing the first practical telephone - https://en.wikipedia.org/wiki/Alexander_Graham_Bell
- "bell",
- # Karl Friedrich Benz - a German automobile engineer. Inventor of the first
- # practical motorcar. https://en.wikipedia.org/wiki/Karl_Benz
- "benz",
- # Homi J Bhabha - was an Indian nuclear physicist, founding director, and professor of
- # physics at the Tata Institute of Fundamental Research. Colloquially known as "father of
- # Indian nuclear programme"- https://en.wikipedia.org/wiki/Homi_J._Bhabha
- "bhabha",
- # Bhaskara II - Ancient Indian mathematician-astronomer whose work on calculus predates
- # Newton and Leibniz by over half a millennium - https://en.wikipedia.org/wiki/Bh%C4%81skara_II#Calculus
- "bhaskara",
- # Sue Black - British computer scientist and campaigner. She has been instrumental in
- # saving Bletchley Park, the site of World War II codebreaking -
- # https://en.wikipedia.org/wiki/Sue_Black_(computer_scientist)
- "black",
- # Elizabeth Helen Blackburn - Australian-American Nobel laureate; best known
- # for co-discovering telomerase. https://en.wikipedia.org/wiki/Elizabeth_Blackburn
- "blackburn",
- # Elizabeth Blackwell - American doctor and first American woman to receive a
- # medical degree - https://en.wikipedia.org/wiki/Elizabeth_Blackwell
- "blackwell",
- # Niels Bohr is the father of quantum theory. https://en.wikipedia.org/wiki/Niels_Bohr.
- "bohr",
- # Kathleen Booth, she's credited with writing the first assembly language.
- # https://en.wikipedia.org/wiki/Kathleen_Booth
- "booth",
- # Anita Borg - Anita Borg was the founding director of the Institute for
- # Women and Technology (IWT). https://en.wikipedia.org/wiki/Anita_Borg
- "borg",
- # Satyendra Nath Bose - He provided the foundation for Bose\u2013Einstein statistics
- # and the theory of the Bose\u2013Einstein condensate. - https://en.wikipedia.org/wiki/Satyendra_Nath_Bose
- "bose",
- # Katherine Louise Bouman is an imaging scientist and Assistant Professor of Computer
- # Science at the California Institute of Technology. She researches computational methods for
- # imaging, and developed an algorithm that made possible the picture first visualization of a
- # black hole using the Event Horizon Telescope. - https://en.wikipedia.org/wiki/Katie_Bouman
- "bouman",
- # Evelyn Boyd Granville - She was one of the first African-American woman to receive a Ph.D.
- # in mathematics; she earned it in 1949 from Yale University. https://en.wikipedia.org/wiki/Evelyn_Boyd_Granville
- "boyd",
- # Brahmagupta - Ancient Indian mathematician during 598-670 CE who gave rules
- # to compute with zero - https://en.wikipedia.org/wiki/Brahmagupta#Zero
- "brahmagupta",
- # Walter Houser Brattain co-invented the transistor - https://en.wikipedia.org/wiki/Walter_Houser_Brattain
- "brattain",
- # Emmett Brown invented time travel. https://en.wikipedia.org/wiki/Emmett_Brown (thanks Brian Goff)
- "brown",
- # Linda Brown Buck - American biologist and Nobel laureate best known for her genetic and
- # molecular analyses of the mechanisms of smell. https://en.wikipedia.org/wiki/Linda_B._Buck
- "buck",
- # Dame Susan Jocelyn Bell Burnell - Northern Irish astrophysicist who discovered radio pulsars
- # and was the first to analyse them. https://en.wikipedia.org/wiki/Jocelyn_Bell_Burnell
- "burnell",
- # Annie Jump Cannon - pioneering female astronomer who classified hundreds of thousands of stars
- # and created the system we use to understand stars today. https://en.wikipedia.org/wiki/Annie_Jump_Cannon
- "cannon",
- # Rachel Carson - American marine biologist and conservationist, her book Silent Spring and other
- # writings are credited with advancing the global environmental movement.
- # https://en.wikipedia.org/wiki/Rachel_Carson
- "carson",
- # Dame Mary Lucy Cartwright - British mathematician who was one of the first to study what is
- # now known as chaos theory. Also known for Cartwright's theorem which finds applications in
- # signal processing. https://en.wikipedia.org/wiki/Mary_Cartwright
- "cartwright",
- # George Washington Carver - American agricultural scientist and inventor. He was the most
- # prominent black scientist of the early 20th century. https://en.wikipedia.org/wiki/George_Washington_Carver
- "carver",
- # Vinton Gray Cerf - American Internet pioneer, recognised as one of "the fathers of the Internet".
- # With Robert Elliot Kahn, he designed TCP and IP, the primary data communication protocols of
- # the Internet and other computer networks. https://en.wikipedia.org/wiki/Vint_Cerf
- "cerf",
- # Subrahmanyan Chandrasekhar - Astrophysicist known for his mathematical theory on different
- # stages and evolution in structures of the stars. He has won nobel prize for physics -
- # https://en.wikipedia.org/wiki/Subrahmanyan_Chandrasekhar
- "chandrasekhar",
- # Sergey Alexeyevich Chaplygin (April 5, 1869 - October 8, 1942) was a Russian and Soviet physicist,
- # mathematician, and mechanical engineer. He is known for mathematical formulas such as Chaplygin's
- # equation and for a hypothetical substance in cosmology called Chaplygin gas,
- # named after him. https://en.wikipedia.org/wiki/Sergey_Chaplygin
- "chaplygin",
- # Emilie du Chatelet - French natural philosopher, mathematician, physicist, and author
- # during the early 1730s, known for her translation of and commentary on Isaac Newton's book
- # Principia containing basic laws of physics. https://en.wikipedia.org/wiki/%C3%89milie_du_Ch%C3%A2telet
- "chatelet",
- # Asima Chatterjee was an Indian organic chemist noted for her research on vinca alkaloids,
- # development of drugs for treatment of epilepsy and malaria - https://en.wikipedia.org/wiki/Asima_Chatterjee
- "chatterjee",
- # Pafnuty Chebyshev - Russian mathematician. He is known fo his works on probability,
- # statistics, mechanics, analytical geometry and number theory https://en.wikipedia.org/wiki/Pafnuty_Chebyshev
- "chebyshev",
- # Bram Cohen - American computer programmer and author of the BitTorrent
- # peer-to-peer protocol. https://en.wikipedia.org/wiki/Bram_Cohen
- "cohen",
- # David Lee Chaum - American computer scientist and cryptographer. Known for his
- # seminal contributions in the field of anonymous communication. https://en.wikipedia.org/wiki/David_Chaum
- "chaum",
- # Joan Clarke - Bletchley Park code breaker during the Second World War who pioneered techniques
- # that remained top secret for decades. Also an accomplished numismatist https://en.wikipedia.org/wiki/Joan_Clarke
- "clarke",
- # Jane Colden - American botanist widely considered the first female
- # American botanist - https://en.wikipedia.org/wiki/Jane_Colden
- "colden",
- # Gerty Theresa Cori - American biochemist who became the third woman and first American woman to win a
- # Nobel Prize in science, and the first woman to be awarded the Nobel Prize in Physiology
- # or Medicine. Cori was born in Prague. https://en.wikipedia.org/wiki/Gerty_Cori
- "cori",
- # Seymour Roger Cray was an American electrical engineer and supercomputer architect who designed a series
- # of computers that were the fastest in the world for decades. https://en.wikipedia.org/wiki/Seymour_Cray
- "cray",
- # This entry reflects a husband and wife team who worked together:
- # Joan Curran was a Welsh scientist who developed radar and invented chaff, a radar countermeasure.
- # https://en.wikipedia.org/wiki/Joan_Curran Samuel Curran was an Irish physicist who worked
- # alongside his wife during WWII and invented the proximity fuse. https://en.wikipedia.org/wiki/Samuel_Curran
- "curran",
- # Marie Curie discovered radioactivity. https://en.wikipedia.org/wiki/Marie_Curie.
- "curie",
- # Charles Darwin established the principles of natural evolution. https://en.wikipedia.org/wiki/Charles_Darwin.
- "darwin",
- # Leonardo Da Vinci invented too many things to list here. https://en.wikipedia.org/wiki/Leonardo_da_Vinci.
- "davinci",
- # A. K. (Alexander Keewatin) Dewdney, Canadian mathematician, computer scientist, author and filmmaker.
- # Contributor to Scientific American's "Computer Recreations" from 1984 to 1991. Author of Core War (program),
- # The Planiverse, The Armchair Universe, The Magic Machine, The New Turing Omnibus, and more.
- # https://en.wikipedia.org/wiki/Alexander_Dewdney
- "dewdney",
- # Satish Dhawan - Indian mathematician and aerospace engineer, known for leading the successful and
- # indigenous development of the Indian space programme. https://en.wikipedia.org/wiki/Satish_Dhawan
- "dhawan",
- # Bailey Whitfield Diffie - American cryptographer and one of the pioneers of
- # public-key cryptography. https://en.wikipedia.org/wiki/Whitfield_Diffie
- "diffie",
- # Edsger Wybe Dijkstra was a Dutch computer scientist and mathematical scientist.
- # https://en.wikipedia.org/wiki/Edsger_W._Dijkstra.
- "dijkstra",
- # Paul Adrien Maurice Dirac - English theoretical physicist who made fundamental contributions to the
- # early development of both quantum mechanics and quantum electrodynamics. https://en.wikipedia.org/wiki/Paul_Dirac
- "dirac",
- # Agnes Meyer Driscoll - American cryptanalyst during World Wars I and II who successfully cryptanalysed a
- # number of Japanese ciphers. She was also the co-developer of one of the cipher machines of
- # the US Navy, the CM. https://en.wikipedia.org/wiki/Agnes_Meyer_Driscoll
- "driscoll",
- # Donna Dubinsky - played an integral role in the development of personal digital assistants (PDAs)
- # serving as CEO of Palm, Inc. and co-founding Handspring. https://en.wikipedia.org/wiki/Donna_Dubinsky
- "dubinsky",
- # Annie Easley - She was a leading member of the team which developed software for the Centaur
- # rocket stage and one of the first African-Americans in her field. https://en.wikipedia.org/wiki/Annie_Easley
- "easley",
- # Thomas Alva Edison, prolific inventor https://en.wikipedia.org/wiki/Thomas_Edison
- "edison",
- # Albert Einstein invented the general theory of relativity. https://en.wikipedia.org/wiki/Albert_Einstein
- "einstein",
- # Alexandra Asanovna Elbakyan is a Kazakhstani graduate student, computer programmer, internet pirate in
- # hiding, and the creator of the site Sci-Hub. Nature has listed her in 2016 in the top ten people that
- # mattered in science, and Ars Technica has compared her to Aaron Swartz. -
- # https://en.wikipedia.org/wiki/Alexandra_Elbakyan
- "elbakyan",
- # Taher A. ElGamal - Egyptian cryptographer best known for the ElGamal discrete log cryptosystem and the
- # ElGamal digital signature scheme. https://en.wikipedia.org/wiki/Taher_Elgamal
- "elgamal",
- # Gertrude Elion - American biochemist, pharmacologist and the 1988 recipient of the
- # Nobel Prize in Medicine - https://en.wikipedia.org/wiki/Gertrude_Elion
- "elion",
- # James Henry Ellis - British engineer and cryptographer employed by the GCHQ. Best known for
- # conceiving for the first time, the idea of public-key cryptography. https://en.wikipedia.org/wiki/James_H._Ellis
- "ellis",
- # Douglas Engelbart gave the mother of all demos: https://en.wikipedia.org/wiki/Douglas_Engelbart
- "engelbart",
- # Euclid invented geometry. https://en.wikipedia.org/wiki/Euclid
- "euclid",
- # Leonhard Euler invented large parts of modern mathematics. https://de.wikipedia.org/wiki/Leonhard_Euler
- "euler",
- # Michael Faraday - British scientist who contributed to the study of electromagnetism and
- # electrochemistry. https://en.wikipedia.org/wiki/Michael_Faraday
- "faraday",
- # Horst Feistel - German-born American cryptographer who was one of the earliest non-government
- # researchers to study the design and theory of block ciphers. Co-developer of DES and Lucifer.
- # Feistel networks, a symmetric structure used in the construction of block ciphers are named after him.
- # https://en.wikipedia.org/wiki/Horst_Feistel
- "feistel",
- # Pierre de Fermat pioneered several aspects of modern mathematics. https://en.wikipedia.org/wiki/Pierre_de_Fermat
- "fermat",
- # Enrico Fermi invented the first nuclear reactor. https://en.wikipedia.org/wiki/Enrico_Fermi.
- "fermi",
- # Richard Feynman was a key contributor to quantum mechanics and particle physics.
- # https://en.wikipedia.org/wiki/Richard_Feynman
- "feynman",
- # Benjamin Franklin is famous for his experiments in electricity and the invention of the lightning rod.
- "franklin",
- # Yuri Alekseyevich Gagarin - Soviet pilot and cosmonaut, best known as the first human to
- # journey into outer space. https://en.wikipedia.org/wiki/Yuri_Gagarin
- "gagarin",
- # Galileo was a founding father of modern astronomy, and faced politics and obscurantism to
- # establish scientific truth. https://en.wikipedia.org/wiki/Galileo_Galilei
- "galileo",
- # Evariste Galois - French mathematician whose work laid the foundations of Galois theory and group theory,
- # two major branches of abstract algebra, and the subfield of Galois connections, all while still in
- # his late teens. https://en.wikipedia.org/wiki/%C3%89variste_Galois
- "galois",
- # Kadambini Ganguly - Indian physician, known for being the first South Asian female physician,
- # trained in western medicine, to graduate in South Asia. https://en.wikipedia.org/wiki/Kadambini_Ganguly
- "ganguly",
- # William Henry "Bill" Gates III is an American business magnate, philanthropist, investor,
- # computer programmer, and inventor. https://en.wikipedia.org/wiki/Bill_Gates
- "gates",
- # Johann Carl Friedrich Gauss - German mathematician who made significant contributions to many fields,
- # including number theory, algebra, statistics, analysis, differential geometry, geodesy, geophysics, mechanics,
- # electrostatics, magnetic fields, astronomy, matrix theory, and optics.
- # https://en.wikipedia.org/wiki/Carl_Friedrich_Gauss
- "gauss",
- # Marie-Sophie Germain - French mathematician, physicist and philosopher. Known for her work o
- # n elasticity theory, number theory and philosophy. https://en.wikipedia.org/wiki/Sophie_Germain
- "germain",
- # Adele Goldberg, was one of the designers and developers of the Smalltalk language.
- # https://en.wikipedia.org/wiki/Adele_Goldberg_(computer_scientist)
- "goldberg",
- # Adele Goldstine, born Adele Katz, wrote the complete technical description for the first electronic
- # digital computer, ENIAC. https://en.wikipedia.org/wiki/Adele_Goldstine
- "goldstine",
- # Shafi Goldwasser is a computer scientist known for creating theoretical foundations of modern
- # cryptography. Winner of 2012 ACM Turing Award. https://en.wikipedia.org/wiki/Shafi_Goldwasser
- "goldwasser",
- # James Golick, all around gangster.
- "golick",
- # Jane Goodall - British primatologist, ethologist, and anthropologist who is considered to be the
- # world's foremost expert on chimpanzees - https://en.wikipedia.org/wiki/Jane_Goodall
- "goodall",
- # Stephen Jay Gould was was an American paleontologist, evolutionary biologist, and historian of science.
- # He is most famous for the theory of punctuated equilibrium - https://en.wikipedia.org/wiki/Stephen_Jay_Gould
- "gould",
- # Carolyn Widney Greider - American molecular biologist and joint winner of the 2009 Nobel Prize for
- # Physiology or Medicine for the discovery of telomerase. https://en.wikipedia.org/wiki/Carol_W._Greider
- "greider",
- # Alexander Grothendieck - German-born French mathematician who became a leading figure in the creation
- # of modern algebraic geometry. https://en.wikipedia.org/wiki/Alexander_Grothendieck
- "grothendieck",
- # Lois Haibt - American computer scientist, part of the team at IBM that developed FORTRAN -
- # https://en.wikipedia.org/wiki/Lois_Haibt
- "haibt",
- # Margaret Hamilton - Director of the Software Engineering Division of the MIT Instrumentation Laboratory,
- # which developed on-board flight software for the Apollo space program.
- # https://en.wikipedia.org/wiki/Margaret_Hamilton_(scientist)
- "hamilton",
- # Caroline Harriet Haslett - English electrical engineer, electricity industry administrator and champion of
- # women's rights. Co-author of British Standard 1363 that specifies AC power plugs and sockets used across
- # the United Kingdom (which is widely considered as one of the safest designs).
- # https://en.wikipedia.org/wiki/Caroline_Haslett
- "haslett",
- # Stephen Hawking pioneered the field of cosmology by combining general relativity and quantum mechanics.
- # https://en.wikipedia.org/wiki/Stephen_Hawking
- "hawking",
- # Martin Edward Hellman - American cryptologist, best known for his invention of public-key cryptography
- # in co-operation with Whitfield Diffie and Ralph Merkle. https://en.wikipedia.org/wiki/Martin_Hellman
- "hellman",
- # Werner Heisenberg was a founding father of quantum mechanics. https://en.wikipedia.org/wiki/Werner_Heisenberg
- "heisenberg",
- # Grete Hermann was a German philosopher noted for her philosophical work on the foundations of quantum mechanics.
- # https://en.wikipedia.org/wiki/Grete_Hermann
- "hermann",
- # Caroline Lucretia Herschel - German astronomer and discoverer of several comets.
- # https://en.wikipedia.org/wiki/Caroline_Herschel
- "herschel",
- # Heinrich Rudolf Hertz - German physicist who first conclusively proved the existence of the electromagnetic waves.
- # https://en.wikipedia.org/wiki/Heinrich_Hertz
- "hertz",
- # Jaroslav Heyrovsky was the inventor of the polarographic method, father of the electroanalytical method, and
- # recipient of the Nobel Prize in 1959. His main field of work was polarography.
- # https://en.wikipedia.org/wiki/Jaroslav_Heyrovsk%C3%BD
- "heyrovsky",
- # Dorothy Hodgkin was a British biochemist, credited with the development of protein crystallography. She was
- # awarded the Nobel Prize in Chemistry in 1964. https://en.wikipedia.org/wiki/Dorothy_Hodgkin
- "hodgkin",
- # Douglas R. Hofstadter is an American professor of cognitive science and author of the Pulitzer Prize and American
- # Book Award-winning work Goedel, Escher, Bach: An Eternal Golden Braid in 1979. A mind-bending work which coined
- # Hofstadter's Law: "It always takes longer than you expect, even when you take into account Hofstadter's Law."
- # https://en.wikipedia.org/wiki/Douglas_Hofstadter
- "hofstadter",
- # Erna Schneider Hoover revolutionized modern communication by inventing a computerized telephone switching method.
- # https://en.wikipedia.org/wiki/Erna_Schneider_Hoover
- "hoover",
- # Grace Hopper developed the first compiler for a computer programming language and is credited with popularizing
- # the term "debugging" for fixing computer glitches. https://en.wikipedia.org/wiki/Grace_Hopper
- "hopper",
- # Frances Hugle, she was an American scientist, engineer, and inventor who contributed to the understanding of
- # semiconductors, integrated circuitry, and the unique electrical principles of microscopic materials.
- # https://en.wikipedia.org/wiki/Frances_Hugle
- "hugle",
- # Hypatia - Greek Alexandrine Neoplatonist philosopher in Egypt who was one of the earliest mothers of mathematics -
- # https://en.wikipedia.org/wiki/Hypatia
- "hypatia",
- # Teruko Ishizaka - Japanese scientist and immunologist who co-discovered the antibody class Immunoglobulin E.
- # https://en.wikipedia.org/wiki/Teruko_Ishizaka
- "ishizaka",
- # Mary Jackson, American mathematician and aerospace engineer who earned the highest title within NASA's engineering
- #
- # department - https://en.wikipedia.org/wiki/Mary_Jackson_(engineer)
- "jackson",
- # Yeong-Sil Jang was a Korean scientist and astronomer during the Joseon Dynasty; he invented the first metal
- # printing press and water gauge. https://en.wikipedia.org/wiki/Jang_Yeong-sil
- "jang",
- # Mae Carol Jemison - is an American engineer, physician, and former NASA astronaut. She became the first black
- # woman to travel in space when she served as a mission specialist aboard the Space Shuttle Endeavour -
- # https://en.wikipedia.org/wiki/Mae_Jemison
- "jemison",
- # Betty Jennings - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Jean_Bartik
- "jennings",
- # Mary Lou Jepsen, was the founder and chief technology officer of One Laptop Per Child (OLPC), and the founder of
- # Pixel Qi. https://en.wikipedia.org/wiki/Mary_Lou_Jepsen
- "jepsen",
- # Katherine Coleman Goble Johnson - American physicist and mathematician contributed to the NASA.
- # https://en.wikipedia.org/wiki/Katherine_Johnson
- "johnson",
- # Irene Joliot-Curie - French scientist who was awarded the Nobel Prize for Chemistry in 1935. Daughter of Marie
- # and Pierre Curie. https://en.wikipedia.org/wiki/Ir%C3%A8ne_Joliot-Curie
- "joliot",
- # Karen Sparck Jones came up with the concept of inverse document frequency, which is used in most search engines
- # today. https://en.wikipedia.org/wiki/Karen_Sp%C3%A4rck_Jones
- "jones",
- # A. P. J. Abdul Kalam - is an Indian scientist aka Missile Man of India for his work on the development of
- # ballistic missile and launch vehicle technology - https://en.wikipedia.org/wiki/A._P._J._Abdul_Kalam
- "kalam",
- # Sergey Petrovich Kapitsa (14 February 1928 - 14 August 2012) was a Russian physicist and demographer. He was best
- # known as host of the popular and long-running Russian scientific TV show, Evident, but Incredible. His father was
- # the Nobel laureate Soviet-era physicist Pyotr Kapitsa, and his brother was the geographer and Antarctic explorer
- # Andrey Kapitsa. - https://en.wikipedia.org/wiki/Sergey_Kapitsa
- "kapitsa",
- # Susan Kare, created the icons and many of the interface elements for the original Apple Macintosh in the 1980s,
- # and was an original employee of NeXT, working as the Creative Director. https://en.wikipedia.org/wiki/Susan_Kare
- "kare",
- # Mstislav Keldysh - a Soviet scientist in the field of mathematics and mechanics, academician of the USSR Academy
- # of Sciences (1946), President of the USSR Academy of Sciences (1961-1975),
- # three times Hero of Socialist Labor (1956, 1961, 1971), fellow of the Royal Society of Edinburgh (1968).
- # https://en.wikipedia.org/wiki/Mstislav_Keldysh
- "keldysh",
- # Mary Kenneth Keller, Sister Mary Kenneth Keller became the first American woman to earn a
- # PhD in Computer Science in 1965. https://en.wikipedia.org/wiki/Mary_Kenneth_Keller
- "keller",
- # Johannes Kepler, German astronomer known for his three laws of planetary motion -
- # https://en.wikipedia.org/wiki/Johannes_Kepler
- "kepler",
- # Omar Khayyam - Persian mathematician, astronomer and poet. Known for his work on the classification and solution
- # of cubic equations, for his contribution to the understanding of Euclid's fifth postulate and for computing the
- # length of a year very accurately. https://en.wikipedia.org/wiki/Omar_Khayyam
- "khayyam",
- # Har Gobind Khorana - Indian-American biochemist who shared the 1968 Nobel Prize for Physiology -
- # https://en.wikipedia.org/wiki/Har_Gobind_Khorana
- "khorana",
- # Jack Kilby invented silicon integrated circuits and gave Silicon Valley its name. -
- # https://en.wikipedia.org/wiki/Jack_Kilby
- "kilby",
- # Maria Kirch - German astronomer and first woman to discover a comet -
- # https://en.wikipedia.org/wiki/Maria_Margarethe_Kirch
- "kirch",
- # Donald Knuth - American computer scientist, author of "The Art of Computer Programming" and creator of the TeX
- # typesetting system. https://en.wikipedia.org/wiki/Donald_Knuth
- "knuth",
- # Sophie Kowalevski - Russian mathematician responsible for important original contributions to analysis,
- # differential equations and mechanics - https://en.wikipedia.org/wiki/Sofia_Kovalevskaya
- "kowalevski",
- # Marie-Jeanne de Lalande - French astronomer, mathematician and cataloguer of stars -
- # https://en.wikipedia.org/wiki/Marie-Jeanne_de_Lalande
- "lalande",
- # Hedy Lamarr - Actress and inventor. The principles of her work are now incorporated into modern Wi-Fi, CDMA
- # and Bluetooth technology. https://en.wikipedia.org/wiki/Hedy_Lamarr
- "lamarr",
- # Leslie B. Lamport - American computer scientist. Lamport is best known for his seminal work in distributed
- # systems and was the winner of the 2013 Turing Award. https://en.wikipedia.org/wiki/Leslie_Lamport
- "lamport",
- # Mary Leakey - British paleoanthropologist who discovered the first fossilized Proconsul skull -
- # https://en.wikipedia.org/wiki/Mary_Leakey
- "leakey",
- # Henrietta Swan Leavitt - she was an American astronomer who discovered the relation between the luminosity and
- # the period of Cepheid variable stars. https://en.wikipedia.org/wiki/Henrietta_Swan_Leavitt
- "leavitt",
- # Esther Miriam Zimmer Lederberg - American microbiologist and a pioneer of bacterial genetics.
- # https://en.wikipedia.org/wiki/Esther_Lederberg
- "lederberg",
- # Inge Lehmann - Danish seismologist and geophysicist. Known for discovering in 1936 that the Earth has a solid
- # inner core inside a molten outer core. https://en.wikipedia.org/wiki/Inge_Lehmann
- "lehmann",
- # Daniel Lewin - Mathematician, Akamai co-founder, soldier, 9/11 victim-- Developed optimization techniques for
- # routing traffic on the internet. Died attempting to stop the 9-11 hijackers.
- # https://en.wikipedia.org/wiki/Daniel_Lewin
- "lewin",
- # Ruth Lichterman - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Ruth_Teitelbaum
- "lichterman",
- # Barbara Liskov - co-developed the Liskov substitution principle. Liskov was also the winner of the Turing
- # Prize in 2008. - https://en.wikipedia.org/wiki/Barbara_Liskov
- "liskov",
- # Ada Lovelace invented the first algorithm. https://en.wikipedia.org/wiki/Ada_Lovelace (thanks James Turnbull)
- "lovelace",
- # Auguste and Louis Lumiere - the first filmmakers in history -
- # https://en.wikipedia.org/wiki/Auguste_and_Louis_Lumi%C3%A8re
- "lumiere",
- # Mahavira - Ancient Indian mathematician during 9th century AD who discovered basic algebraic identities -
- # https://en.wikipedia.org/wiki/Mah%C4%81v%C4%ABra_(mathematician)
- "mahavira",
- # Lynn Margulis (b. Lynn Petra Alexander) - an American evolutionary theorist and biologist, science author,
- # educator, and popularizer, and was the primary modern proponent for the significance of symbiosis in evolution. -
- # https://en.wikipedia.org/wiki/Lynn_Margulis
- "margulis",
- # Yukihiro Matsumoto - Japanese computer scientist and software programmer best known as the chief designer of
- # the Ruby programming language. https://en.wikipedia.org/wiki/Yukihiro_Matsumoto
- "matsumoto",
- # James Clerk Maxwell - Scottish physicist, best known for his formulation of electromagnetic theory.
- # https://en.wikipedia.org/wiki/James_Clerk_Maxwell
- "maxwell",
- # Maria Mayer - American theoretical physicist and Nobel laureate in Physics for proposing the nuclear shell model
- # of the atomic nucleus - https://en.wikipedia.org/wiki/Maria_Mayer
- "mayer",
- # John McCarthy invented LISP: https://en.wikipedia.org/wiki/John_McCarthy_(computer_scientist)
- "mccarthy",
- # Barbara McClintock - a distinguished American cytogeneticist, 1983 Nobel Laureate in Physiology or Medicine for
- # discovering transposons. https://en.wikipedia.org/wiki/Barbara_McClintock
- "mcclintock",
- # Anne Laura Dorinthea McLaren - British developmental biologist whose work helped lead to human
- # in-vitro fertilisation. https://en.wikipedia.org/wiki/Anne_McLaren
- "mclaren",
- # Malcolm McLean invented the modern shipping container: https://en.wikipedia.org/wiki/Malcom_McLean
- "mclean",
- # Kay McNulty - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Kathleen_Antonelli
- "mcnulty",
- # Gregor Johann Mendel - Czech scientist and founder of genetics. https://en.wikipedia.org/wiki/Gregor_Mendel
- "mendel",
- # Dmitri Mendeleev - a chemist and inventor. He formulated the Periodic Law, created a farsighted version of the
- # periodic table of elements, and used it to correct the properties of some already discovered elements and also
- # to predict the properties of eight elements yet to be discovered. https://en.wikipedia.org/wiki/Dmitri_Mendeleev
- "mendeleev",
- # Lise Meitner - Austrian/Swedish physicist who was involved in the discovery of nuclear fission. The element
- # meitnerium is named after her - https://en.wikipedia.org/wiki/Lise_Meitner
- "meitner",
- # Carla Meninsky, was the game designer and programmer for Atari 2600 games Dodge 'Em and Warlords.
- # https://en.wikipedia.org/wiki/Carla_Meninsky
- "meninsky",
- # Ralph C. Merkle - American computer scientist, known for devising Merkle's puzzles - one of the very first
- # schemes for public-key cryptography. Also, inventor of Merkle trees and co-inventor of the Merkle-Damgard
- # construction for building collision-resistant cryptographic hash functions and the Merkle-Hellman knapsack
- # cryptosystem. https://en.wikipedia.org/wiki/Ralph_Merkle
- "merkle",
- # Johanna Mestorf - German prehistoric archaeologist and first female museum director in Germany -
- # https://en.wikipedia.org/wiki/Johanna_Mestorf
- "mestorf",
- # Maryam Mirzakhani - an Iranian mathematician and the first woman to win the Fields Medal.
- # https://en.wikipedia.org/wiki/Maryam_Mirzakhani
- "mirzakhani",
- # Rita Levi-Montalcini - Won Nobel Prize in Physiology or Medicine jointly with colleague Stanley Cohen for the
- # discovery of nerve growth factor (https://en.wikipedia.org/wiki/Rita_Levi-Montalcini)
- "montalcini",
- # Gordon Earle Moore - American engineer, Silicon Valley founding father, author of Moore's law.
- # https://en.wikipedia.org/wiki/Gordon_Moore
- "moore",
- # Samuel Morse - contributed to the invention of a single-wire telegraph system based on European telegraphs
- # and was a co-developer of the Morse code - https://en.wikipedia.org/wiki/Samuel_Morse
- "morse",
- # Ian Murdock - founder of the Debian project - https://en.wikipedia.org/wiki/Ian_Murdock
- "murdock",
- # May-Britt Moser - Nobel prize winner neuroscientist who contributed to the discovery of grid cells in the brain.
- # https://en.wikipedia.org/wiki/May-Britt_Moser
- "moser",
- # John Napier of Merchiston - Scottish landowner known as an astronomer, mathematician and physicist.
- # Best known for his discovery of logarithms. https://en.wikipedia.org/wiki/John_Napier
- "napier",
- # John Forbes Nash, Jr. - American mathematician who made fundamental contributions to game theory, differential
- # geometry, and the study of partial differential equations. https://en.wikipedia.org/wiki/John_Forbes_Nash_Jr.
- "nash",
- # John von Neumann - todays computer architectures are based on the von Neumann architecture.
- # https://en.wikipedia.org/wiki/Von_Neumann_architecture
- "neumann",
- # Isaac Newton invented classic mechanics and modern optics. https://en.wikipedia.org/wiki/Isaac_Newton
- "newton",
- # Florence Nightingale, more prominently known as a nurse, was also the first female member of the Royal Statistical
- # Society and a pioneer in statistical graphics
- # https://en.wikipedia.org/wiki/Florence_Nightingale#Statistics_and_sanitary_reform
- "nightingale",
- # Alfred Nobel - a Swedish chemist, engineer, innovator, and armaments manufacturer (inventor of dynamite) -
- # https://en.wikipedia.org/wiki/Alfred_Nobel
- "nobel",
- # Emmy Noether, German mathematician. Noether's Theorem is named after her.
- # https://en.wikipedia.org/wiki/Emmy_Noether
- "noether",
- # Poppy Northcutt. Poppy Northcutt was the first woman to work as part of NASA's Mission Control.
- # http://www.businessinsider.com/poppy-northcutt-helped-apollo-astronauts-2014-12?op=1
- "northcutt",
- # Robert Noyce invented silicon integrated circuits and gave Silicon Valley its name. -
- # https://en.wikipedia.org/wiki/Robert_Noyce
- "noyce",
- # Panini - Ancient Indian linguist and grammarian from 4th century CE who worked on the world's first formal system
- # - https://en.wikipedia.org/wiki/P%C4%81%E1%B9%87ini#Comparison_with_modern_formal_systems
- "panini",
- # Ambroise Pare invented modern surgery. https://en.wikipedia.org/wiki/Ambroise_Par%C3%A9
- "pare",
- # Blaise Pascal, French mathematician, physicist, and inventor - https://en.wikipedia.org/wiki/Blaise_Pascal
- "pascal",
- # Louis Pasteur discovered vaccination, fermentation and pasteurization.
- # https://en.wikipedia.org/wiki/Louis_Pasteur.
- "pasteur",
- # Cecilia Payne-Gaposchkin was an astronomer and astrophysicist who, in 1925, proposed in her Ph.D. thesis an
- # explanation for the composition of stars in terms of the relative abundances of hydrogen and helium.
- # https://en.wikipedia.org/wiki/Cecilia_Payne-Gaposchkin
- "payne",
- # Radia Perlman is a software designer and network engineer and most famous for her invention of the
- # spanning-tree protocol (STP). https://en.wikipedia.org/wiki/Radia_Perlman
- "perlman",
- # Rob Pike was a key contributor to Unix, Plan 9, the X graphic system, utf-8, and the Go programming language.
- # https://en.wikipedia.org/wiki/Rob_Pike
- "pike",
- # Henri Poincare made fundamental contributions in several fields of mathematics.
- # https://en.wikipedia.org/wiki/Henri_Poincar%C3%A9
- "poincare",
- # Laura Poitras is a director and producer whose work, made possible by open source crypto tools, advances the
- # causes of truth and freedom of information by reporting disclosures by whistleblowers such as Edward Snowden.
- # https://en.wikipedia.org/wiki/Laura_Poitras
- "poitras",
- # Tat'yana Avenirovna Proskuriakova (January 23 [O.S. January 10] 1909 - August 30, 1985) was a Russian-American
- # Mayanist scholar and archaeologist who contributed significantly to the deciphering of Maya hieroglyphs, the
- # writing system of the pre-Columbian Maya civilization of Mesoamerica.
- # https://en.wikipedia.org/wiki/Tatiana_Proskouriakoff
- "proskuriakova",
- # Claudius Ptolemy - a Greco-Egyptian writer of Alexandria, known as a mathematician, astronomer, geographer,
- # astrologer, and poet of a single epigram in the Greek Anthology - https://en.wikipedia.org/wiki/Ptolemy
- "ptolemy",
- # C. V. Raman - Indian physicist who won the Nobel Prize in 1930 for proposing the Raman effect. -
- # https://en.wikipedia.org/wiki/C._V._Raman
- "raman",
- # Srinivasa Ramanujan - Indian mathematician and autodidact who made extraordinary contributions to mathematical
- # analysis, number theory, infinite series, and continued fractions. -
- # https://en.wikipedia.org/wiki/Srinivasa_Ramanujan
- "ramanujan",
- # Sally Kristen Ride was an American physicist and astronaut. She was the first American woman in space, and the
- # youngest American astronaut. https://en.wikipedia.org/wiki/Sally_Ride
- "ride",
- # Dennis Ritchie - co-creator of UNIX and the C programming language. - https://en.wikipedia.org/wiki/Dennis_Ritchie
- "ritchie",
- # Ida Rhodes - American pioneer in computer programming, designed the first computer used for Social Security.
- # https://en.wikipedia.org/wiki/Ida_Rhodes
- "rhodes",
- # Julia Hall Bowman Robinson - American mathematician renowned for her contributions to the fields of computability
- # theory and computational complexity theory. https://en.wikipedia.org/wiki/Julia_Robinson
- "robinson",
- # Wilhelm Conrad Rontgen - German physicist who was awarded the first Nobel Prize in Physics in 1901 for the
- # discovery of X-rays (Rontgen rays). https://en.wikipedia.org/wiki/Wilhelm_R%C3%B6ntgen
- "roentgen",
- # Rosalind Franklin - British biophysicist and X-ray crystallographer whose research was critical to the
- # understanding of DNA - https://en.wikipedia.org/wiki/Rosalind_Franklin
- "rosalind",
- # Vera Rubin - American astronomer who pioneered work on galaxy rotation rates.
- # https://en.wikipedia.org/wiki/Vera_Rubin
- "rubin",
- # Meghnad Saha - Indian astrophysicist best known for his development of the Saha equation, used to describe
- # chemical and physical conditions in stars - https://en.wikipedia.org/wiki/Meghnad_Saha
- "saha",
- # Jean E. Sammet developed FORMAC, the first widely used computer language for symbolic manipulation of
- # mathematical formulas. https://en.wikipedia.org/wiki/Jean_E._Sammet
- "sammet",
- # Mildred Sanderson - American mathematician best known for Sanderson's theorem concerning modular invariants.
- # https://en.wikipedia.org/wiki/Mildred_Sanderson
- "sanderson",
- # Satoshi Nakamoto is the name used by the unknown person or group of people who developed bitcoin, authored the
- # bitcoin white paper, and created and deployed bitcoin's original reference implementation.
- # https://en.wikipedia.org/wiki/Satoshi_Nakamoto
- "satoshi",
- # Adi Shamir - Israeli cryptographer whose numerous inventions and contributions to cryptography include the Ferge
- # Fiat Shamir identification scheme, the Rivest Shamir Adleman (RSA) public-key cryptosystem, the Shamir's secret
- # sharing scheme, the breaking of the Merkle-Hellman cryptosystem, the TWINKLE and TWIRL factoring devices and the
- # discovery of differential cryptanalysis (with Eli Biham). https://en.wikipedia.org/wiki/Adi_Shamir
- "shamir",
- # Claude Shannon - The father of information theory and founder of digital circuit design theory.
- # (https://en.wikipedia.org/wiki/Claude_Shannon)
- "shannon",
- # Carol Shaw - Originally an Atari employee, Carol Shaw is said to be the first female video game designer.
- # https://en.wikipedia.org/wiki/Carol_Shaw_(video_game_designer)
- "shaw",
- # Dame Stephanie "Steve" Shirley - Founded a software company in 1962 employing women working from home.
- # https://en.wikipedia.org/wiki/Steve_Shirley
- "shirley",
- # William Shockley co-invented the transistor - https://en.wikipedia.org/wiki/William_Shockley
- "shockley",
- # Lina Solomonovna Stern (or Shtern; 26 August 1878 - 7 March 1968) was a Soviet biochemist, physiologist and
- # humanist whose medical discoveries saved thousands of lives at the fronts of World War II. She is best known
- # for her pioneering work on blood\u2013brain barrier, which she described as hemato-encephalic barrier in 1921.
- # https://en.wikipedia.org/wiki/Lina_Stern
- "shtern",
- # Francoise Barre-Sinoussi - French virologist and Nobel Prize Laureate in Physiology or Medicine; her work was
- # fundamental in identifying HIV as the cause of AIDS.
- # https://en.wikipedia.org/wiki/Fran%C3%A7oise_Barr%C3%A9-Sinoussi
- "sinoussi",
- # Betty Snyder - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Betty_Holberton
- "snyder",
- # Cynthia Solomon - Pioneer in the fields of artificial intelligence, computer science and educational computing.
- # Known for creation of Logo, an educational programming language. https://en.wikipedia.org/wiki/Cynthia_Solomon
- "solomon",
- # Frances Spence - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Frances_Spence
- "spence",
- # Michael Stonebraker is a database research pioneer and architect of Ingres, Postgres, VoltDB and SciDB.
- # Winner of 2014 ACM Turing Award. https://en.wikipedia.org/wiki/Michael_Stonebraker
- "stonebraker",
- # Ivan Edward Sutherland - American computer scientist and Internet pioneer, widely regarded as the father of
- # computer graphics. https://en.wikipedia.org/wiki/Ivan_Sutherland
- "sutherland",
- # Janese Swanson (with others) developed the first of the Carmen Sandiego games. She went on to found Girl Tech.
- # https://en.wikipedia.org/wiki/Janese_Swanson
- "swanson",
- # Aaron Swartz was influential in creating RSS, Markdown, Creative Commons, Reddit, and much of the internet as we
- # know it today. He was devoted to freedom of information on the web. https://en.wikiquote.org/wiki/Aaron_Swartz
- "swartz",
- # Bertha Swirles was a theoretical physicist who made a number of contributions to early quantum theory.
- # https://en.wikipedia.org/wiki/Bertha_Swirles
- "swirles",
- # Helen Brooke Taussig - American cardiologist and founder of the field of paediatric cardiology.
- # https://en.wikipedia.org/wiki/Helen_B._Taussig
- "taussig",
- # Valentina Tereshkova is a Russian engineer, cosmonaut and politician. She was the first woman to fly to space in
- # 1963. In 2013, at the age of 76, she offered to go on a one-way mission to Mars.
- # https://en.wikipedia.org/wiki/Valentina_Tereshkova
- "tereshkova",
- # Nikola Tesla invented the AC electric system and every gadget ever used by a James Bond villain.
- # https://en.wikipedia.org/wiki/Nikola_Tesla
- "tesla",
- # Marie Tharp - American geologist and oceanic cartographer who co-created the first scientific map of the Atlantic
- # Ocean floor. Her work led to the acceptance of the theories of plate tectonics and continental drift.
- # https://en.wikipedia.org/wiki/Marie_Tharp
- "tharp",
- # Ken Thompson - co-creator of UNIX and the C programming language - https://en.wikipedia.org/wiki/Ken_Thompson
- "thompson",
- # Linus Torvalds invented Linux and Git. https://en.wikipedia.org/wiki/Linus_Torvalds
- "torvalds",
- # Youyou Tu - Chinese pharmaceutical chemist and educator known for discovering artemisinin and dihydroartemisinin,
- # used to treat malaria, which has saved millions of lives. Joint winner of the 2015 Nobel Prize in Physiology or
- # Medicine. https://en.wikipedia.org/wiki/Tu_Youyou
- "tu",
- # Alan Turing was a founding father of computer science. https://en.wikipedia.org/wiki/Alan_Turing.
- "turing",
- # Varahamihira - Ancient Indian mathematician who discovered trigonometric formulae during 505-587 CE -
- # https://en.wikipedia.org/wiki/Var%C4%81hamihira#Contributions
- "varahamihira",
- # Dorothy Vaughan was a NASA mathematician and computer programmer on the SCOUT launch vehicle program that put
- # America's first satellites into space - https://en.wikipedia.org/wiki/Dorothy_Vaughan
- "vaughan",
- # Sir Mokshagundam Visvesvaraya - is a notable Indian engineer. He is a recipient of the Indian Republic's highest
- # honour, the Bharat Ratna, in 1955. On his birthday, 15 September is celebrated as Engineer's Day in India in his
- # memory - https://en.wikipedia.org/wiki/Visvesvaraya
- "visvesvaraya",
- # Christiane Nusslein-Volhard - German biologist, won Nobel Prize in Physiology or Medicine in 1995 for research on
- # the genetic control of embryonic development. https://en.wikipedia.org/wiki/Christiane_N%C3%BCsslein-Volhard
- "volhard",
- # Cedric Villani - French mathematician, won Fields Medal, Fermat Prize and Poincare Price for his work in
- # differential geometry and statistical mechanics. https://en.wikipedia.org/wiki/C%C3%A9dric_Villani
- "villani",
- # Marlyn Wescoff - one of the original programmers of the ENIAC. https://en.wikipedia.org/wiki/ENIAC -
- # https://en.wikipedia.org/wiki/Marlyn_Meltzer
- "wescoff",
- # Sylvia B. Wilbur - British computer scientist who helped develop the ARPANET, was one of the first to exchange
- # email in the UK and a leading researcher in computer-supported collaborative work.
- # https://en.wikipedia.org/wiki/Sylvia_Wilbur
- "wilbur",
- # Andrew Wiles - Notable British mathematician who proved the enigmatic Fermat's Last Theorem -
- # https://en.wikipedia.org/wiki/Andrew_Wiles
- "wiles",
- # Roberta Williams, did pioneering work in graphical adventure games for personal computers, particularly the King's
- # Quest series. https://en.wikipedia.org/wiki/Roberta_Williams
- "williams",
- # Malcolm John Williamson - British mathematician and cryptographer employed by the GCHQ. Developed in 1974 what
- # is now known as Diffie-Hellman key exchange (Diffie and Hellman first published the scheme in 1976).
- # https://en.wikipedia.org/wiki/Malcolm_J._Williamson
- "williamson",
- # Sophie Wilson designed the first Acorn Micro-Computer and the instruction set for ARM processors.
- # https://en.wikipedia.org/wiki/Sophie_Wilson
- "wilson",
- # Jeannette Wing - co-developed the Liskov substitution principle. - https://en.wikipedia.org/wiki/Jeannette_Wing
- "wing",
- # Steve Wozniak invented the Apple I and Apple II. https://en.wikipedia.org/wiki/Steve_Wozniak
- "wozniak",
- # The Wright brothers, Orville and Wilbur - credited with inventing and building the world's first successful
- # airplane and making the first controlled, powered and sustained heavier-than-air human flight -
- # https://en.wikipedia.org/wiki/Wright_brothers
- "wright",
- # Chien-Shiung Wu - Chinese-American experimental physicist who made significant contributions to nuclear physics.
- # https://en.wikipedia.org/wiki/Chien-Shiung_Wu
- "wu",
- # Rosalyn Sussman Yalow - Rosalyn Sussman Yalow was an American medical physicist, and a co-winner of the 1977
- # Nobel Prize in Physiology or Medicine for development of the radioimmunoassay technique.
- # https://en.wikipedia.org/wiki/Rosalyn_Sussman_Yalow
- "yalow",
- # Ada Yonath - an Israeli crystallographer, the first woman from the Middle East to win a Nobel prize in the
- # sciences. https://en.wikipedia.org/wiki/Ada_Yonath
- "yonath",
- # Nikolay Yegorovich Zhukovsky (January 17 1847 - March 17, 1921) was a Russian scientist, mathematician and
- # engineer, and a founding father of modern aero- and hydrodynamics. Whereas contemporary scientists scoffed at the
- # idea of human flight, Zhukovsky was the first to undertake the study of airflow. He is often called the Father
- # of Russian Aviation. https://en.wikipedia.org/wiki/Nikolay_Yegorovich_Zhukovsky
- "zhukovsky",
-]
-
-
-def get_unique_name():
- """Generates a random name in the style of "docker containers".
-
- This is generated from the list of adjectives and surnames in this package,
- formatted as "adjective_surname" with a random integer between 0 and 10000
- added to the end.
-
- A python port of docker's random container name generator.
- Original source:
- https://raw.githubusercontent.com/moby/moby/master/pkg/namesgenerator/names-generator.go
-
- Examples:
-
- >>> import random ; random.seed(42)
- >>> get_unique_name()
- 'meek-ardinghelli-4506'
- >>> get_unique_name()
- 'truthful-dijkstra-2286'
-
- """
- adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311
- return f"{adjective}-{surname}-{i}"
diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py
deleted file mode 100644
index d7ba2a4f88102..0000000000000
--- a/src/lightning/app/utilities/network.py
+++ /dev/null
@@ -1,211 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import socket
-from functools import wraps
-from typing import Any, Callable, Dict, Optional
-from urllib.parse import urljoin
-
-import requests
-
-# for backwards compatibility
-from lightning_cloud.rest_client import GridRestClient, LightningClient, create_swagger_client # noqa: F401
-from requests import Session
-from requests.adapters import HTTPAdapter
-from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
-from urllib3.util.retry import Retry
-
-from lightning.app.core import constants
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-
-
-# Global record to track ports that have been allocated in this session.
-_reserved_ports = set()
-
-
-def find_free_network_port() -> int:
- """Finds a free port on localhost."""
- if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
- return _find_free_network_port_cloudspace()
-
- port = None
-
- for _ in range(10):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(("", 0))
- port = sock.getsockname()[1]
- sock.close()
-
- if port not in _reserved_ports:
- break
-
- if port in _reserved_ports:
- # Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong
- raise RuntimeError(
- "Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`."
- )
-
- _reserved_ports.add(port)
- return port
-
-
-def _find_free_network_port_cloudspace():
- """Finds a free port in the exposed range when running in a cloudspace."""
- for port in range(
- constants.APP_SERVER_PORT + 1, # constants.APP_SERVER_PORT is reserved for the app server
- constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT,
- ):
- if port in _reserved_ports:
- continue
-
- try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(("", port))
- sock.close()
- _reserved_ports.add(port)
- return port
- except OSError:
- continue
-
- # This error should never happen. An app using this many ports would probably fail on a single machine anyway.
- raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.")
-
-
-_CONNECTION_RETRY_TOTAL = 2880
-_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
-_DEFAULT_REQUEST_TIMEOUT = 30 # seconds
-
-
-def _configure_session() -> Session:
- """Configures the session for GET and POST requests.
-
- It enables a generous retrial strategy that waits for the application server to connect.
-
- """
- retry_strategy = Retry(
- # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
- total=_CONNECTION_RETRY_TOTAL,
- backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
- status_forcelist=[429, 500, 502, 503, 504],
- )
- adapter = HTTPAdapter(max_retries=retry_strategy)
- http = requests.Session()
- http.mount("https://", adapter)
- http.mount("http://", adapter)
- return http
-
-
-def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bool:
- try:
- response = requests.get(url, timeout=timeout)
- return response.status_code in (200, 404)
- except (ConnectionError, ConnectTimeout, ReadTimeout):
- logger.debug(f"The url {url} is not ready. {metadata}")
- return False
-
-
-class CustomRetryAdapter(HTTPAdapter):
- def __init__(self, *args: Any, **kwargs: Any):
- self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
- super().__init__(*args, **kwargs)
-
- def send(self, request, **kwargs: Any):
- kwargs["timeout"] = kwargs.get("timeout", self.timeout)
- return super().send(request, **kwargs)
-
-
-def _http_method_logger_wrapper(func: Callable) -> Callable:
- """Returns the function decorated by a wrapper that logs the message using the `log_function` hook."""
-
- @wraps(func)
- def wrapped(self: "HTTPClient", *args: Any, **kwargs: Any) -> Any:
- message = f"HTTPClient: Method: {func.__name__.upper()}, Path: {args[0]}\n"
- message += f" Base URL: {self.base_url}\n"
- params = kwargs.get("query_params", {})
- if params:
- message += f" Params: {params}\n"
- resp: requests.Response = func(self, *args, **kwargs)
- message += f" Response: {resp.status_code} {resp.reason}"
- self.log_function(message)
- return resp
-
- return wrapped
-
-
-def _response(r, *args: Any, **kwargs: Any):
- return r.raise_for_status()
-
-
-class HTTPClient:
- """A wrapper class around the requests library which handles chores like logging, retries, and timeouts
- automatically."""
-
- def __init__(
- self, base_url: str, auth_token: Optional[str] = None, log_callback: Optional[Callable] = None
- ) -> None:
- self.base_url = base_url
- retry_strategy = Retry(
- # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
- # but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients
- # are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
- total=_CONNECTION_RETRY_TOTAL,
- backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
- status_forcelist=[
- 408, # Request Timeout
- 429, # Too Many Requests
- 500, # Internal Server Error
- 502, # Bad Gateway
- 503, # Service Unavailable
- 504, # Gateway Timeout
- ],
- )
- adapter = CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
- self.session = requests.Session()
-
- self.session.hooks = {"response": _response}
-
- self.session.mount("http://", adapter)
- self.session.mount("https://", adapter)
-
- if auth_token:
- self.session.headers.update({"Authorization": f"Bearer {auth_token}"})
-
- self.log_function = log_callback or self.log_function
-
- @_http_method_logger_wrapper
- def get(self, path: str):
- url = urljoin(self.base_url, path)
- return self.session.get(url)
-
- @_http_method_logger_wrapper
- def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None):
- url = urljoin(self.base_url, path)
- return self.session.post(url, data=data, params=query_params, json=json)
-
- @_http_method_logger_wrapper
- def delete(self, path: str):
- url = urljoin(self.base_url, path)
- return self.session.delete(url)
-
- def log_function(self, message: str, *args, **kwargs: Any):
- """This function is used to log the messages in the client, it can be overridden by caller to customise the
- logging logic.
-
- We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
- it is crucial for finding bugs when we have them
-
- """
- pass
diff --git a/src/lightning/app/utilities/openapi.py b/src/lightning/app/utilities/openapi.py
deleted file mode 100644
index f210c3cd47b04..0000000000000
--- a/src/lightning/app/utilities/openapi.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-from typing import Any, Dict
-
-
-def _duplicate_checker(js):
- """_duplicate_checker verifies that your JSON object doesn't contain duplicate keys."""
- result = {}
- for name, value in js:
- if name in result:
- raise ValueError(
- f"Unable to load JSON. A duplicate key {name} was detected. JSON objects must have unique keys."
- )
- result[name] = value
- return result
-
-
-def string2dict(text):
- """String2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
- if not isinstance(text, str):
- text = text.decode("utf-8")
- try:
- js = json.loads(text, object_pairs_hook=_duplicate_checker)
- return js
- except ValueError as ex:
- raise ValueError(f"Unable to load JSON: {str(ex)}.")
-
-
-def is_openapi(obj):
- """is_openopi checks if an object was generated by OpenAPI."""
- return hasattr(obj, "swagger_types")
-
-
-def create_openapi_object(json_obj: Dict, target: Any):
- """Create the OpenAPI object from the given JSON dict and based on the target object.
-
- Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
- object.
-
- """
- if not isinstance(json_obj, dict):
- raise TypeError("json_obj must be a dictionary")
- if not is_openapi(target):
- raise TypeError("target must be an openapi object")
-
- target_attribs = {}
- for key, value in json_obj.items():
- try:
- # user provided key is not a valid key on openapi object
- sub_target = getattr(target, key)
- except AttributeError:
- raise ValueError(f"Field {key} not found in the target object")
-
- if is_openapi(sub_target): # it's an openapi object
- target_attribs[key] = create_openapi_object(value, sub_target)
- else:
- target_attribs[key] = value
-
- # TODO(sherin) - specifically process list and dict and do the validation. Also do the
- # verification for enum types
-
- return target.__class__(**target_attribs)
diff --git a/src/lightning/app/utilities/packaging/__init__.py b/src/lightning/app/utilities/packaging/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning/app/utilities/packaging/app_config.py b/src/lightning/app/utilities/packaging/app_config.py
deleted file mode 100644
index 57177344a8fe0..0000000000000
--- a/src/lightning/app/utilities/packaging/app_config.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import pathlib
-from dataclasses import asdict, dataclass, field
-from typing import Union
-
-import yaml
-
-from lightning.app.utilities.name_generator import get_unique_name
-
-_APP_CONFIG_FILENAME = ".lightning"
-
-
-@dataclass
-class AppConfig:
- """The AppConfig holds configuration metadata for the application.
-
- Args:
- name: Optional name of the application. If not provided, auto-generates a new name.
-
- """
-
- name: str = field(default_factory=get_unique_name)
-
- def save_to_file(self, path: Union[str, pathlib.Path]) -> None:
- """Save the configuration to the given file in YAML format."""
- path = pathlib.Path(path)
- with open(path, "w") as file:
- yaml.dump(asdict(self), file)
-
- def save_to_dir(self, directory: Union[str, pathlib.Path]) -> None:
- """Save the configuration to a file '.lightning' to the given folder in YAML format."""
- self.save_to_file(_get_config_file(directory))
-
- @classmethod
- def load_from_file(cls, path: Union[str, pathlib.Path]) -> "AppConfig":
- """Load the configuration from the given file."""
- with open(path) as file:
- config = yaml.safe_load(file)
- return cls(**config)
-
- @classmethod
- def load_from_dir(cls, directory: Union[str, pathlib.Path]) -> "AppConfig":
- """Load the configuration from the given folder.
-
- Args:
- directory: Path to a folder which contains the '.lightning' config file to load.
-
- """
- return cls.load_from_file(pathlib.Path(directory, _APP_CONFIG_FILENAME))
-
-
-def _get_config_file(source_path: Union[str, pathlib.Path]) -> pathlib.Path:
- """Get the Lightning app config file '.lightning' at the given source path.
-
- Args:
- source_path: A path to a folder or a file.
-
- """
- source_path = pathlib.Path(source_path).absolute()
- if source_path.is_file():
- source_path = source_path.parent
-
- return pathlib.Path(source_path / _APP_CONFIG_FILENAME)
diff --git a/src/lightning/app/utilities/packaging/build_config.py b/src/lightning/app/utilities/packaging/build_config.py
deleted file mode 100644
index 05099fa1b4c77..0000000000000
--- a/src/lightning/app/utilities/packaging/build_config.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import os
-import re
-from dataclasses import asdict, dataclass, field
-from pathlib import Path
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
-
-from typing_extensions import Self
-
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-if TYPE_CHECKING:
- from lightning.app.core.work import LightningWork
-
-logger = Logger(__name__)
-
-
-def load_requirements(
- path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
-) -> List[str]:
- """Load requirements from a file."""
- path = os.path.join(path_dir, file_name)
- if not os.path.isfile(path):
- return []
-
- with open(path) as file:
- lines = [ln.strip() for ln in file.readlines()]
- reqs = []
- for ln in lines:
- # filer all comments
- comment = ""
- if comment_char in ln:
- comment = ln[ln.index(comment_char) :]
- ln = ln[: ln.index(comment_char)]
- req = ln.strip()
- # skip directly installed dependencies
- if not req or req.startswith("http") or "@http" in req:
- continue
- # remove version restrictions unless they are strict
- if unfreeze and "<" in req and "strict" not in comment:
- req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
- reqs.append(req)
- return reqs
-
-
-@dataclass
-class _Dockerfile:
- path: str
- data: List[str]
-
-
-@dataclass
-class BuildConfig:
- """The Build Configuration describes how the environment a LightningWork runs in should be set up.
-
- Arguments:
- requirements: List of requirements or list of paths to requirement files. If not passed, they will be
- automatically extracted from a `requirements.txt` if it exists.
- dockerfile: The path to a dockerfile to be used to build your container.
- You need to add those lines to ensure your container works in the cloud.
-
- .. warning:: This feature isn't supported yet, but coming soon.
-
- Example::
-
- WORKDIR /gridai/project
- COPY . .
- image: The base image that the work runs on. This should be a publicly accessible image from a registry that
- doesn't enforce rate limits (such as DockerHub) to pull this image, otherwise your application will not
- start.
-
- """
-
- requirements: List[str] = field(default_factory=list)
- dockerfile: Optional[Union[str, Path, _Dockerfile]] = None
- image: Optional[str] = None
-
- def __post_init__(self) -> None:
- current_frame = inspect.currentframe()
- co_filename = current_frame.f_back.f_back.f_code.co_filename # type: ignore[union-attr]
- self._call_dir = os.path.dirname(co_filename)
- self._prepare_requirements()
- self._prepare_dockerfile()
-
- def build_commands(self) -> List[str]:
- """Override to run some commands before your requirements are installed.
-
- .. note:: If you provide your own dockerfile, this would be ignored.
-
- Example:
-
- from dataclasses import dataclass
- from lightning.app import BuildConfig
-
- @dataclass
- class MyOwnBuildConfig(BuildConfig):
-
- def build_commands(self):
- return ["apt-get install libsparsehash-dev"]
-
- BuildConfig(requirements=["git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0"])
-
- """
- return []
-
- def on_work_init(self, work: "LightningWork", cloud_compute: Optional["CloudCompute"] = None) -> None:
- """Override with your own logic to load the requirements or dockerfile."""
- found_requirements = self._find_requirements(work)
- if self.requirements:
- if found_requirements and self.requirements != found_requirements:
- # notify the user of this silent behaviour
- logger.info(
- f"A 'requirements.txt' exists with {found_requirements} but {self.requirements} was passed to"
- f" the `{type(self).__name__}` in {work.name!r}. The `requirements.txt` file will be ignored."
- )
- else:
- self.requirements = found_requirements
- self._prepare_requirements()
-
- found_dockerfile = self._find_dockerfile(work)
- if self.dockerfile:
- if found_dockerfile and self.dockerfile != found_dockerfile:
- # notify the user of this silent behaviour
- logger.info(
- f"A Dockerfile exists at {found_dockerfile!r} but {self.dockerfile!r} was passed to"
- f" the `{type(self).__name__}` in {work.name!r}. {found_dockerfile!r}` will be ignored."
- )
- else:
- self.dockerfile = found_dockerfile
- self._prepare_dockerfile()
-
- def _find_requirements(self, work: "LightningWork", filename: str = "requirements.txt") -> List[str]:
- # 1. Get work file
- file = _get_work_file(work)
- if file is None:
- return []
- # 2. Try to find a requirement file associated the file.
- dirname = os.path.dirname(file)
- try:
- requirements = load_requirements(dirname, filename)
- except NotADirectoryError:
- return []
- return [r for r in requirements if r != "lightning"]
-
- def _find_dockerfile(self, work: "LightningWork", filename: str = "Dockerfile") -> Optional[str]:
- # 1. Get work file
- file = _get_work_file(work)
- if file is None:
- return None
- # 2. Check for Dockerfile.
- dirname = os.path.dirname(file)
- dockerfile = os.path.join(dirname, filename)
- if os.path.isfile(dockerfile):
- return dockerfile
- return None
-
- def _prepare_requirements(self) -> None:
- requirements = []
- for req in self.requirements:
- # 1. Check for relative path
- path = os.path.join(self._call_dir, req)
- if os.path.isfile(path):
- try:
- new_requirements = load_requirements(self._call_dir, req)
- except NotADirectoryError:
- continue
- requirements.extend(new_requirements)
- else:
- requirements.append(req)
- self.requirements = requirements
-
- def _prepare_dockerfile(self) -> None:
- if isinstance(self.dockerfile, (str, Path)):
- path = os.path.join(self._call_dir, self.dockerfile)
- if os.path.exists(path):
- with open(path) as f:
- self.dockerfile = _Dockerfile(path, f.readlines())
-
- def to_dict(self) -> Dict:
- return {"__build_config__": asdict(self)}
-
- @classmethod
- def from_dict(cls, d: Dict) -> Self:
- return cls(**d["__build_config__"])
-
-
-def _get_work_file(work: "LightningWork") -> Optional[str]:
- cls = work.__class__
- try:
- return inspect.getfile(cls)
- except TypeError:
- logger.debug(f"The {cls.__name__} file couldn't be found.")
- return None
diff --git a/src/lightning/app/utilities/packaging/cloud_compute.py b/src/lightning/app/utilities/packaging/cloud_compute.py
deleted file mode 100644
index e4f30aee14a63..0000000000000
--- a/src/lightning/app/utilities/packaging/cloud_compute.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import asdict, dataclass
-from typing import Dict, List, Optional, Tuple, Union
-from uuid import uuid4
-
-from lightning.app.core.constants import ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER, enable_interruptible_works
-from lightning.app.storage.mount import Mount
-
-__CLOUD_COMPUTE_IDENTIFIER__ = "__cloud_compute__"
-
-
-@dataclass
-class _CloudComputeStore:
- id: str
- component_names: List[str]
-
- def add_component_name(self, new_component_name: str) -> None:
- found_index = None
- # When the work is being named by the flow, pop its previous names
- for index, component_name in enumerate(self.component_names):
- if new_component_name.endswith(component_name.replace("root.", "")):
- found_index = index
-
- if found_index is not None:
- self.component_names[found_index] = new_component_name
- else:
- if (
- len(self.component_names) == 1
- and not ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER
- and self.id != "default"
- ):
- raise Exception(
- f"A Cloud Compute can be assigned only to a single Work. Attached to {self.component_names[0]}"
- )
- self.component_names.append(new_component_name)
-
- def remove(self, new_component_name: str) -> None:
- found_index = None
- for index, component_name in enumerate(self.component_names):
- if new_component_name == component_name:
- found_index = index
-
- if found_index is not None:
- del self.component_names[found_index]
-
-
-_CLOUD_COMPUTE_STORE = {}
-
-
-@dataclass
-class CloudCompute:
- """Configure the cloud runtime for a lightning work or flow.
-
- Arguments:
- name: The name of the hardware to use. A full list of supported options can be found in
- :doc:`/core_api/lightning_work/compute`. If you have a request for more hardware options, please contact
- `onprem@lightning.ai `_.
-
- disk_size: The disk size in Gigabytes.
- The value you set here will be allocated to the /home folder.
-
- idle_timeout: The number of seconds to wait before pausing the compute when the work is running and idle.
- This timeout starts whenever your run() method succeeds (or fails).
- If the timeout is reached, the instance pauses until the next run() call happens.
-
- shm_size: Shared memory size in MiB, backed by RAM. min 512, max 8192, it will auto update in steps of 512.
- For example 1100 will become 1024. If set to zero (the default) will get the default 64MiB inside docker.
-
- mounts: External data sources which should be mounted into a work as a filesystem at runtime.
-
- colocation_group_id: Identifier for groups of works to be colocated in the same datacenter.
- Set this to a string of max. 64 characters and all works with this group id will run in the same datacenter.
- If not set, the works are not guaranteed to be colocated.
-
- interruptible: Whether to run on a interruptible machine e.g the machine can be stopped
- at any time by the providers. This is also known as spot or preemptible machines.
- Compared to on-demand machines, they tend to be cheaper.
-
- """
-
- name: str = "default"
- disk_size: int = 0
- idle_timeout: Optional[int] = None
- shm_size: Optional[int] = None
- mounts: Optional[Union[Mount, List[Mount]]] = None
- colocation_group_id: Optional[str] = None
- interruptible: bool = False
- _internal_id: Optional[str] = None
-
- def __post_init__(self) -> None:
- _verify_mount_root_dirs_are_unique(self.mounts)
-
- self.name = self.name.lower()
-
- if self.shm_size is None:
- if "gpu" in self.name:
- self.shm_size = 1024
- else:
- self.shm_size = 0
-
- if self.interruptible:
- if not enable_interruptible_works():
- raise ValueError("CloudCompute with `interruptible=True` isn't supported yet.")
- if "gpu" not in self.name:
- raise ValueError("CloudCompute `interruptible=True` is supported only with GPU.")
-
- # FIXME: Clean the mess on the platform side
- if self.name == "default" or self.name == "cpu":
- self.name = "cpu-small"
- self._internal_id = "default"
-
- # TODO: Remove from the platform first.
- self.preemptible = self.interruptible
-
- # All `default` CloudCompute are identified in the same way.
- if self._internal_id is None:
- self._internal_id = self._generate_id()
-
- if self.colocation_group_id is not None and (
- not isinstance(self.colocation_group_id, str)
- or (isinstance(self.colocation_group_id, str) and len(self.colocation_group_id) > 64)
- ):
- raise ValueError("colocation_group_id can only be a string of maximum 64 characters.")
-
- def to_dict(self) -> dict:
- _verify_mount_root_dirs_are_unique(self.mounts)
- return {"type": __CLOUD_COMPUTE_IDENTIFIER__, **asdict(self)}
-
- @classmethod
- def from_dict(cls, d: dict) -> "CloudCompute":
- assert d.pop("type") == __CLOUD_COMPUTE_IDENTIFIER__
- mounts = d.pop("mounts", None)
- if mounts is None:
- pass
- elif isinstance(mounts, dict):
- d["mounts"] = Mount(**mounts)
- elif isinstance(mounts, (list)):
- d["mounts"] = []
- for mount in mounts:
- d["mounts"].append(Mount(**mount))
- else:
- raise TypeError(
- f"mounts argument must be one of [None, Mount, List[Mount]], "
- f"received {mounts} of type {type(mounts)}"
- )
- _verify_mount_root_dirs_are_unique(d.get("mounts"))
- return cls(**d)
-
- @property
- def id(self) -> Optional[str]:
- return self._internal_id
-
- def is_default(self) -> bool:
- return self.name in ("default", "cpu-small")
-
- def _generate_id(self):
- return "default" if self.name == "default" else uuid4().hex[:7]
-
- def clone(self):
- new_dict = self.to_dict()
- new_dict["_internal_id"] = self._generate_id()
- return self.from_dict(new_dict)
-
-
-def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
- if isinstance(mounts, (list, tuple, set)):
- mount_paths = [mount.mount_path for mount in mounts]
- if len(set(mount_paths)) != len(mount_paths):
- raise ValueError("Every Mount attached to a work must have a unique 'mount_path' argument.")
-
-
-def _maybe_create_cloud_compute(state: Dict) -> Union[CloudCompute, Dict]:
- if state and state.get("type") == __CLOUD_COMPUTE_IDENTIFIER__:
- return CloudCompute.from_dict(state)
- return state
diff --git a/src/lightning/app/utilities/packaging/docker.py b/src/lightning/app/utilities/packaging/docker.py
deleted file mode 100644
index 8a59fec288c7c..0000000000000
--- a/src/lightning/app/utilities/packaging/docker.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import pickle
-import shutil
-import sys
-from datetime import datetime
-from typing import Optional
-
-from lightning.app import _PROJECT_ROOT, LightningWork
-from lightning.app.storage.path import _shared_local_mount_path
-from lightning.app.utilities.imports import _is_docker_available, _is_jinja2_available, requires
-
-if _is_docker_available():
- import docker
- from docker.models.containers import Container
-
-if _is_jinja2_available():
- import jinja2
-
-
-class DockerRunner:
- @requires("docker")
- def __init__(self, file: str, work: LightningWork, queue_id: str, create_base: bool = False):
- self.file = file
- self.work = work
- self.queue_id = queue_id
- self.image: Optional[str] = None
- if create_base:
- self._create_base_container()
- self._create_work_container()
-
- def _create_base_container(self) -> None:
- # 1. Get base container
- container_base = f"{_PROJECT_ROOT}/dockers/Dockerfile.base.cpu"
- destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
-
- # 2. Copy the base Dockerfile within the Lightning project
- shutil.copy(container_base, destination_path)
-
- # 3. Build the docker image.
- os.system("docker build . --tag thomaschaton/base")
-
- # 4. Clean the copied base Dockerfile.
- os.remove(destination_path)
-
- def _create_work_container(self) -> None:
- # 1. Get work container.
- source_path = os.path.join(_PROJECT_ROOT, "dockers/Dockerfile.jinja")
- destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
- work_destination_path = os.path.join(_PROJECT_ROOT, "work.p")
-
- # 2. Pickle the work.
- with open(work_destination_path, "wb") as f:
- pickle.dump(self.work, f)
-
- # 3. Load Lightning requirements.
- with open(source_path) as f:
- template = jinja2.Template(f.read())
-
- # Get the work local build spec.
- requirements = self.work.local_build_config.requirements
-
- # Render template with the requirements.
- dockerfile_str = template.render(
- requirements=" ".join(requirements),
- redis_host="host.docker.internal" if sys.platform == "darwin" else "127.0.0.1",
- )
-
- with open(destination_path, "w") as f:
- f.write(dockerfile_str)
-
- # 4. Build the container.
- self.image = f"work-{self.work.__class__.__qualname__.lower()}"
- os.system(f"docker build . --tag {self.image}")
-
- # 5. Clean the leftover files.
- os.remove(destination_path)
- os.remove(work_destination_path)
-
- def run(self) -> None:
- assert self.image
-
- # 1. Run the work container and launch the work.
- client = docker.DockerClient(base_url="unix://var/run/docker.sock")
- self.container: Container = client.containers.run(
- image=self.image,
- shm_size="10G",
- stderr=True,
- stdout=True,
- stdin_open=True,
- detach=True,
- ports=[url.split(":")[-1] for url in self.work._urls if url],
- volumes=[f"{str(_shared_local_mount_path())}:/home/.shared"],
- command=f"python -m lightning run work {self.file} --work_name={self.work.name} --queue_id {self.queue_id}",
- environment={"SHARED_MOUNT_DIRECTORY": "/home/.shared"},
- network_mode="host",
- )
-
- # 2. Check the log and exit when ``Starting WorkRunner`` is found.
- for line in self.container.logs(stream=True):
- line = str(line.strip())
- print(line)
- if "command not found" in line:
- raise RuntimeError("The container wasn't properly executed.")
- if "Starting WorkRunner" in line:
- break
-
- def get_container_logs(self) -> str:
- """Returns the logs of the container produced until now."""
- return "".join([chr(c) for c in self.container.logs(until=datetime.now())])
-
- def kill(self) -> None:
- """Kill the container."""
- self.container.kill()
diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py
deleted file mode 100644
index 4dbba499625eb..0000000000000
--- a/src/lightning/app/utilities/packaging/lightning_utils.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-import logging
-import os
-import pathlib
-import shutil
-import subprocess
-import sys
-import tarfile
-import tempfile
-import urllib.request
-from pathlib import Path
-from typing import Any, Callable, Optional
-
-from packaging.version import Version
-
-from lightning.app import _PROJECT_ROOT, _logger, _root_logger
-from lightning.app import __version__ as version
-from lightning.app.core.constants import FRONTEND_DIR, PACKAGE_LIGHTNING
-from lightning.app.utilities.app_helpers import Logger
-from lightning.app.utilities.git import check_github_repository, get_dir_name
-
-logger = Logger(__name__)
-
-
-# FIXME(alecmerdler): Use GitHub release artifacts once the `lightning-ui` repo is public
-LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
-
-
-def download_frontend(root: str = _PROJECT_ROOT):
- """Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
- directory."""
- build_dir = "build"
- frontend_dir = pathlib.Path(FRONTEND_DIR)
- download_dir = tempfile.mkdtemp()
-
- shutil.rmtree(frontend_dir, ignore_errors=True)
-
- response = urllib.request.urlopen(LIGHTNING_FRONTEND_RELEASE_URL) # noqa: S310
-
- file = tarfile.open(fileobj=response, mode="r|gz")
- file.extractall(path=download_dir) # noqa: S202
-
- shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
- print("The Lightning UI has successfully been downloaded!")
-
-
-def _cleanup(*tar_files: str):
- for tar_file in tar_files:
- shutil.rmtree(os.path.join(_PROJECT_ROOT, "dist"), ignore_errors=True)
- os.remove(tar_file)
-
-
-def _prepare_wheel(path):
- with open("log.txt", "w") as logfile:
- with subprocess.Popen(
- ["rm", "-r", "dist"], stdout=logfile, stderr=logfile, bufsize=0, close_fds=True, cwd=path
- ) as proc:
- proc.wait()
-
- with subprocess.Popen(
- ["python", "setup.py", "sdist"],
- stdout=logfile,
- stderr=logfile,
- bufsize=0,
- close_fds=True,
- cwd=path,
- ) as proc:
- proc.wait()
-
- os.remove("log.txt")
-
-
-def _copy_tar(project_root, dest: Path) -> str:
- dist_dir = os.path.join(project_root, "dist")
- tar_files = os.listdir(dist_dir)
- assert len(tar_files) == 1
- tar_name = tar_files[0]
- tar_path = os.path.join(dist_dir, tar_name)
- shutil.copy(tar_path, dest)
- return tar_name
-
-
-def get_dist_path_if_editable_install(project_name) -> str:
- """Is distribution an editable install - modified version from pip that
- fetches egg-info instead of egg-link"""
- for path_item in sys.path:
- if not os.path.isdir(path_item):
- continue
-
- egg_info = os.path.join(path_item, project_name + ".egg-info")
- if os.path.isdir(egg_info):
- return path_item
- return ""
-
-
-def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "lightning") -> Optional[Callable]:
- """This function determines if lightning is installed in editable mode (for developers) and packages the current
- lightning source along with the app.
-
- For normal users who install via PyPi or Conda, then this function does not do anything.
-
- """
- if not get_dist_path_if_editable_install(package_name):
- return None
-
- os.environ["PACKAGE_NAME"] = "app" if package_name == "lightning" + "_app" else "lightning"
-
- # Packaging the Lightning codebase happens only inside the `lightning` repo.
- git_dir_name = get_dir_name() if check_github_repository() else None
-
- is_lightning = git_dir_name and git_dir_name == package_name
-
- if (PACKAGE_LIGHTNING is None and not is_lightning) or PACKAGE_LIGHTNING == "0":
- return None
-
- download_frontend(_PROJECT_ROOT)
- _prepare_wheel(_PROJECT_ROOT)
-
- # todo: check why logging.info is missing in outputs
- print(f"Packaged Lightning with your application. Version: {version}")
-
- tar_name = _copy_tar(_PROJECT_ROOT, root)
-
- tar_files = [os.path.join(root, tar_name)]
-
- # Don't skip by default
- if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))):
- # building and copying lightning-cloud wheel if installed in editable mode
- lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud")
- if lightning_cloud_project_path:
- from lightning_cloud.__version__ import __version__ as cloud_version
-
- # todo: check why logging.info is missing in outputs
- print(f"Packaged Lightning Cloud with your application. Version: {cloud_version}")
- _prepare_wheel(lightning_cloud_project_path)
- tar_name = _copy_tar(lightning_cloud_project_path, root)
- tar_files.append(os.path.join(root, tar_name))
-
- lightning_launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
- if lightning_launcher_project_path:
- from lightning_launcher.__version__ import __version__ as cloud_version
-
- # todo: check why logging.info is missing in outputs
- print(f"Packaged Lightning Launcher with your application. Version: {cloud_version}")
- _prepare_wheel(lightning_launcher_project_path)
- tar_name = _copy_tar(lightning_launcher_project_path, root)
- tar_files.append(os.path.join(root, tar_name))
-
- return functools.partial(_cleanup, *tar_files)
-
-
-def _enable_debugging():
- tar_file = os.path.join(os.getcwd(), f"lightning-{version}.tar.gz")
-
- if not os.path.exists(tar_file):
- return
-
- _root_logger.propagate = True
- _logger.propagate = True
- _root_logger.setLevel(logging.DEBUG)
- _root_logger.debug("Setting debugging mode.")
-
-
-def enable_debugging(func: Callable) -> Callable:
- """This function is used to transform any print into logger.info calls, so it gets tracked in the cloud."""
-
- @functools.wraps(func)
- def wrapper(*args: Any, **kwargs: Any) -> Any:
- _enable_debugging()
- res = func(*args, **kwargs)
- _logger.setLevel(logging.INFO)
- return res
-
- return wrapper
-
-
-def _fetch_latest_version(package_name: str) -> str:
- args = [
- sys.executable,
- "-m",
- "pip",
- "install",
- f"{package_name}==1000",
- ]
-
- proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=0, close_fds=True)
- if proc.stdout:
- logs = " ".join([line.decode("utf-8") for line in iter(proc.stdout.readline, b"")])
- return logs.split(")\n")[0].split(",")[-1].replace(" ", "")
- return version
-
-
-def _verify_lightning_version():
- """This function verifies that users are running the latest lightning version for the cloud."""
- # TODO (tchaton) Add support for windows
- if sys.platform == "win32":
- return
-
- lightning_latest_version = _fetch_latest_version("lightning")
-
- if Version(lightning_latest_version) > Version(version):
- raise Exception(
- f"You need to use the latest version of Lightning ({lightning_latest_version}) to run in the cloud. "
- "Please, run `pip install -U lightning`"
- )
diff --git a/src/lightning/app/utilities/packaging/tarfile.py b/src/lightning/app/utilities/packaging/tarfile.py
deleted file mode 100644
index d19c25e78403e..0000000000000
--- a/src/lightning/app/utilities/packaging/tarfile.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import shutil
-import tarfile
-
-
-def clean_tarfile(file_path: str, mode: str) -> None:
- """This utility removes all files extracted from a tarfile."""
- if not os.path.exists(file_path):
- return
-
- with tarfile.open(file_path, mode=mode) as tar_ref:
- for member in tar_ref.getmembers():
- p = member.path
- if p == "." or not os.path.exists(p):
- continue
- try:
- if os.path.isfile(p):
- os.remove(p)
- else:
- shutil.rmtree(p)
- except (FileNotFoundError, OSError, PermissionError):
- pass
-
- if os.path.exists(file_path):
- os.remove(file_path)
-
-
-def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None:
- """This utility extracts all files from a tarfile."""
- if not os.path.exists(file_path):
- return
-
- with tarfile.open(file_path, mode=mode) as tar_ref:
- for member in tar_ref.getmembers():
- try:
- tar_ref.extract(member, path=extract_path, set_attrs=False)
- except PermissionError:
- raise PermissionError(f"Could not extract tar file {file_path}")
diff --git a/src/lightning/app/utilities/port.py b/src/lightning/app/utilities/port.py
deleted file mode 100644
index d26036794f2ca..0000000000000
--- a/src/lightning/app/utilities/port.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import socket
-from typing import Optional
-
-from lightning_cloud.openapi import AppinstancesIdBody, Externalv1LightningappInstance, V1NetworkConfig
-
-from lightning.app.utilities.network import LightningClient, find_free_network_port
-
-
-def is_port_in_use(port: int) -> bool:
- """Checks if the given port is in use."""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex(("localhost", port)) == 0
-
-
-def _find_lit_app_port(default_port: int) -> int:
- """Make a request to the cloud controlplane to find a disabled port of the flow, enable it and return it."""
- app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)
- project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
- enable_multiple_works_in_default_container = bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
-
- if not app_id or not project_id or not enable_multiple_works_in_default_container:
- app_port = default_port
-
- # If the default port is not available, picks any other available one
- if is_port_in_use(default_port):
- app_port = find_free_network_port()
-
- return app_port
-
- client = LightningClient()
- list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
- lit_app: Optional[Externalv1LightningappInstance] = None
-
- for lapp in list_apps_resp.lightningapps:
- if lapp.id == app_id:
- lit_app = lapp
-
- if not lit_app:
- raise RuntimeError(
- "App was not found. Please open an issue at https://github.com/lightning-AI/lightning/issues."
- )
-
- found_nc = None
-
- for nc in lit_app.spec.network_config:
- if not nc.enable:
- found_nc = nc
- nc.enable = True
- break
-
- client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=lit_app.id,
- body=AppinstancesIdBody(name=lit_app.name, spec=lit_app.spec),
- )
-
- if not found_nc:
- raise RuntimeError(
- "No available port was found. Please open an issue at https://github.com/lightning-AI/lightning/issues."
- )
-
- # Note: This is required for the framework to know we need to use the CloudMultiProcessRuntime.
- os.environ["APP_SERVER_HOST"] = f"https://{found_nc.host}"
-
- return found_nc.port
-
-
-def enable_port() -> V1NetworkConfig:
- """Make a request to the cloud controlplane to open a port of the flow."""
- app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)
- project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
-
- if not app_id or not project_id:
- raise Exception("The app_id and project_id should be defined.")
-
- client = LightningClient()
- list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
- lit_app: Optional[Externalv1LightningappInstance] = None
-
- for lapp in list_apps_resp.lightningapps:
- if lapp.id == app_id:
- lit_app = lapp
-
- if not lit_app:
- raise RuntimeError(
- "App was not found. Please open an issue at https://github.com/lightning-AI/lightning/issues."
- )
-
- found_nc = None
-
- for nc in lit_app.spec.network_config:
- if not nc.enable:
- found_nc = nc
- nc.enable = True
- break
-
- client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=lit_app.id,
- body=AppinstancesIdBody(name=lit_app.name, spec=lit_app.spec),
- )
-
- if not found_nc:
- raise RuntimeError(
- "No available port was found. Please open an issue at https://github.com/lightning-AI/lightning/issues."
- )
-
- return found_nc
-
-
-def disable_port(port: int, ignore_disabled: bool = True) -> None:
- """Make a request to the cloud controlplane to close a port of the flow."""
- app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)
- project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
-
- if not app_id or not project_id:
- raise Exception("The app_id and project_id should be defined.")
-
- client = LightningClient()
- list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id)
- lit_app: Optional[Externalv1LightningappInstance] = None
-
- for lapp in list_apps_resp.lightningapps:
- if lapp.id == app_id:
- lit_app = lapp
-
- if not lit_app:
- raise RuntimeError(
- "App was not found. Please open an issue at https://github.com/lightning-AI/lightning/issues."
- )
-
- found_nc = None
-
- for nc in lit_app.spec.network_config:
- if nc.port == port:
- if not nc.enable and not ignore_disabled:
- raise RuntimeError(f"The port {port} was already disabled.")
-
- nc.enable = False
- found_nc = nc
- break
-
- client.lightningapp_instance_service_update_lightningapp_instance(
- project_id=project_id,
- id=lit_app.id,
- body=AppinstancesIdBody(name=lit_app.name, spec=lit_app.spec),
- )
-
- if not found_nc:
- ports = [nc.port for nc in lit_app.spec.network_config]
- raise ValueError(f"The provided port doesn't exists. Available ports are {ports}.")
-
- assert found_nc
diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py
deleted file mode 100644
index bce7d661cbb03..0000000000000
--- a/src/lightning/app/utilities/proxies.py
+++ /dev/null
@@ -1,745 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import pathlib
-import queue
-import signal
-import sys
-import threading
-import time
-import traceback
-import warnings
-from contextlib import contextmanager
-from copy import deepcopy
-from dataclasses import dataclass, field
-from functools import partial
-from threading import Event, Thread
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, Union
-
-from deepdiff import DeepDiff, Delta
-from lightning_utilities.core.apply_func import apply_to_collection
-
-from lightning.app.core import constants
-from lightning.app.core.queues import MultiProcessQueue
-from lightning.app.storage.copier import _Copier, _copy_files
-from lightning.app.storage.drive import Drive, _maybe_create_drive
-from lightning.app.storage.path import Path, _path_to_work_artifact
-from lightning.app.storage.payload import Payload
-from lightning.app.utilities.app_helpers import affiliation
-from lightning.app.utilities.component import _set_work_context
-from lightning.app.utilities.enum import (
- CacheCallsKeys,
- WorkFailureReasons,
- WorkStageStatus,
- WorkStopReasons,
- make_status,
-)
-from lightning.app.utilities.exceptions import CacheMissException, LightningSigtermStateException
-
-if TYPE_CHECKING:
- from lightning.app.core import LightningWork
- from lightning.app.core.queues import BaseQueue
-
-from lightning.app.utilities.app_helpers import Logger
-
-logger = Logger(__name__)
-_state_observer_lock = threading.Lock()
-
-
-@dataclass
-class Action:
- method: str = "run"
- args: Tuple = field(default_factory=lambda: ())
- kwargs: Dict = field(default_factory=lambda: {})
-
-
-def unwrap(fn):
- if isinstance(fn, partial):
- fn = fn.keywords["work_run"]
- if isinstance(fn, ProxyWorkRun):
- fn = fn.work_run
- while hasattr(fn, "__wrapped__"):
- fn = fn.__wrapped__
- return fn
-
-
-def _send_data_to_caller_queue(
- proxy, work: "LightningWork", caller_queue: "BaseQueue", data: Dict, call_hash: str
-) -> Dict:
- proxy.has_sent = True
-
- if work._calls[CacheCallsKeys.LATEST_CALL_HASH] is None:
- work._calls[CacheCallsKeys.LATEST_CALL_HASH] = call_hash
-
- if call_hash not in work._calls:
- work._calls[call_hash] = {"statuses": []}
- else:
- # remove ret when relaunching the work.
- work._calls[call_hash].pop("ret", None)
-
- work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.PENDING))
-
- work_state = work.state
-
- # There is no need to send all call hashes to the work.
- calls = deepcopy(work_state["calls"])
- work_state["calls"] = {
- k: v for k, v in work_state["calls"].items() if k in (call_hash, CacheCallsKeys.LATEST_CALL_HASH)
- }
-
- data.update({"state": work_state})
- logger.debug(f"Sending to {work.name}: {data}")
- caller_queue.put(deepcopy(data))
-
- # Reset the calls entry.
- work_state["calls"] = calls
- work._restarting = False
- return work_state
-
-
-@dataclass
-class ProxyWorkRun:
- work_run: Callable
- work_name: str # TODO: remove this argument and get the name from work.name directly
- work: "LightningWork"
- caller_queue: "BaseQueue"
-
- def __post_init__(self):
- self.work_state = None
-
- def __call__(self, *args: Any, **kwargs: Any):
- self.has_sent = False
-
- self._validate_call_args(args, kwargs)
- args, kwargs = self._process_call_args(args, kwargs)
-
- call_hash = self.work._call_hash(self.work_run, *self._convert_hashable(args, kwargs))
- entered = call_hash in self.work._calls
- returned = entered and "ret" in self.work._calls[call_hash]
- # TODO (tchaton): Handle spot instance retrieval differently from stopped work.
- stopped_on_sigterm = self.work._restarting and self.work.status.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER
-
- data = {"args": args, "kwargs": kwargs, "call_hash": call_hash}
-
- # The if/else conditions are left un-compressed to simplify readability for the readers.
- if not entered or stopped_on_sigterm:
- _send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
- else:
- if self.work.cache_calls and returned:
- return
- if returned or stopped_on_sigterm:
- # the previous task has completed and we can re-queue the next one.
- # overriding the return value for next loop iteration.
- _send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
- if not self.work.parallel:
- raise CacheMissException("Task never called before. Triggered now")
-
- def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
- """Validate the call args before they get passed to the run method of the Work.
-
- Currently, this performs a check against strings that look like filesystem paths and may need to be wrapped with
- a Lightning Path by the user.
-
- """
-
- def warn_if_pathlike(obj: Union[os.PathLike, str]):
- if isinstance(obj, Path):
- return
- if os.sep in str(obj) and os.path.exists(obj):
- # NOTE: The existence check is wrong in general, as the file will never exist on the disk
- # where the flow is running unless we are running locally
- warnings.warn(
- f"You passed a the value {obj!r} as an argument to the `run()` method of {self.work_name} and"
- f" it looks like this is a path to a file or a folder. Consider wrapping this path in a"
- f" `lightning.app.storage.Path` object to be able to access these files in your Work.",
- UserWarning,
- )
-
- apply_to_collection(args, dtype=(os.PathLike, str), function=warn_if_pathlike)
- apply_to_collection(kwargs, dtype=(os.PathLike, str), function=warn_if_pathlike)
-
- @staticmethod
- def _process_call_args(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the
- LightningWork.
-
- Currently, this method only applies sanitization to Lightning Path objects.
-
- Args:
- args: The tuple of positional arguments passed to the run method.
- kwargs: The dictionary of named arguments passed to the run method.
-
- Returns:
- The positional and keyword arguments in the same order they were passed in.
-
- """
-
- def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]:
- if isinstance(obj, Path):
- # create a copy of the Path and erase the consumer
- # the LightningWork on the receiving end of the caller queue will become the new consumer
- # this is necessary to make the Path deepdiff-hashable
- path_copy = Path(obj)
- path_copy._sanitize()
- path_copy._consumer = None
- return path_copy
- return obj.to_dict()
-
- return apply_to_collection((args, kwargs), dtype=(Path, Drive), function=sanitize)
-
- @staticmethod
- def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the
- LightningWork.
-
- Currently, this method only applies sanitization to Hashable Objects.
-
- Args:
- args: The tuple of positional arguments passed to the run method.
- kwargs: The dictionary of named arguments passed to the run method.
-
- Returns:
- The positional and keyword arguments in the same order they were passed in.
-
- """
- from lightning.app.utilities.types import Hashable
-
- def sanitize(obj: Hashable) -> Union[Path, Dict]:
- return obj.to_dict()
-
- return apply_to_collection((args, kwargs), dtype=Hashable, function=sanitize)
-
-
-class WorkStateObserver(Thread):
- """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed from
- one interval to the next, it will compute the delta and add it to the queue which is connected to the Flow. This
- enables state changes to be captured that are not triggered through a setattr call.
-
- Args:
- work: The LightningWork for which the state should be monitored
- delta_queue: The queue to send deltas to when state changes occur
- interval: The interval at which to check for state changes.
-
- Example:
-
- class Work(LightningWork):
- ...
-
- def run(self):
- # This update gets sent to the Flow once the thread compares the new state with the previous one
- self.list.append(1)
-
- """
-
- def __init__(
- self,
- work: "LightningWork",
- delta_queue: "BaseQueue",
- flow_to_work_delta_queue: Optional["BaseQueue"] = None,
- error_queue: Optional["BaseQueue"] = None,
- interval: float = 1,
- ) -> None:
- super().__init__(daemon=True)
- self.started = False
- self._work = work
- self._delta_queue = delta_queue
- self._flow_to_work_delta_queue = flow_to_work_delta_queue
- self._error_queue = error_queue
- self._interval = interval
- self._exit_event = Event()
- self._delta_memory = []
- self._last_state = deepcopy(self._work.state)
-
- def run(self) -> None:
- self.started = True
- while not self._exit_event.is_set():
- time.sleep(self._interval)
- # Run the thread only if active
- self.run_once()
-
- @staticmethod
- def get_state_changed_from_queue(q: "BaseQueue", timeout: Optional[int] = None):
- try:
- delta = q.get(timeout=timeout or q.default_timeout)
- return delta
- except queue.Empty:
- return None
-
- def run_once(self) -> None:
- with _state_observer_lock:
- # Add all deltas the LightningWorkSetAttrProxy has processed and sent to the Flow already while
- # the WorkStateObserver was sleeping
- for delta in self._delta_memory:
- self._last_state += delta
- self._delta_memory.clear()
-
- # The remaining delta is the result of state updates triggered outside the setattr, e.g, by a list append
- delta = Delta(DeepDiff(self._last_state, self._work.state, verbose_level=2))
- if not delta.to_dict():
- return
- self._last_state = deepcopy(self._work.state)
- self._delta_queue.put(ComponentDelta(id=self._work.name, delta=delta))
-
- if self._flow_to_work_delta_queue:
- while True:
- deep_diff = self.get_state_changed_from_queue(self._flow_to_work_delta_queue)
- if not isinstance(deep_diff, dict):
- break
- try:
- with _state_observer_lock:
- self._work.apply_flow_delta(Delta(deep_diff, raise_errors=True))
- except Exception as ex:
- print(traceback.print_exc())
- self._error_queue.put(ex)
- raise ex
-
- def join(self, timeout: Optional[float] = None) -> None:
- self._exit_event.set()
- super().join(timeout)
-
-
-@dataclass
-class LightningWorkSetAttrProxy:
- """This wrapper around the ``LightningWork.__setattr__`` ensures that state changes get sent to the delta queue to
- be reflected in the Flow.
-
- Example:
-
- class Work(LightningWork):
- ...
-
- def run(self):
- self.var += 1 # This update gets sent to the Flow immediately
-
- """
-
- work_name: str
- work: "LightningWork"
- delta_queue: "BaseQueue"
- state_observer: Optional["WorkStateObserver"]
-
- def __call__(self, name: str, value: Any) -> None:
- logger.debug(f"Setting {name}: {value}")
- with _state_observer_lock:
- state = deepcopy(self.work.state)
- self.work._default_setattr(name, value)
- delta = Delta(DeepDiff(state, self.work.state, verbose_level=2))
- if not delta.to_dict():
- return
-
- # push the delta only if there is any
- self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta))
-
- # add the delta to the buffer to let WorkStateObserver know we already sent this one to the Flow
- if self.state_observer:
- self.state_observer._delta_memory.append(delta)
-
-
-@dataclass
-class ComponentDelta:
- id: str
- delta: Delta
-
-
-@dataclass
-class WorkRunExecutor:
- work: "LightningWork"
- work_run: Callable
- delta_queue: "BaseQueue"
- enable_start_observer: bool = True
-
- def __call__(self, *args, **kwargs):
- return self.work_run(*args, **kwargs)
-
- @contextmanager
- def enable_spawn(self) -> Generator:
- self.work._setattr_replacement = None
- self.work._backend = None
- self._clean_queues()
- yield
-
- def _clean_queues(self):
- if not isinstance(self.work._request_queue, MultiProcessQueue):
- self.work._request_queue = self.work._request_queue.to_dict()
- self.work._response_queue = self.work._response_queue.to_dict()
-
- @staticmethod
- def process_queue(queue):
- from lightning.app.core.queues import HTTPQueue, RedisQueue
-
- if isinstance(queue, dict):
- queue_type = queue.pop("type")
- if queue_type == "redis":
- return RedisQueue.from_dict(queue)
- return HTTPQueue.from_dict(queue)
- return queue
-
-
-@dataclass
-class WorkRunner:
- work: "LightningWork"
- work_name: str
- caller_queue: "BaseQueue"
- delta_queue: "BaseQueue"
- readiness_queue: "BaseQueue"
- error_queue: "BaseQueue"
- request_queue: "BaseQueue"
- response_queue: "BaseQueue"
- copy_request_queue: "BaseQueue"
- copy_response_queue: "BaseQueue"
- flow_to_work_delta_queue: Optional["BaseQueue"] = None
- run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
-
- def __post_init__(self):
- self.parallel = self.work.parallel
- self.copier: Optional[_Copier] = None
- self.state_observer: Optional[WorkStateObserver] = None
-
- def __call__(self):
- self.setup()
- while True:
- try:
- self.run_once()
- except KeyboardInterrupt:
- if self.state_observer:
- if self.state_observer.started:
- self.state_observer.join(0)
- self.state_observer = None
- self.copier.join(0)
- except LightningSigtermStateException as ex:
- logger.debug("Exiting")
- os._exit(ex.exit_code)
- except Exception as ex:
- # Inform the flow the work failed. This would fail the entire application.
- self.error_queue.put(ex)
- # Terminate the threads
- if self.state_observer:
- if self.state_observer.started:
- self.state_observer.join(0)
- self.state_observer = None
- self.copier.join(0)
- raise ex
-
- def setup(self):
- from lightning.app.utilities.state import AppState
-
- _set_work_context()
-
- # 1. Make the AppState aware of the affiliation of the work.
- # hacky: attach affiliation to be know from the AppState object
- AppState._MY_AFFILIATION = affiliation(self.work)
-
- # 2. Attach the queues to the work.
- # Each work gets their own request- and response storage queue for communicating with the storage orchestrator
- self.work._request_queue = self.request_queue
- self.work._response_queue = self.response_queue
-
- # 3. Starts the Copier thread. This thread enables transfering files using
- # the Path object between works.
- self.copier = _Copier(self.work, self.copy_request_queue, self.copy_response_queue)
- self.copier.setDaemon(True)
- self.copier.start()
-
- # 4. If the work is restarting, reload the latest state.
- # TODO (tchaton) Add support for capturing the latest state.
- if self.work._restarting:
- self.work.load_state_dict(self.work.state)
-
- # 5. Inform the flow that the work is ready to receive data through the caller queue.
- self.readiness_queue.put(True)
-
- def run_once(self):
- # 1. Wait for the caller queue data.
- called: Dict[str, Any] = self.caller_queue.get()
- logger.debug(f"Work {self.work_name} {called}")
-
- # 2. Extract the info from the caller queue data and process the input arguments. Arguments can contain
- # Lightning Path objects; if they don't have a consumer, the current Work will become one.
- call_hash = called["call_hash"]
- args, kwargs = self._process_call_args(called["args"], called["kwargs"])
-
- # 3. Register the signal handler for spot instances.
- # `SIGUSR1` signal isn't supported on windows.
- # TODO (tchaton) Add support for windows
- if sys.platform != "win32":
- signal.signal(signal.SIGTERM, partial(self._sigterm_signal_handler, call_hash=call_hash))
-
- # 4. Set the received state to the work.
- self.work.set_state(called["state"])
-
- # 5. Transfer all paths in the state automatically if they have an origin and exist
- self._transfer_path_attributes()
-
- # 6. Create the state observer thread.
- if self.run_executor_cls.enable_start_observer:
- self.state_observer = WorkStateObserver(
- self.work,
- delta_queue=self.delta_queue,
- flow_to_work_delta_queue=self.flow_to_work_delta_queue,
- error_queue=self.error_queue,
- )
-
- # 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
- reference_state = deepcopy(self.work.state)
-
- # Set the internal IP address.
- # Set this here after the state observer is initialized, since it needs to record it as a change and send
- # it back to the flow
- default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104
- self.work._internal_ip = os.environ.get("LIGHTNING_NODE_PRIVATE_IP", default_internal_ip)
- self.work._public_ip = os.environ.get("LIGHTNING_NODE_IP", "")
-
- # 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
- # send delta while calling `set_state`.
- self._proxy_setattr()
-
- if self._is_starting(called, reference_state, call_hash):
- return
-
- # 9. Inform the flow the work is running and add the delta to the deepcopy.
- self.work._calls[CacheCallsKeys.LATEST_CALL_HASH] = call_hash
- self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.RUNNING))
- delta = Delta(DeepDiff(reference_state, self.work.state))
- self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta))
-
- # 10. Unwrap the run method if wrapped.
- work_run = self.work.run
- if hasattr(work_run, "__wrapped__"):
- work_run = work_run.__wrapped__
-
- # 11. Start the state observer thread. It will look for state changes and send them back to the Flow
- # The observer has to be initialized here, after the set_state call above so that the thread can start with
- # the proper initial state of the work
- if self.run_executor_cls.enable_start_observer:
- self.state_observer.start()
-
- # 12. Run the `work_run` method.
- # If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook.
- try:
- ret = self.run_executor_cls(self.work, work_run, self.delta_queue)(*args, **kwargs)
- except LightningSigtermStateException as ex:
- raise ex
- except BaseException as ex:
- # 10.2 Send failed delta to the flow.
- reference_state = deepcopy(self.work.state)
- exp, val, tb = sys.exc_info()
- listing = traceback.format_exception(exp, val, tb)
- user_exception = False
- used_runpy = False
- trace = []
- for p in listing:
- if "runpy.py" in p:
- trace = []
- used_runpy = True
- if user_exception:
- trace.append(p)
- if "ret = self.run_executor_cls(" in p:
- user_exception = True
-
- if used_runpy:
- trace = trace[1:]
-
- self.work._calls[call_hash]["statuses"].append(
- make_status(
- WorkStageStatus.FAILED,
- message=str("\n".join(trace)),
- reason=WorkFailureReasons.USER_EXCEPTION,
- )
- )
- self.delta_queue.put(
- ComponentDelta(
- id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state, verbose_level=2))
- )
- )
- self.work.on_exception(ex)
- print("########## CAPTURED EXCEPTION ###########")
- print(traceback.print_exc())
- print("########## CAPTURED EXCEPTION ###########")
- return
-
- # 13. Destroy the state observer.
- if self.run_executor_cls.enable_start_observer and self.state_observer.started:
- self.state_observer.join(0)
- self.state_observer = None
-
- # 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down
- persist_artifacts(work=self.work)
-
- # 15. An asynchronous work shouldn't return a return value.
- if ret is not None:
- raise RuntimeError(
- f"Your work {self.work} shouldn't have a return value. Found {ret}."
- "HINT: Use the Payload API instead."
- )
-
- # 17. DeepCopy the state and send the latest delta to the flow.
- # use the latest state as we have already sent delta
- # during its execution.
- # inform the task has completed
- reference_state = deepcopy(self.work.state)
- self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.SUCCEEDED))
- self.work._calls[call_hash]["ret"] = ret
- self.delta_queue.put(
- ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state, verbose_level=2)))
- )
-
- # 18. Update the work for the next delta if any.
- self._proxy_setattr(cleanup=True)
-
- def _sigterm_signal_handler(self, signum, frame, call_hash: str) -> None:
- """Signal handler used to react when spot instances are being retrived."""
- logger.info(f"Received SIGTERM signal. Gracefully terminating {self.work.name.replace('root.', '')}...")
- persist_artifacts(work=self.work)
- with _state_observer_lock:
- self.work.on_exit()
- self.work._calls[call_hash]["statuses"] = []
- state = deepcopy(self.work.state)
- self.work._calls[call_hash]["statuses"].append(
- make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.SIGTERM_SIGNAL_HANDLER)
- )
-
- # kill the thread as the job is going to be terminated.
- if self.state_observer:
- if self.state_observer.started:
- self.state_observer.join(0)
- self.state_observer = None
- delta = Delta(DeepDiff(state, deepcopy(self.work.state), verbose_level=2))
- self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta))
-
- self.copier.join(0)
- raise LightningSigtermStateException(0)
-
- def _proxy_setattr(self, cleanup: bool = False):
- _proxy_setattr(self.work, self.delta_queue, self.state_observer, cleanup=cleanup)
-
- def _process_call_args(
- self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """Process the arguments that were passed in to the ``run()`` method of the
- :class:`lightning.app.core.work.LightningWork`.
-
- This method currently only implements special treatments for the :class:`lightning.app.storage.path.Path`
- objects. Any Path objects that get passed into the run method get attached to the Work automatically, i.e.,
- the Work becomes the `origin` or the `consumer` if they were not already before. Additionally,
- if the file or folder under the Path exists, we transfer it.
-
- Args:
- args: The tuple of positional arguments passed to the run method.
- kwargs: The dictionary of named arguments passed to the run method.
-
- Returns:
- The positional and keyword arguments in the same order they were passed in.
-
- """
-
- def _attach_work_and_get(transporter: Union[Path, Payload, dict]) -> Union[Path, Drive, dict, Any]:
- if not transporter.origin_name:
- # If path/payload is not attached to an origin, there is no need to attach or transfer anything
- return transporter
-
- transporter._attach_work(self.work)
- transporter._attach_queues(self.work._request_queue, self.work._response_queue)
- if transporter.exists_remote():
- # All paths/payloads passed to the `run` method under a Lightning obj need to be copied (if they exist)
- if isinstance(transporter, Payload):
- transporter.get()
- else:
- transporter.get(overwrite=True)
- return transporter
-
- def _handle_drive(dict):
- return _maybe_create_drive(self.work_name, dict)
-
- args, kwargs = apply_to_collection((args, kwargs), dtype=(Path, Payload), function=_attach_work_and_get)
- return apply_to_collection((args, kwargs), dtype=dict, function=_handle_drive)
-
- def _transfer_path_attributes(self) -> None:
- """Transfer all Path attributes in the Work if they have an origin and exist."""
- for name in self.work._paths:
- path = getattr(self.work, name)
- if isinstance(path, str):
- path = Path(path)
- path._attach_work(self.work)
- if path.origin_name and path.origin_name != self.work.name and path.exists_remote():
- path.get(overwrite=True)
-
- def _is_starting(self, called, reference_state, call_hash) -> bool:
- if len(called["args"]) == 1 and isinstance(called["args"][0], Action):
- action = called["args"][0]
- if action.method == "start":
- # 9. Inform the flow the work is running and add the delta to the deepcopy.
- self.work._calls[CacheCallsKeys.LATEST_CALL_HASH] = call_hash
- self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.STARTED))
- delta = Delta(DeepDiff(reference_state, self.work.state))
- self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta))
- self._proxy_setattr(cleanup=True)
- return True
- raise Exception("Only the `start` action is supported right now !")
- return False
-
-
-def persist_artifacts(work: "LightningWork") -> None:
- """Copies all :class:`~lightning.app.storage.path.Path` referenced by the given LightningWork to the shared
- storage.
-
- Files that don't exist or do not originate from the given Work will be skipped.
-
- """
- artifact_paths = [getattr(work, name) for name in work._paths]
- # only copy files that belong to this Work, i.e., when the path's origin refers to the current Work
- artifact_paths = [path for path in artifact_paths if isinstance(path, Path) and path.origin_name == work.name]
-
- for name in work._state:
- if isinstance(getattr(work, name), Payload):
- artifact_path = pathlib.Path(name).resolve()
- payload = getattr(work, name)
- payload.save(payload.value, artifact_path)
- artifact_paths.append(artifact_path)
-
- missing_artifacts: Set[str] = set()
- destination_paths = []
- for artifact_path in artifact_paths:
- artifact_path = pathlib.Path(artifact_path).absolute()
- if not artifact_path.exists():
- missing_artifacts.add(str(artifact_path))
- continue
- destination_path = _path_to_work_artifact(artifact_path, work)
- _copy_files(artifact_path, destination_path)
- destination_paths.append(destination_path)
-
- if missing_artifacts:
- warnings.warn(
- f"{len(missing_artifacts)} artifacts could not be saved because they don't exist:"
- f" {','.join(missing_artifacts)}.",
- UserWarning,
- )
- else:
- logger.debug(
- f"All {destination_paths} artifacts from Work {work.name} successfully "
- "stored at {artifacts_path(work.name)}."
- )
-
-
-def _proxy_setattr(work, delta_queue, state_observer: Optional[WorkStateObserver], cleanup: bool = False):
- if cleanup:
- setattr_proxy = None
- else:
- setattr_proxy = LightningWorkSetAttrProxy(
- work.name,
- work,
- delta_queue=delta_queue,
- state_observer=state_observer,
- )
- work._setattr_replacement = setattr_proxy
diff --git a/src/lightning/app/utilities/redis.py b/src/lightning/app/utilities/redis.py
deleted file mode 100644
index 5461d5d302ae3..0000000000000
--- a/src/lightning/app/utilities/redis.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Optional
-
-from lightning.app.core.constants import REDIS_HOST, REDIS_PASSWORD, REDIS_PORT
-from lightning.app.utilities.imports import _is_redis_available
-
-
-def check_if_redis_running(
- host: Optional[str] = "", port: Optional[int] = 6379, password: Optional[str] = None
-) -> bool:
- if not _is_redis_available():
- return False
- import redis
-
- try:
- host = host or REDIS_HOST
- port = port or REDIS_PORT
- password = password or REDIS_PASSWORD
- return redis.Redis(host=host, port=port, password=password).ping()
- except redis.exceptions.ConnectionError:
- return False
diff --git a/src/lightning/app/utilities/safe_pickle.py b/src/lightning/app/utilities/safe_pickle.py
deleted file mode 100644
index ddd77ddcc6509..0000000000000
--- a/src/lightning/app/utilities/safe_pickle.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import contextlib
-import pickle
-import sys
-import types
-import typing
-from copy import deepcopy
-from pathlib import Path
-
-from lightning.app.core.work import LightningWork
-from lightning.app.utilities.app_helpers import _LightningAppRef
-
-NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"]
-
-
-@contextlib.contextmanager
-def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]:
- """Context manager to trim the work object to remove attributes that are not picklable."""
- holder = {}
- for arg in to_trim:
- holder[arg] = getattr(work, arg)
- setattr(work, arg, None)
- yield
- for arg in to_trim:
- setattr(work, arg, holder[arg])
-
-
-def get_picklable_work(work: LightningWork) -> LightningWork:
- """Pickling a LightningWork instance fails if done from the work process
- itself. This function is safe to call from the work process within both MultiprocessRuntime
- and Cloud.
- Note: This function modifies the module information of the work object. Specifically, it injects
- the relative module path into the __module__ attribute of the work object. If the object is not
- importable from the CWD, then the pickle load will fail.
-
- Example:
- for a directory structure like below and the work class is defined in the app.py where
- the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the
- __module__ attribute
-
- └── foo
- ├── __init__.py
- └── bar
- └── app.py
- """
- # If the work object not taken from the app ref, there is a thread lock reference
- # somewhere thats preventing it from being pickled. Investigate it later. We
- # shouldn't be fetching the work object from the app ref. TODO @sherin
- app_ref = _LightningAppRef.get_current()
- if app_ref is None:
- raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp")
- for w in app_ref.works:
- if work.name == w.name:
- # deep-copying the work object to avoid modifying the original work object
- with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES):
- copied_work = deepcopy(w)
- break
- else:
- raise ValueError(f"Work with name {work.name} not found in the app references")
-
- # if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command),
- # pickling/unpickling will fail, hence we need patch the module information
- if "_main__" in copied_work.__class__.__module__:
- work_class_module = sys.modules[copied_work.__class__.__module__]
- work_class_file = work_class_module.__file__
- if not work_class_file:
- raise ValueError(
- f"Cannot pickle work class {copied_work.__class__.__name__} because we "
- f"couldn't identify the module file"
- )
- relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore
- expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".")
- # TODO @sherin: also check if the module is importable from the CWD
- fake_module = types.ModuleType(expected_module_name)
- fake_module.__dict__.update(work_class_module.__dict__)
- fake_module.__dict__["__name__"] = expected_module_name
- sys.modules[expected_module_name] = fake_module
- for k, v in fake_module.__dict__.items():
- if not k.startswith("__") and hasattr(v, "__module__") and "_main__" in v.__module__:
- v.__module__ = expected_module_name
- return copied_work
-
-
-def dump(work: LightningWork, f: typing.BinaryIO) -> None:
- picklable_work = get_picklable_work(work)
- pickle.dump(picklable_work, f)
-
-
-def load(f: typing.BinaryIO) -> typing.Any:
- # inject current working directory to sys.path
- sys.path.insert(1, str(Path.cwd()))
- work = pickle.load(f)
- sys.path.pop(1)
- return work
diff --git a/src/lightning/app/utilities/scheduler.py b/src/lightning/app/utilities/scheduler.py
deleted file mode 100644
index 67b081fb56000..0000000000000
--- a/src/lightning/app/utilities/scheduler.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import threading
-from datetime import datetime
-from typing import Optional
-
-from croniter import croniter
-from deepdiff import Delta
-
-from lightning.app.utilities.proxies import ComponentDelta
-
-
-class SchedulerThread(threading.Thread):
- # TODO (tchaton) Abstract this logic to a generic scheduling service.
-
- def __init__(self, app) -> None:
- super().__init__(daemon=True)
- self._exit_event = threading.Event()
- self._sleep_time = 1.0
- self._app = app
-
- def run(self) -> None:
- while not self._exit_event.is_set():
- self._exit_event.wait(self._sleep_time)
- self.run_once()
-
- def run_once(self):
- for call_hash in list(self._app._schedules.keys()):
- metadata = self._app._schedules[call_hash]
- start_time = datetime.fromisoformat(metadata["start_time"])
- current_date = datetime.now()
- next_event = croniter(metadata["cron_pattern"], start_time).get_next(datetime)
- # When the event is reached, send a delta to activate scheduling.
- if current_date > next_event:
- component_delta = ComponentDelta(
- id=metadata["name"],
- delta=Delta({
- "values_changed": {
- f"root['calls']['scheduling']['{call_hash}']['running']": {"new_value": True}
- }
- }),
- )
- self._app.delta_queue.put(component_delta)
- metadata["start_time"] = next_event.isoformat()
-
- def join(self, timeout: Optional[float] = None) -> None:
- self._exit_event.set()
- super().join(timeout)
diff --git a/src/lightning/app/utilities/secrets.py b/src/lightning/app/utilities/secrets.py
deleted file mode 100644
index dee96c5d163a9..0000000000000
--- a/src/lightning/app/utilities/secrets.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict, Iterable
-
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.network import LightningClient
-
-
-def _names_to_ids(secret_names: Iterable[str]) -> Dict[str, str]:
- """Returns the name/ID pair for each given Secret name.
-
- Raises a `ValueError` if any of the given Secret names do not exist.
-
- """
- lightning_client = LightningClient()
-
- project = _get_project(lightning_client)
- secrets = lightning_client.secret_service_list_secrets(project_id=project.project_id)
-
- secret_names_to_ids: Dict[str, str] = {}
- for secret in secrets.secrets:
- if secret.name in secret_names:
- secret_names_to_ids[secret.name] = secret.id
-
- for secret_name in secret_names:
- if secret_name not in secret_names_to_ids:
- raise ValueError(f"Secret with name '{secret_name}' not found")
-
- return secret_names_to_ids
diff --git a/src/lightning/app/utilities/state.py b/src/lightning/app/utilities/state.py
deleted file mode 100644
index b366142948102..0000000000000
--- a/src/lightning/app/utilities/state.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import enum
-import json
-import os
-from copy import deepcopy
-from time import sleep
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-from deepdiff import DeepDiff
-from requests import Session
-from requests.exceptions import ConnectionError
-
-from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT
-from lightning.app.storage.drive import _maybe_create_drive
-from lightning.app.utilities.app_helpers import AppStatePlugin, BaseStatePlugin, Logger
-from lightning.app.utilities.network import LightningClient, _configure_session
-
-logger = Logger(__name__)
-
-# GLOBAL APP STATE
-_LAST_STATE = None
-_STATE = None
-
-
-class AppStateType(enum.Enum):
- STREAMLIT = enum.auto()
- DEFAULT = enum.auto()
-
-
-def headers_for(context: Dict[str, str]) -> Dict[str, str]:
- return {
- "X-Lightning-Session-UUID": context.get("token", ""),
- "X-Lightning-Session-ID": context.get("session_id", ""),
- "X-Lightning-Type": context.get("type", ""),
- }
-
-
-class AppState:
- _APP_PRIVATE_KEYS: Tuple[str, ...] = (
- "_use_localhost",
- "_host",
- "_session_id",
- "_state",
- "_last_state",
- "_url",
- "_port",
- "_request_state",
- "_store_state",
- "_send_state",
- "_my_affiliation",
- "_find_state_under_affiliation",
- "_plugin",
- "_attach_plugin",
- "_authorized",
- "is_authorized",
- "_debug",
- "_session",
- )
- _MY_AFFILIATION: Tuple[str, ...] = ()
-
- def __init__(
- self,
- host: Optional[str] = None,
- port: Optional[int] = None,
- last_state: Optional[Dict] = None,
- state: Optional[Dict] = None,
- my_affiliation: Tuple[str, ...] = None,
- plugin: Optional[BaseStatePlugin] = None,
- ) -> None:
- """The AppState class enables Frontend users to interact with their application state.
-
- When the state isn't defined, it would be pulled from the app REST API Server.
- If the state gets modified by the user, the new state would be sent to the API Server.
-
- Arguments:
- host: Rest API Server current host
- port: Rest API Server current port
- last_state: The state pulled on first access.
- state: The state modified by the user.
- my_affiliation: A tuple describing the affiliation this app state represents. When storing a state dict
- on this AppState, this affiliation will be used to reduce the scope of the given state.
- plugin: A plugin to handle authorization.
-
- """
- self._use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
- self._host = host or ("http://127.0.0.1" if self._use_localhost else None)
- self._port = port or (APP_SERVER_PORT if self._use_localhost else None)
- self._last_state = last_state
- self._state = state
- self._session_id = "1234"
- self._my_affiliation = my_affiliation if my_affiliation is not None else AppState._MY_AFFILIATION
- self._authorized = None
- self._attach_plugin(plugin)
- self._session = self._configure_session()
-
- @property
- def _url(self) -> str:
- if self._host is None:
- app_ip = ""
-
- if "LIGHTNING_CLOUD_PROJECT_ID" in os.environ and "LIGHTNING_CLOUD_APP_ID" in os.environ:
- client = LightningClient()
- app_instance = client.lightningapp_instance_service_get_lightningapp_instance(
- os.environ.get("LIGHTNING_CLOUD_PROJECT_ID"),
- os.environ.get("LIGHTNING_CLOUD_APP_ID"),
- )
- app_ip = app_instance.status.ip_address
-
- # TODO: Don't hard code port 8080 here
- self._host = f"http://{app_ip}:8080" if app_ip else APP_SERVER_HOST
- return f"{self._host}:{self._port}" if self._use_localhost else self._host
-
- def _attach_plugin(self, plugin: Optional[BaseStatePlugin]) -> None:
- plugin = plugin if plugin is not None else AppStatePlugin()
- self._plugin = plugin
-
- @staticmethod
- def _find_state_under_affiliation(state, my_affiliation: Tuple[str, ...]) -> Dict[str, Any]:
- """This method is used to extract the subset of the app state associated with the given affiliation.
-
- For example, if the affiliation is ``("root", "subflow")``, then the returned state will be
- ``state["flows"]["subflow"]``.
-
- """
- children_state = state
- for name in my_affiliation:
- if name in children_state["flows"]:
- children_state = children_state["flows"][name]
- elif name in children_state["works"]:
- children_state = children_state["works"][name]
- else:
- raise ValueError(f"Failed to extract the state under the affiliation '{my_affiliation}'.")
- return children_state
-
- def _store_state(self, state: Dict[str, Any]) -> None:
- # Relying on the global variable to ensure the
- # deep_diff is done on the entire state.
- global _LAST_STATE
- global _STATE
- _LAST_STATE = deepcopy(state)
- _STATE = state
- # If the affiliation is passed, the AppState was created in a LightningFlow context.
- # The state should be only the one of this LightningFlow and its children.
- self._last_state = self._find_state_under_affiliation(_LAST_STATE, self._my_affiliation)
- self._state = self._find_state_under_affiliation(_STATE, self._my_affiliation)
-
- def send_delta(self) -> None:
- app_url = f"{self._url}/api/v1/delta"
- deep_diff = DeepDiff(_LAST_STATE, _STATE, verbose_level=2)
- assert self._plugin is not None
- # TODO: Find how to prevent the infinite loop on refresh without storing the DeepDiff
- if self._plugin.should_update_app(deep_diff):
- data = {"delta": json.loads(deep_diff.to_json())}
- headers = headers_for(self._plugin.get_context())
- try:
- # TODO: Send the delta directly to the REST API.
- response = self._session.post(app_url, json=data, headers=headers)
- except ConnectionError as ex:
- raise AttributeError("Failed to connect and send the app state. Is the app running?") from ex
-
- if response and response.status_code != 200:
- raise Exception(f"The response from the server was {response.status_code}. Your inputs were rejected.")
-
- def _request_state(self) -> None:
- if self._state is not None:
- return
- app_url = f"{self._url}/api/v1/state"
- headers = headers_for(self._plugin.get_context()) if self._plugin else {}
-
- response_json = {}
-
- # Sometimes the state URL can return an empty JSON when things are being set-up,
- # so we wait for it to be ready here.
- while response_json == {}:
- sleep(0.5)
- try:
- response = self._session.get(app_url, headers=headers, timeout=1)
- except ConnectionError as ex:
- raise AttributeError("Failed to connect and fetch the app state. Is the app running?") from ex
-
- self._authorized = response.status_code
- if self._authorized != 200:
- return
-
- response_json = response.json()
-
- logger.debug(f"GET STATE {response} {response_json}")
- self._store_state(response_json)
-
- def __getattr__(self, name: str) -> Union[Any, "AppState"]:
- if name in self._APP_PRIVATE_KEYS:
- return object.__getattr__(self, name)
-
- # The state needs to be fetched on access if it doesn't exist.
- self._request_state()
-
- if name in self._state.get("vars", {}):
- value = self._state["vars"][name]
- if isinstance(value, dict):
- return _maybe_create_drive("root." + ".".join(self._my_affiliation), value)
- return value
-
- if name in self._state.get("works", {}):
- return AppState(
- self._host, self._port, last_state=self._last_state["works"][name], state=self._state["works"][name]
- )
-
- if name in self._state.get("flows", {}):
- return AppState(
- self._host,
- self._port,
- last_state=self._last_state["flows"][name],
- state=self._state["flows"][name],
- )
-
- if name in self._state.get("structures", {}):
- return AppState(
- self._host,
- self._port,
- last_state=self._last_state["structures"][name],
- state=self._state["structures"][name],
- )
-
- raise AttributeError(
- f"Failed to access '{name}' through `AppState`. The state provides:"
- f" Variables: {list(self._state['vars'].keys())},"
- f" Components: {list(self._state.get('flows', {}).keys()) + list(self._state.get('works', {}).keys())}",
- )
-
- def __getitem__(self, key: str):
- return self.__getattr__(key)
-
- def __setattr__(self, name: str, value: Any) -> None:
- if name in self._APP_PRIVATE_KEYS:
- object.__setattr__(self, name, value)
- return
-
- # The state needs to be fetched on access if it doesn't exist.
- self._request_state()
-
- # TODO: Find a way to aggregate deltas to avoid making
- # request for each attribute change.
- if name in self._state["vars"]:
- self._state["vars"][name] = value
- self.send_delta()
-
- elif name in self._state["flows"]:
- raise AttributeError("You shouldn't set the flows directly onto the state. Use its attributes instead.")
-
- elif name in self._state["works"]:
- raise AttributeError("You shouldn't set the works directly onto the state. Use its attributes instead.")
-
- else:
- raise AttributeError(
- f"Failed to access '{name}' through `AppState`. The state provides:"
- f" Variables: {list(self._state['vars'].keys())},"
- f" Components: {list(self._state['flows'].keys()) + list(self._state['works'].keys())}",
- )
-
- def __repr__(self) -> str:
- return str(self._state)
-
- def __bool__(self) -> bool:
- return bool(self._state)
-
- def __len__(self) -> int:
- # The state needs to be fetched on access if it doesn't exist.
- self._request_state()
-
- keys = []
- for component in ["flows", "works", "structures"]:
- keys.extend(list(self._state.get(component, {})))
- return len(keys)
-
- def items(self) -> List[Dict[str, Any]]:
- # The state needs to be fetched on access if it doesn't exist.
- self._request_state()
-
- items = []
- for component in ["flows", "works"]:
- state = self._state.get(component, {})
- last_state = self._last_state.get(component, {})
- for name, state_value in state.items():
- v = AppState(
- self._host,
- self._port,
- last_state=last_state[name],
- state=state_value,
- )
- items.append((name, v))
-
- structures = self._state.get("structures", {})
- last_structures = self._last_state.get("structures", {})
- if structures:
- for component in ["flows", "works"]:
- state = structures.get(component, {})
- last_state = last_structures.get(component, {})
- for name, state_value in state.items():
- v = AppState(
- self._host,
- self._port,
- last_state=last_state[name],
- state=state_value,
- )
- items.append((name, v))
- return items
-
- @staticmethod
- def _configure_session() -> Session:
- return _configure_session()
diff --git a/src/lightning/app/utilities/tracer.py b/src/lightning/app/utilities/tracer.py
deleted file mode 100644
index fe44f91947305..0000000000000
--- a/src/lightning/app/utilities/tracer.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-import inspect
-import runpy
-import sys
-import time
-from pathlib import Path
-from typing import Any, Optional
-
-
-def get_default_args(func):
- signature = inspect.signature(func)
- return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
-
-
-def wrap_fn(fn, cls, method_name, trace, stack_level=1, pre_fn=None, post_fn=None, is_class_method=None):
- """Wrap a function so that its execution can be traced and its args and return values modified."""
- class_name = cls.__qualname__
-
- @functools.wraps(fn)
- def fn_with_tracing(self, *args: Any, **kwargs: Any):
- if class_name not in trace:
- trace[class_name] = {}
-
- self_id = id(self)
- stack = inspect.stack()
- frame = stack[stack_level]
- frame_id = id(frame)
- stack_len = len(stack) - 1
-
- if self_id not in trace[class_name]:
- trace[class_name][self_id] = {}
-
- if method_name not in trace[class_name][self_id]:
- trace[class_name][self_id][method_name] = {}
-
- if frame_id not in trace[class_name][self_id][method_name]:
- trace[class_name][self_id][method_name][frame_id] = {}
-
- trace_entry = trace[class_name][self_id][method_name][frame_id]
-
- if pre_fn:
- # If a pre_fn is specified, it can both record information
- # in a trace, as well as return modified args and kwargs
- # that will be provided to the actual fn being wrappped
- pre_trace, args, kwargs = pre_fn(self, *args, **kwargs)
- trace_entry["pre"] = pre_trace
-
- # We record the invocation and the calling location in the trace
- trace_entry["frame"] = {
- "filename": frame.filename,
- "lineno": frame.lineno,
- "function": frame.function,
- "depth": stack_len,
- }
-
- # we cache the dfeault parameters used during the function call
- trace_entry["default_args"] = get_default_args(fn)
-
- # we cache also the parameters used during the function call
- trace_entry["call_args"] = kwargs
-
- trace_entry["call"] = {"start": time.time_ns()}
-
- ret = fn(self, *args, **kwargs) if not is_class_method else fn(*args, **kwargs)
-
- trace_entry["call"]["end"] = time.time_ns()
-
- if post_fn:
- # If a post_fn is specified, it can both record information
- # in a trace, as well as modify the value returned from fn
- post_trace, ret = post_fn(self, ret)
- trace_entry["post"] = post_trace
-
- return ret
-
- return fn_with_tracing
-
-
-class Tracer:
- def __init__(self):
- self.methods = []
- self.orig = {}
- self.res = {}
-
- def add_traced(self, cls, method_name, stack_level=1, pre_fn=None, post_fn=None):
- """Record the fact that we will want to trace method_name in class cls.
-
- Optionally provide two functions that will execute prior to and after the method. The functions also have a
- chance to modify the input arguments and the return values of the methods.
-
- """
- self.methods.append((cls, method_name, stack_level, pre_fn, post_fn))
-
- def _instrument(self):
- """Modify classes by wrapping methods that need to be traced.
-
- Initialize the output trace dict.
-
- """
- self.res = {}
- for cls, method, stack_level, pre_fn, post_fn in self.methods:
- fn = getattr(cls, method)
- # this checks if the passed function is a class method
- fn_is_class_method: bool = hasattr(fn, "__self__")
-
- if cls not in self.orig:
- self.orig[cls] = {}
- self.orig[cls][method] = fn
- wrapped_fn = wrap_fn(
- fn,
- cls,
- method,
- self.res,
- stack_level=stack_level,
- pre_fn=pre_fn,
- post_fn=post_fn,
- is_class_method=fn_is_class_method,
- )
-
- # this is needed to wrap class methods
- if fn_is_class_method:
- wrapped_fn = classmethod(wrapped_fn)
-
- setattr(cls, method, wrapped_fn)
-
- def _restore(self):
- """Restore original methods so classes go back to their initial state."""
- for cls in self.orig:
- for method in self.orig[cls]:
- setattr(cls, method, self.orig[cls][method])
-
- def _cleanup(self):
- """Cleanup trace by converting trace[class_name][instance_id][method_name][frame_id] to
- trace[class_name][][method_name][] thereby removing references to instance ids."""
- out = {}
- for class_name in self.res:
- out[class_name] = []
- for self_id in self.res[class_name]:
- instance = self.res[class_name][self_id]
- out_instance = {"id": self_id}
- for method_name, method in instance.items():
- frames = []
- for frame_id, frame in method.items():
- frame["id"] = frame_id
- frames.append(frame)
- out_instance[method_name] = frames
- out[class_name].append(out_instance)
- self.res = out
-
- def trace(self, *args: Any, init_globals=None) -> Optional[dict]:
- """Execute the command-line arguments in args after instrumenting for tracing.
-
- Restore the classes to their initial state after tracing.
-
- """
- args = list(args)
- script = args[0]
- script_dir = Path(script).parent.absolute()
-
- sys_path = sys.path[:]
- sys_argv = sys.argv[:]
-
- sys.path.append(str(script_dir))
-
- sys.argv = args
-
- self._instrument()
-
- res = runpy.run_path(script, run_name="__main__", init_globals=init_globals or globals())
-
- self._restore()
- self._cleanup()
-
- sys.path = sys_path[:]
- sys.argv = sys_argv[:]
-
- res["tracer_res"] = self.res
-
- return res
diff --git a/src/lightning/app/utilities/tree.py b/src/lightning/app/utilities/tree.py
deleted file mode 100644
index b60d8a85dc30b..0000000000000
--- a/src/lightning/app/utilities/tree.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Utilities for traversing the tree of components in an app."""
-
-from typing import TYPE_CHECKING, Type
-
-import lightning.app
-
-if TYPE_CHECKING:
- from lightning.app.utilities.types import Component, ComponentTuple
-
-
-def breadth_first(root: "Component", types: Type["ComponentTuple"] = None):
- """Returns a generator that walks through the tree of components breadth-first.
-
- Arguments:
- root: The root component of the tree
- types: If provided, only the component types in this list will be visited.
-
- """
- yield from _BreadthFirstVisitor(root, types)
-
-
-class _BreadthFirstVisitor:
- def __init__(self, root: "Component", types: Type["ComponentTuple"] = None) -> None:
- self.queue = [root]
- self.types = types
-
- def __iter__(self):
- return self
-
- def __next__(self):
- from lightning.app.structures import Dict
-
- while self.queue:
- component = self.queue.pop(0)
-
- if isinstance(component, lightning.app.LightningFlow):
- components = [getattr(component, el) for el in sorted(component._flows)]
- for struct_name in sorted(component._structures):
- structure = getattr(component, struct_name)
- if isinstance(structure, Dict):
- values = sorted(structure.items(), key=lambda x: x[0])
- else:
- values = sorted(((v.name, v) for v in structure), key=lambda x: x[0])
- for _, value in values:
- if isinstance(value, lightning.app.LightningFlow):
- components.append(value)
- self.queue += components
- self.queue += component.works(recurse=False)
-
- if any(isinstance(component, t) for t in self.types):
- return component
-
- raise StopIteration
-
-
-class _DepthFirstVisitor:
- def __init__(self, root: "Component", types: Type["ComponentTuple"] = None) -> None:
- self.stack = [root]
- self.types = types
-
- def __iter__(self):
- return self
-
- def __next__(self):
- while self.stack:
- component = self.stack.pop()
-
- if isinstance(component, lightning.app.LightningFlow):
- self.stack += list(component.flows.values())
- self.stack += component.works(recurse=False)
-
- if any(isinstance(component, t) for t in self.types):
- return component
-
- raise StopIteration
diff --git a/src/lightning/app/utilities/types.py b/src/lightning/app/utilities/types.py
deleted file mode 100644
index d14a025e179b3..0000000000000
--- a/src/lightning/app/utilities/types.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import typing as t
-from typing import Protocol, runtime_checkable
-
-from lightning.app.core import LightningFlow, LightningWork
-from lightning.app.structures import Dict, List
-
-Component = t.Union[LightningFlow, LightningWork, Dict, List]
-ComponentTuple = (LightningFlow, LightningWork, Dict, List)
-
-
-@runtime_checkable
-class Hashable(Protocol):
- def to_dict(self) -> t.Dict[str, t.Any]:
- """Convert to dictionaty."""
diff --git a/src/lightning/app/utilities/warnings.py b/src/lightning/app/utilities/warnings.py
deleted file mode 100644
index f0787103ea54d..0000000000000
--- a/src/lightning/app/utilities/warnings.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-class LightningFlowWarning(UserWarning):
- """Warning used to inform users of misuse with Lightning Flow."""
diff --git a/src/lightning/data/README.md b/src/lightning/data/README.md
index efd51a37e48a0..525a7e14f894d 100644
--- a/src/lightning/data/README.md
+++ b/src/lightning/data/README.md
@@ -15,7 +15,7 @@ We developed `StreamingDataset` to optimize training of large datasets stored on
Specifically crafted for multi-gpu & multi-node (with [DDP](https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html), [FSDP](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html), etc...), distributed training with large models, it enhances accuracy, performance, and user-friendliness. Now, training efficiently is possible regardless of the data's location. Simply stream in the required data when needed.
-The `StreamingDataset` is compatible with any data type, including **images, text, video, audio, geo-spatial, and multimodal data** and it is a drop-in replacement for your PyTorch [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) class. For example, it is used by [Lit-GPT](https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/tinyllama.py) to pretrain LLMs.
+The `StreamingDataset` is compatible with any data type, including **images, text, video, audio, geo-spatial, and multimodal data** and it is a drop-in replacement for your PyTorch [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) class. For example, it is used by [Lit-GPT](https://github.com/Lightning-AI/lit-gpt/blob/main/litgpt/data/tinyllama.py) to pretrain LLMs.
@@ -284,7 +284,7 @@ for batch in tqdm(train_dataloader):
Lightning Data provides a stateful `StreamingDataLoader`. This simplifies resuming training over large datasets.
-Note: The `StreamingDataLoader` is used by [Lit-GPT](https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/tinyllama.py) to pretrain LLMs. The statefulness still works when using a mixture of datasets with the `CombinedStreamingDataset`.
+Note: The `StreamingDataLoader` is used by [Lit-GPT](https://github.com/Lightning-AI/lit-gpt/blob/main/litgpt/data/tinyllama.py) to pretrain LLMs. The statefulness still works when using a mixture of datasets with the `CombinedStreamingDataset`.
```python
import os
diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md
index d88f2ec12827a..9ad21cea7225a 100644
--- a/src/lightning/fabric/CHANGELOG.md
+++ b/src/lightning/fabric/CHANGELOG.md
@@ -5,53 +5,66 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-## [unReleased] - 2024-MM-DD
+## [2.4.0] - 2024-08-06
### Added
-- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
-
--
-
--
+- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))
+- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
+- Added support for PyTorch 2.4 ([#20028](https://github.com/Lightning-AI/pytorch-lightning/pull/20028))
+- Added support for Python 3.12 ([20078](https://github.com/Lightning-AI/pytorch-lightning/pull/20078))
### Changed
-- Renamed `lightning run model` to `fabric run` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442), [#19527](https://github.com/Lightning-AI/pytorch-lightning/pull/19527))
-
+- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))
+- NumPy is no longer a required dependency ([#20090](https://github.com/Lightning-AI/pytorch-lightning/issues/20090))
-- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
+### Removed
+- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
+- Removed support for Python 3.8 ([#20071](https://github.com/Lightning-AI/lightning/pull/20071))
-- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))
+### Fixed
+- Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121))
+- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
-- `_BackwardSyncControl` can now control what to do when gradient accumulation is disabled ([#19577](https://github.com/Lightning-AI/lightning/pull/19577))
+## [2.3.0] - 2024-06-13
-### Deprecated
+### Added
--
+- Added sanitization for classes before logging them as hyperparameters ([#19771](https://github.com/Lightning-AI/pytorch-lightning/pull/19771))
+- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
+- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
+- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
+- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870), [#19872](https://github.com/Lightning-AI/pytorch-lightning/pull/19872))
+- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
+- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))
--
+### Changed
--
+- Renamed `lightning run model` to `fabric run` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442), [#19527](https://github.com/Lightning-AI/pytorch-lightning/pull/19527))
+- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
+- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))
+- `_BackwardSyncControl` can now control what to do when gradient accumulation is disabled ([#19577](https://github.com/Lightning-AI/lightning/pull/19577))
### Removed
--
+- Removed support for PyTorch 1.13 ([#19706](https://github.com/Lightning-AI/lightning/pull/19706))
+
+### Fixed
+
+- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))
--
--
+## [2.2.2] - 2024-04-11
### Fixed
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
-
- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705))
-
--
+- Fixed an issue causing weights to be reset in `Fabric.setup()` when using FSDP ([#19755](https://github.com/Lightning-AI/pytorch-lightning/pull/19755))
## [2.2.1] - 2024-03-04
diff --git a/src/lightning/fabric/__init__.py b/src/lightning/fabric/__init__.py
index 75752d8b94884..921d3d61e60fe 100644
--- a/src/lightning/fabric/__init__.py
+++ b/src/lightning/fabric/__init__.py
@@ -21,7 +21,7 @@
_logger.propagate = False
-# In PyTorch 2.0+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
+# Setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
# to use an NVML-based implementation that doesn't poison forks.
# https://github.com/pytorch/pytorch/issues/83973
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
@@ -37,9 +37,6 @@
__all__ = ["Fabric", "seed_everything", "is_wrapped"]
-# for compatibility with namespace packages
-__import__("pkg_resources").declare_namespace(__name__)
-
if os.environ.get("POSSIBLE_USER_WARNINGS", "").lower() in ("0", "off"):
disable_possible_user_warnings()
diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py
index 8613c6549e4c9..4afc9be723fc2 100644
--- a/src/lightning/fabric/accelerators/cuda.py
+++ b/src/lightning/fabric/accelerators/cuda.py
@@ -11,18 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-import warnings
-from contextlib import contextmanager
from functools import lru_cache
-from typing import Generator, List, Optional, Union, cast
+from typing import List, Optional, Union
import torch
from typing_extensions import override
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info
@@ -144,211 +140,15 @@ def _get_all_visible_cuda_devices() -> List[int]:
return list(range(num_cuda_devices()))
-# TODO: Remove once minimum supported PyTorch version is 2.0
-@contextmanager
-def _patch_cuda_is_available() -> Generator:
- """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible."""
- if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0:
- # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
- # otherwise, patching is_available could lead to attribute errors or infinite recursion
- orig_check = torch.cuda.is_available
- torch.cuda.is_available = is_cuda_available
- try:
- yield
- finally:
- torch.cuda.is_available = orig_check
- else:
- yield
-
-
-@lru_cache(1)
def num_cuda_devices() -> int:
- """Returns the number of available CUDA devices.
-
- Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
- if the platform allows it.
-
- """
- if _TORCH_GREATER_EQUAL_2_0:
- return torch.cuda.device_count()
-
- # Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
- # TODO: Remove once minimum supported PyTorch version is 2.0
- nvml_count = _device_count_nvml()
- return torch.cuda.device_count() if nvml_count < 0 else nvml_count
+ """Returns the number of available CUDA devices."""
+ return torch.cuda.device_count()
def is_cuda_available() -> bool:
- """Returns a bool indicating if CUDA is currently available.
-
- Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
- if the platform allows it.
-
- """
+ """Returns a bool indicating if CUDA is currently available."""
# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning.fabric.__init__.py
- return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0
-
-
-# TODO: Remove once minimum supported PyTorch version is 2.0
-def _parse_visible_devices() -> Union[List[int], List[str]]:
- """Parse CUDA_VISIBLE_DEVICES environment variable."""
- var = os.getenv("CUDA_VISIBLE_DEVICES")
- if var is None:
- return list(range(64))
-
- def _strtoul(s: str) -> int:
- """Return -1 or positive integer sequence string starts with,"""
- if not s:
- return -1
- for idx, c in enumerate(s):
- if not (c.isdigit() or (idx == 0 and c in "+-")):
- break
- if idx + 1 == len(s):
- idx += 1
- return int(s[:idx]) if idx > 0 else -1
-
- def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
- rcs: List[str] = []
- for elem in lst.split(","):
- # Repeated id results in empty set
- if elem in rcs:
- return cast(List[str], [])
- # Anything other but prefix is ignored
- if not elem.startswith(prefix):
- break
- rcs.append(elem)
- return rcs
-
- if var.startswith("GPU-"):
- return parse_list_with_prefix(var, "GPU-")
- if var.startswith("MIG-"):
- return parse_list_with_prefix(var, "MIG-")
- # CUDA_VISIBLE_DEVICES uses something like strtoul
- # which makes `1gpu2,2ampere` is equivalent to `1,2`
- rc: List[int] = []
- for elem in var.split(","):
- x = _strtoul(elem.strip())
- # Repeated ordinal results in empty set
- if x in rc:
- return cast(List[int], [])
- # Negative value aborts the sequence
- if x < 0:
- break
- rc.append(x)
- return rc
-
-
-# TODO: Remove once minimum supported PyTorch version is 2.0
-def _raw_device_count_nvml() -> int:
- """Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
- from ctypes import CDLL, byref, c_int
-
- nvml_h = CDLL("libnvidia-ml.so.1")
- rc = nvml_h.nvmlInit()
- if rc != 0:
- warnings.warn("Can't initialize NVML")
- return -1
- dev_count = c_int(-1)
- rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
- if rc != 0:
- warnings.warn("Can't get nvml device count")
- return -1
- del nvml_h
- return dev_count.value
-
-
-# TODO: Remove once minimum supported PyTorch version is 2.0
-def _raw_device_uuid_nvml() -> Optional[List[str]]:
- """Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
- from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer
-
- nvml_h = CDLL("libnvidia-ml.so.1")
- rc = nvml_h.nvmlInit()
- if rc != 0:
- warnings.warn("Can't initialize NVML")
- return None
- dev_count = c_int(-1)
- rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
- if rc != 0:
- warnings.warn("Can't get nvml device count")
- return None
- uuids: List[str] = []
- for idx in range(dev_count.value):
- dev_id = c_void_p()
- rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
- if rc != 0:
- warnings.warn("Can't get device handle")
- return None
- buf_len = 96
- buf = create_string_buffer(buf_len)
- rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
- if rc != 0:
- warnings.warn("Can't get device UUID")
- return None
- uuids.append(buf.raw.decode("ascii").strip("\0"))
- del nvml_h
- return uuids
-
-
-# TODO: Remove once minimum supported PyTorch version is 2.0
-def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
- """Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials
- IDs."""
-
- def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
- best_match = -1
- for idx, uuid in enumerate(uuids):
- if not uuid.startswith(candidate):
- continue
- # Ambigous candidate
- if best_match != -1:
- return -1
- best_match = idx
- return best_match
-
- rc: List[int] = []
- for candidate in candidates:
- idx = uuid_to_orinal(candidate, uuids)
- # First invalid ordinal stops parsing
- if idx < 0:
- break
- # Duplicates result in empty set
- if idx in rc:
- return cast(List[int], [])
- rc.append(idx)
- return rc
-
-
-# TODO: Remove once minimum supported PyTorch version is 2.0
-def _device_count_nvml() -> int:
- """Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
-
- Negative value is returned if NVML discovery or initialization has failed.
-
- """
- visible_devices = _parse_visible_devices()
- if not visible_devices:
- return 0
- try:
- if isinstance(visible_devices[0], str):
- # Skip MIG parsing
- if visible_devices[0].startswith("MIG-"):
- return -1
- uuids = _raw_device_uuid_nvml()
- if uuids is None:
- return -1
- visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
- else:
- raw_cnt = _raw_device_count_nvml()
- if raw_cnt <= 0:
- return raw_cnt
- # Trim the list up to a maximum available device
- for idx, val in enumerate(visible_devices):
- if cast(int, val) >= raw_cnt:
- return idx
- except (OSError, AttributeError):
- return -1
- return len(visible_devices)
+ return torch.cuda.is_available()
def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool:
@@ -375,7 +175,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None:
def _clear_cuda_memory() -> None:
# strangely, the attribute function be undefined when torch.compile is used
- if _TORCH_GREATER_EQUAL_2_0 and hasattr(torch._C, "_cuda_clearCublasWorkspaces"):
+ if hasattr(torch._C, "_cuda_clearCublasWorkspaces"):
# https://github.com/pytorch/pytorch/issues/95668
torch._C._cuda_clearCublasWorkspaces()
torch.cuda.empty_cache()
diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py
index d0f36698616d4..75497169cda0f 100644
--- a/src/lightning/fabric/accelerators/mps.py
+++ b/src/lightning/fabric/accelerators/mps.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import platform
from functools import lru_cache
from typing import List, Optional, Union
@@ -70,7 +71,8 @@ def auto_device_count() -> int:
@lru_cache(1)
def is_available() -> bool:
"""MPS is only available on a machine with the ARM-based Apple Silicon processors."""
- return torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
+ mps_disabled = os.getenv("DISABLE_MPS", "0") == "1"
+ return not mps_disabled and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64")
@classmethod
@override
diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py
index d8c6fe47b6630..7c81afa916196 100644
--- a/src/lightning/fabric/cli.py
+++ b/src/lightning/fabric/cli.py
@@ -56,17 +56,17 @@ def _legacy_main() -> None:
Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly
"""
- print(
- "`lightning run model` is deprecated and will be removed in future versions."
- " Please call `fabric run` instead."
- )
- args = sys.argv[1:]
- if args and args[0] == "run" and args[1] == "model":
+ hparams = sys.argv[1:]
+ if len(hparams) >= 2 and hparams[0] == "run" and hparams[1] == "model":
+ print(
+ "`lightning run model` is deprecated and will be removed in future versions."
+ " Please call `fabric run` instead."
+ )
_main()
return
if _LIGHTNING_SDK_AVAILABLE:
- subprocess.run([sys.executable, "-m", "lightning_sdk.cli.entrypoint"] + args)
+ subprocess.run([sys.executable, "-m", "lightning_sdk.cli.entrypoint"] + hparams)
return
@click.group()
@@ -140,7 +140,7 @@ def _main() -> None:
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
default=None,
help=(
- "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
+ "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
),
)
diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py
index edbfd77721a95..9fb66255830c6 100644
--- a/src/lightning/fabric/connector.py
+++ b/src/lightning/fabric/connector.py
@@ -62,6 +62,7 @@
)
from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
+from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
@@ -312,7 +313,8 @@ def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str
f" using {accelerator_name} accelerator."
)
- def _choose_auto_accelerator(self) -> str:
+ @staticmethod
+ def _choose_auto_accelerator() -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if XLAAccelerator.is_available():
return "tpu"
@@ -429,7 +431,7 @@ def _check_strategy_and_fallback(self) -> None:
f" platform. We recommed `Fabric(strategy='ddp_spawn')` instead."
)
if (
- strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
+ strategy_flag in _FSDP_ALIASES or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
raise ValueError(
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Fabric(accelerator='gpu', ...)`"
@@ -460,6 +462,12 @@ def _check_and_init_precision(self) -> Precision:
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
+ mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
+ if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
+ raise ValueError(
+ f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_input!r})`."
+ f" Choose a different precision among: {', '.join(mp_precision_supported)}."
+ )
if self._precision_input in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_input) # type: ignore
if self._precision_input == "32-true":
diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py
index 4b9c14eb06e62..0ff5b04b30b0a 100644
--- a/src/lightning/fabric/fabric.py
+++ b/src/lightning/fabric/fabric.py
@@ -51,10 +51,8 @@
FSDPStrategy,
SingleDeviceStrategy,
Strategy,
- XLAFSDPStrategy,
XLAStrategy,
)
-from lightning.fabric.strategies.fsdp import _has_meta_device_parameters
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast, _Sharded
from lightning.fabric.utilities import move_data_to_device
@@ -67,7 +65,7 @@
)
from lightning.fabric.utilities.device_dtype_mixin import _update_properties
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper, _InfiniteBarrier
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
+from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
@@ -227,8 +225,7 @@ def setup(
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
- FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
- issues.
+ FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
@@ -294,8 +291,7 @@ def setup_module(
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
- FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
- issues.
+ FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
Returns:
The wrapped model.
@@ -699,26 +695,14 @@ def sharded_model(self) -> ContextManager:
def init_tensor(self) -> ContextManager:
"""Tensors that you instantiate under this context manager will be created on the device right away and have
- the right data type depending on the precision setting in Fabric.
-
- The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
-
- """
- if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
- rank_zero_warn(
- "`Fabric.init_tensor()` can't place tensors on the device directly"
- " with PyTorch < 2.0. Parameters will remain on CPU until `Fabric.setup()` is called."
- " Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
- category=PossibleUserWarning,
- )
+ the right data type depending on the precision setting in Fabric."""
return self._strategy.tensor_init_context()
def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
The parameters get created on the device and with the right data type right away without wasting memory being
- allocated unnecessarily. The automatic device placement under this context manager is only supported with
- PyTorch 2.0 and newer.
+ allocated unnecessarily.
Args:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
@@ -727,13 +711,6 @@ def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:
"""
self._validate_launched()
- if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
- rank_zero_warn(
- "`Fabric.init_module()` can't place the model parameters on the device directly"
- " with PyTorch < 2.0. Parameters will remain on CPU until `Fabric.setup()` is called."
- " Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
- category=PossibleUserWarning,
- )
return self._strategy.module_init_context(empty_init=empty_init)
def save(
@@ -932,7 +909,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No
logger.log_metrics(metrics=metrics, step=step)
@staticmethod
- def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
+ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None, verbose: bool = True) -> int:
r"""Helper function to seed everything without explicitly importing Lightning.
See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details.
@@ -942,7 +919,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
# Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new
# release, we can afford to do it.
workers = True
- return seed_everything(seed=seed, workers=workers)
+ return seed_everything(seed=seed, workers=workers, verbose=verbose)
def _wrap_and_launch(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:
self._launched = True
@@ -1036,14 +1013,8 @@ def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) ->
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup` method.")
- if isinstance(self._strategy, (FSDPStrategy, XLAFSDPStrategy)) and not _TORCH_GREATER_EQUAL_2_0:
- raise RuntimeError(
- f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
- " Create and set up the model first through `model = self.setup_module(model)`. Then create the"
- " optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
- )
if isinstance(self._strategy, FSDPStrategy) and any(
- _has_meta_device_parameters(optimizer) for optimizer in optimizers
+ _has_meta_device_parameters_or_buffers(optimizer) for optimizer in optimizers
):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
@@ -1071,7 +1042,7 @@ def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
- if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
+ if any(_has_meta_device_parameters_or_buffers(optimizer) for optimizer in optimizers):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
" is currently not supported. Create the optimizer after setting up the model, then call"
diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py
index 9e82041b528c1..4dbb56fa691db 100644
--- a/src/lightning/fabric/loggers/csv_logs.py
+++ b/src/lightning/fabric/loggers/csv_logs.py
@@ -172,7 +172,6 @@ def _get_next_version(self) -> int:
versions_root = os.path.join(self._root_dir, self.name)
if not _is_dir(self._fs, versions_root, strict=True):
- log.warning("Missing logger folder: %s", versions_root)
return 0
existing_versions = []
diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py
index eb0039f86e956..685c832088818 100644
--- a/src/lightning/fabric/loggers/tensorboard.py
+++ b/src/lightning/fabric/loggers/tensorboard.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import os
from argparse import Namespace
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
@@ -30,8 +29,6 @@
from lightning.fabric.utilities.types import _PATH
from lightning.fabric.wrappers import _unwrap_objects
-log = logging.getLogger(__name__)
-
_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
if TYPE_CHECKING:
@@ -107,7 +104,7 @@ def __init__(
self._prefix = prefix
self._fs = get_filesystem(root_dir)
- self._experiment: Optional["SummaryWriter"] = None
+ self._experiment: Optional[SummaryWriter] = None
self._kwargs = kwargs
@property
@@ -305,8 +302,6 @@ def _get_next_version(self) -> int:
try:
listdir_info = self._fs.listdir(save_dir)
except OSError:
- # TODO(fabric): This message can be confusing (did user do something wrong?). Improve it or remove it.
- log.warning("Missing logger folder: %s", save_dir)
return 0
existing_versions = []
diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py
index abba459895c5b..0dea3033f3dff 100644
--- a/src/lightning/fabric/plugins/collectives/torch_collective.py
+++ b/src/lightning/fabric/plugins/collectives/torch_collective.py
@@ -95,11 +95,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T
@override
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
- dist.send(tensor, dst, tag=tag, group=self.group)
+ dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type]
@override
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
- dist.recv(tensor, src, tag=tag, group=self.group)
+ dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type]
return tensor
def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]:
diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py
index 0ec21247c9881..c624e821af28c 100644
--- a/src/lightning/fabric/plugins/precision/amp.py
+++ b/src/lightning/fabric/plugins/precision/amp.py
@@ -20,9 +20,9 @@
from torch.optim import LBFGS, Optimizer
from typing_extensions import override
-from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable
@@ -40,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
- scaler: Optional[torch.cuda.amp.GradScaler] = None,
+ scaler: Optional["torch.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
@@ -50,9 +50,7 @@ def __init__(
self.precision = precision
if scaler is None and self.precision == "16-mixed":
- with _patch_cuda_is_available():
- # if possible, we defer CUDA initialization to support strategies that will attempt forks
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py
index 12a0ac3998b6e..394415452890a 100644
--- a/src/lightning/fabric/plugins/precision/bitsandbytes.py
+++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py
@@ -43,7 +43,7 @@
class BitsandbytesPrecision(Precision):
- """Plugin for quantizing weights with `bitsandbytes `__.
+ """Plugin for quantizing weights with `bitsandbytes `__.
.. warning:: This is an :ref:`experimental ` feature.
@@ -184,11 +184,15 @@ def _replace_param(
if param.device.type == "meta":
if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit(
- data,
+ data=data,
requires_grad=data.requires_grad,
quant_state=quant_state,
+ blocksize=param.blocksize,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
+ quant_storage=param.quant_storage,
+ module=param.module,
+ bnb_quantized=param.bnb_quantized,
)
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data
@@ -233,9 +237,9 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
- if weight.data.type == torch.int8:
- # already quantized
- return
+ if weight.data.dtype == torch.int8:
+ # already quantized
+ return
assert isinstance(self.weight, bnb.nn.Int8Params)
self.weight = self.quantize(self.weight, weight, device)
@@ -317,11 +321,12 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
"""Inplace quantize."""
if weight is None:
weight = self.weight.data
- if weight.data.type == torch.uint8:
- # already quantized
- return
+ if weight.data.dtype == torch.uint8:
+ # already quantized
+ return
assert isinstance(self.weight, bnb.nn.Params4bit)
self.weight = self.quantize(self.weight, weight, device)
+ self.weight.bnb_quantized = True
@staticmethod
def quantize(
@@ -337,6 +342,7 @@ def quantize(
blocksize=params4bit.blocksize,
compress_statistics=params4bit.compress_statistics,
quant_type=params4bit.quant_type,
+ quant_storage=params4bit.quant_storage,
)
return _replace_param(params4bit, w_4bit, quant_state)
diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py
index 161ad98f43475..179fc21cdd90d 100644
--- a/src/lightning/fabric/plugins/precision/fsdp.py
+++ b/src/lightning/fabric/plugins/precision/fsdp.py
@@ -23,7 +23,6 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
if TYPE_CHECKING:
@@ -78,21 +77,18 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
- # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
- # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
- # `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
diff --git a/src/lightning/fabric/strategies/__init__.py b/src/lightning/fabric/strategies/__init__.py
index ff48b152750ef..f561c4f426aac 100644
--- a/src/lightning/fabric/strategies/__init__.py
+++ b/src/lightning/fabric/strategies/__init__.py
@@ -17,6 +17,7 @@
from lightning.fabric.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
from lightning.fabric.strategies.dp import DataParallelStrategy # noqa: F401
from lightning.fabric.strategies.fsdp import FSDPStrategy # noqa: F401
+from lightning.fabric.strategies.model_parallel import ModelParallelStrategy # noqa: F401
from lightning.fabric.strategies.parallel import ParallelStrategy # noqa: F401
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy # noqa: F401
diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py
index 0ec5df1a6b0ae..c38780655ce6e 100644
--- a/src/lightning/fabric/strategies/ddp.py
+++ b/src/lightning/fabric/strategies/ddp.py
@@ -200,6 +200,13 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
description=f"DDP strategy with `start_method={start_method!r}`",
start_method=start_method,
)
+ strategy_registry.register(
+ "ddp_find_unused_parameters_true",
+ cls,
+ description="Alias for `find_unused_parameters_true` and `start_method='popen'`",
+ find_unused_parameters=True,
+ start_method="popen",
+ )
def _setup_distributed(self) -> None:
self._set_world_ranks()
diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 2a1a1272b498e..e71b8e2db3d58 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -43,6 +43,7 @@
from deepspeed import DeepSpeedEngine
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
+_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@@ -291,7 +292,7 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale
- self._deepspeed_engine: Optional["DeepSpeedEngine"] = None
+ self._deepspeed_engine: Optional[DeepSpeedEngine] = None
@property
def zero_stage_3(self) -> bool:
@@ -498,7 +499,10 @@ def load_checkpoint(
)
engine = engines[0]
- from deepspeed.runtime import DeepSpeedOptimizer
+ if _DEEPSPEED_GREATER_EQUAL_0_14_1:
+ from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
+ else:
+ from deepspeed.runtime import DeepSpeedOptimizer
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
@@ -594,7 +598,7 @@ def _initialize_engine(
) -> Tuple["DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.
- This calls :func:`deepspeed.initialize` internally.
+ This calls ``deepspeed.initialize`` internally.
"""
import deepspeed
diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py
index ed89629f720e8..e7fdd29f6287f 100644
--- a/src/lightning/fabric/strategies/fsdp.py
+++ b/src/lightning/fabric/strategies/fsdp.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
+import warnings
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
@@ -36,7 +37,7 @@
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
-from torch.nn import Module, Parameter
+from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import TypeGuard, override
@@ -63,37 +64,33 @@
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import (
- _TORCH_GREATER_EQUAL_2_0,
- _TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
_TORCH_GREATER_EQUAL_2_3,
)
-from lightning.fabric.utilities.init import _EmptyInit
+from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, _Stateful
if TYPE_CHECKING:
+ from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
- if _TORCH_GREATER_EQUAL_2_0:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-
- _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
- else:
- _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool]] # type: ignore[misc]
-
+ _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
+
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
+# TODO: Switch to new state-dict APIs
+warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4
+
class FSDPStrategy(ParallelStrategy, _Sharded):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
- .. warning:: This is an :ref:`experimental ` feature.
-
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
@@ -125,10 +122,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
- replicates across machines.
+ replicates across machines. See also the `device_mesh` parameter below.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
+ device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
+ replicate the model. The product of the two numbers must equal the world size. Only valid in combination
+ with the `HYBRID_SHARD` sharding strategy.
+
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
@@ -154,6 +155,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "sharded",
+ device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -168,9 +170,13 @@ def __init__(
self._backward_sync_control = _FSDPBackwardSyncControl()
self._fsdp_kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
- if _TORCH_GREATER_EQUAL_2_0:
- # Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
- self._fsdp_kwargs.setdefault("use_orig_params", True)
+ # Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
+ self._fsdp_kwargs.setdefault("use_orig_params", True)
+
+ if device_mesh is not None:
+ if not _TORCH_GREATER_EQUAL_2_2:
+ raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
+ self._fsdp_kwargs["device_mesh"] = device_mesh
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
@@ -253,18 +259,18 @@ def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
+ # if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
+ if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
+ from torch.distributed.device_mesh import init_device_mesh
+
+ self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])
+
@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
- " Please do it in this order: Create the model, call `setup_module`, create the optimizer,"
- " call `setup_optimizer`."
- )
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
if use_orig_params is False:
raise ValueError(
@@ -284,7 +290,7 @@ def setup_module(self, module: Module) -> Module:
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
- if _has_meta_device_parameters(module):
+ if _has_meta_device_parameters_or_buffers(module):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
@@ -322,7 +328,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
if self._fsdp_kwargs.get("use_orig_params"):
return super().setup_optimizer(optimizer)
if not _optimizer_has_flat_params(optimizer):
- # We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
+ # We avoid this limitation by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
@@ -337,15 +343,12 @@ def module_to_device(self, module: Module) -> None:
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
- empty_ctx = _EmptyInit(enabled=bool(empty_init))
stack = ExitStack()
- if _TORCH_GREATER_EQUAL_2_1 and empty_init:
+ if empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
stack.enter_context(torch.device("meta"))
- else:
- stack.enter_context(empty_ctx)
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
@@ -406,7 +409,7 @@ def clip_gradients_norm(
# the root must be wrapped
raise TypeError(
"Gradient clipping with FSDP is only possible if the module passed to"
- f" `{self.__class__.__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
+ f" `{type(self).__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
self.precision.unscale_gradients(optimizer)
@@ -428,11 +431,6 @@ def save_checkpoint(
creates a metadata file `meta.pt` with the rest of the user's state (only saved from rank 0).
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- "Saving and loading checkpoints with the `FSDPStrategy` is not supported in PyTorch < 2.0."
- " Please upgrade `torch` or file an issue: `https://github.com/Lightning-AI/lightning/issues`."
- )
if storage_options is not None:
raise TypeError(
"`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
@@ -524,17 +522,7 @@ def load_checkpoint(
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
- """Load the contents from a checkpoint and restore the state of the given objects.
-
- The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a
- directory of multiple files rather than a single file.
-
- """
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- "Saving and loading checkpoints with the `FSDPStrategy` is not supported in PyTorch < 2.0."
- " Please upgrade `torch` or file an issue: `https://github.com/Lightning-AI/lightning/issues`."
- )
+ """Load the contents from a checkpoint and restore the state of the given objects."""
if not state:
raise ValueError(
f"Got FSDPStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
@@ -545,6 +533,8 @@ def load_checkpoint(
path = Path(self.broadcast(path))
if isinstance(state, Module):
+ from lightning.fabric.strategies.model_parallel import _load_raw_module_state_from_path
+
_load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
return {}
@@ -555,7 +545,6 @@ def load_checkpoint(
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- from torch.distributed.fsdp import OptimStateKeyType
modules = {key: module for key, module in state.items() if _has_fsdp_modules(module)}
if len(modules) == 0:
@@ -614,29 +603,27 @@ def load_checkpoint(
return metadata
if _is_full_checkpoint(path):
- checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
+ checkpoint = _lazy_load(path)
+
+ from lightning.fabric.strategies.model_parallel import (
+ _load_raw_module_state,
+ _rekey_optimizer_state_if_needed,
+ )
+
_load_raw_module_state(checkpoint.pop(module_key), module=module, world_size=self.world_size, strict=strict)
if isinstance(state, Module):
return {}
- if _TORCH_GREATER_EQUAL_2_0:
- # Materialize lazy tensors if there are any left in the checkpoint
- # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
- checkpoint = _materialize_tensors(checkpoint)
+ # Materialize lazy tensors if there are any left in the checkpoint
+ # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
+ checkpoint = _materialize_tensors(checkpoint)
# Load optimizer states
for optim_key, optim in optimizers.items():
# rank0_only should be false because we need to load the optimizer state on all ranks
with _get_full_state_dict_context(module, world_size=self.world_size, rank0_only=False):
- temp_state_dict = checkpoint.pop(optim_key)
-
- # Handling the case where the optimizer state is saved from a normal optimizer
- if isinstance(list(temp_state_dict["state"].keys())[0], int):
- temp_state_dict = FSDP.rekey_optim_state_dict(
- temp_state_dict, OptimStateKeyType.PARAM_NAME, module
- )
-
+ temp_state_dict = _rekey_optimizer_state_if_needed(checkpoint.pop(optim_key), module)
optim_state_dict = FSDP.optim_state_dict_to_load(
optim_state_dict=temp_state_dict,
model=module,
@@ -710,18 +697,13 @@ def _activation_checkpointing_kwargs(
classes = tuple(activation_checkpointing)
else:
classes = (activation_checkpointing,)
- if _TORCH_GREATER_EQUAL_2_1:
- rank_zero_deprecation(
- f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
- f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
- )
+ rank_zero_deprecation(
+ f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
+ f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
+ )
return {"check_fn": lambda submodule: isinstance(submodule, classes)}
if isinstance(activation_checkpointing_policy, set):
- if _TORCH_GREATER_EQUAL_2_1:
- return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
- return {"check_fn": lambda submodule: isinstance(submodule, tuple(activation_checkpointing_policy))}
- if not _TORCH_GREATER_EQUAL_2_1:
- raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`")
+ return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
return {"auto_wrap_policy": activation_checkpointing_policy}
@@ -729,15 +711,10 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict:
if policy is None:
return kwargs
if isinstance(policy, set):
- if _TORCH_GREATER_EQUAL_2_1:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
- policy = ModuleWrapPolicy(policy)
- else:
- from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+ policy = ModuleWrapPolicy(policy)
- # this is not transformer specific despite the name
- policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
kwargs["auto_wrap_policy"] = policy
return kwargs
@@ -779,7 +756,7 @@ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
# the root must be wrapped
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
- f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
+ f" `{type(self).__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
return module.no_sync()
@@ -840,27 +817,17 @@ def _get_full_state_dict_context(
) -> Generator[None, None, None]:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp.api import FullOptimStateDictConfig
- # In PyTorch <= 2.0, offload to CPU in combination with `world_size=1` is not possible
- offload_to_cpu = world_size > 1 or _TORCH_GREATER_EQUAL_2_1
- state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
-
- if _TORCH_GREATER_EQUAL_2_0:
- from torch.distributed.fsdp.api import FullOptimStateDictConfig
+ state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
+ optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
+ state_dict_type_context = FSDP.state_dict_type(
+ module=module,
+ state_dict_type=StateDictType.FULL_STATE_DICT,
+ state_dict_config=state_dict_config,
+ optim_state_dict_config=optim_state_dict_config,
+ )
- optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
- state_dict_type_context = FSDP.state_dict_type(
- module=module,
- state_dict_type=StateDictType.FULL_STATE_DICT,
- state_dict_config=state_dict_config,
- optim_state_dict_config=optim_state_dict_config,
- )
- else:
- state_dict_type_context = FSDP.state_dict_type(
- module=module,
- state_dict_type=StateDictType.FULL_STATE_DICT,
- state_dict_config=state_dict_config,
- )
return state_dict_type_context # type: ignore[return-value]
@@ -879,38 +846,6 @@ def _has_fsdp_modules(module: object) -> TypeGuard[Module]:
return isinstance(module, Module) and any(isinstance(m, FullyShardedDataParallel) for m in module.modules())
-def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int, strict: bool = True) -> None:
- """Loads the state dict from a file path into the FSDP module."""
- if not _is_full_checkpoint(path):
- raise ValueError(
- "Failed to load checkpoint directly into the model. The given path must be a single file containing the"
- f" full state dict: {path}"
- )
- # Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank
- _load_raw_module_state(state_dict=_lazy_load(path), module=module, world_size=world_size, strict=strict)
-
-
-def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
- """Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
- if not isinstance(module, FSDP):
- module.load_state_dict(state_dict, strict=strict)
- else:
- with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
- module.load_state_dict(state_dict, strict=strict)
-
-
-def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
- if isinstance(obj, Optimizer):
- return any(
- t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
- )
- if isinstance(obj, Module):
- return any(t.is_meta for t in obj.parameters())
- raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
-
-
def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
# https://github.com/pytorch/pytorch/issues/113113
@@ -929,7 +864,7 @@ def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) ->
# let torch automatically infer the writer to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
- save(converted_state, checkpoint_id=path) # type: ignore[call-arg]
+ save(converted_state, checkpoint_id=path)
else: # deprecated
from torch.distributed.checkpoint import FileSystemWriter
@@ -948,7 +883,7 @@ def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> No
# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
- load(module_state, checkpoint_id=path) # type: ignore[call-arg]
+ load(module_state, checkpoint_id=path)
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader
diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py
new file mode 100644
index 0000000000000..86b93d35e66f3
--- /dev/null
+++ b/src/lightning/fabric/strategies/model_parallel.py
@@ -0,0 +1,594 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import itertools
+import shutil
+from contextlib import ExitStack
+from datetime import timedelta
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union
+
+import torch
+from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
+from torch import Tensor
+from torch.nn import Module
+from torch.optim import Optimizer
+from typing_extensions import TypeGuard, override
+
+from lightning.fabric.plugins import CheckpointIO
+from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
+from lightning.fabric.strategies.fsdp import (
+ _distributed_checkpoint_load,
+ _distributed_checkpoint_save,
+ _get_full_state_dict_context,
+ _is_full_checkpoint,
+ _is_sharded_checkpoint,
+)
+from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
+from lightning.fabric.strategies.parallel import ParallelStrategy
+from lightning.fabric.strategies.strategy import (
+ TBroadcast,
+ _apply_filter,
+ _BackwardSyncControl,
+ _validate_keys_for_strict_loading,
+)
+from lightning.fabric.utilities.distributed import (
+ ReduceOp,
+ _distributed_is_initialized,
+ _get_default_process_group_backend_for_device,
+ _init_dist_connection,
+ _sync_ddp_if_available,
+)
+from lightning.fabric.utilities.distributed import group as _group
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
+from lightning.fabric.utilities.init import _materialize_distributed_module
+from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _move_state_into
+from lightning.fabric.utilities.rank_zero import rank_zero_only
+from lightning.fabric.utilities.seed import reset_seed
+from lightning.fabric.utilities.types import _PATH, _Stateful
+
+if TYPE_CHECKING:
+ from torch.distributed.device_mesh import DeviceMesh
+
+TModel = TypeVar("TModel", bound=Module)
+
+
+class ModelParallelStrategy(ParallelStrategy):
+ """Enables user-defined parallelism applied to a model.
+
+ .. warning:: This is an :ref:`experimental ` feature.
+
+ Currently supports up to 2D parallelism. Specifically, it supports the combination of
+ Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
+ experimental in PyTorch. Requires PyTorch 2.4 or newer.
+
+ Arguments:
+ parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
+ model and device mesh as input.
+ data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
+ sets this size to the number of nodes in the cluster.
+ tensor_parallel_size: The number of devices within a tensor-parallel group. Defaults to ``"auto"``, which
+ sets this size to the number of GPUs in a single node.
+ save_distributed_checkpoint: If ``True``, each rank saves its shard of weights and optimizer states to a file.
+ The checkpoint is a folder with as many files as the world size.
+ If ``False``, the full weights and optimizer states get assembled on rank 0 and saved to a single file.
+
+ """
+
+ def __init__(
+ self,
+ parallelize_fn: Callable[[TModel, "DeviceMesh"], TModel],
+ data_parallel_size: Union[Literal["auto"], int] = "auto",
+ tensor_parallel_size: Union[Literal["auto"], int] = "auto",
+ save_distributed_checkpoint: bool = True,
+ process_group_backend: Optional[str] = None,
+ timeout: Optional[timedelta] = default_pg_timeout,
+ ) -> None:
+ super().__init__()
+ if not _TORCH_GREATER_EQUAL_2_4:
+ raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
+ self._parallelize_fn = parallelize_fn
+ self._data_parallel_size = data_parallel_size
+ self._tensor_parallel_size = tensor_parallel_size
+ self._num_nodes = 1
+ self._save_distributed_checkpoint = save_distributed_checkpoint
+ self._process_group_backend: Optional[str] = process_group_backend
+ self._timeout: Optional[timedelta] = timeout
+ self._backward_sync_control = _ParallelBackwardSyncControl()
+
+ self._device_mesh: Optional[DeviceMesh] = None
+
+ @property
+ def device_mesh(self) -> "DeviceMesh":
+ if self._device_mesh is None:
+ raise RuntimeError("Accessing the device mesh before processes have initialized is not allowed.")
+ return self._device_mesh
+
+ @property
+ @override
+ def checkpoint_io(self) -> CheckpointIO:
+ raise NotImplementedError(f"The `{type(self).__name__}` does not use the `CheckpointIO` plugin interface.")
+
+ @checkpoint_io.setter
+ @override
+ def checkpoint_io(self, io: CheckpointIO) -> None:
+ raise NotImplementedError(f"The `{type(self).__name__}` does not support setting a `CheckpointIO` plugin.")
+
+ @property
+ @override
+ def root_device(self) -> torch.device:
+ assert self.parallel_devices is not None
+ return self.parallel_devices[self.local_rank]
+
+ @property
+ def num_nodes(self) -> int:
+ return self._num_nodes
+
+ @num_nodes.setter
+ def num_nodes(self, num_nodes: int) -> None:
+ self._num_nodes = num_nodes
+
+ @property
+ def num_processes(self) -> int:
+ return len(self.parallel_devices) if self.parallel_devices is not None else 0
+
+ @property
+ @override
+ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
+ assert self.device_mesh is not None
+ data_parallel_mesh = self.device_mesh["data_parallel"]
+ return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()}
+
+ @property
+ def process_group_backend(self) -> Optional[str]:
+ return self._process_group_backend
+
+ @override
+ def _configure_launcher(self) -> None:
+ assert self.cluster_environment is not None
+ if not self.cluster_environment.creates_processes_externally:
+ self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
+
+ @override
+ def setup_environment(self) -> None:
+ super().setup_environment()
+ self._setup_distributed()
+ if self._data_parallel_size == "auto":
+ self._data_parallel_size = self.num_nodes
+ if self._tensor_parallel_size == "auto":
+ self._tensor_parallel_size = self.num_processes
+ self._device_mesh = _setup_device_mesh(
+ self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device
+ )
+
+ @override
+ def setup_module(self, module: Module) -> Module:
+ from torch.distributed.fsdp import FullyShardedDataParallel
+
+ if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
+ raise TypeError(
+ "Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
+ f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
+ )
+
+ module = self._parallelize_fn(module, self.device_mesh) # type: ignore[arg-type]
+ if not isinstance(module, Module):
+ raise TypeError(
+ f"The `parallelize_fn` must return a `nn.Module` instance, but got: {type(module).__name__}"
+ )
+ _materialize_distributed_module(module, self.root_device)
+ return module
+
+ @override
+ def module_to_device(self, module: Module) -> None:
+ pass
+
+ @override
+ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
+ precision_init_ctx = self.precision.module_init_context()
+ stack = ExitStack()
+ if empty_init:
+ # Materializaton happens in `setup_module`
+ # TODO: Introduce `Fabric.materialize(module)` to give user control over materialization
+ stack.enter_context(torch.device("meta"))
+ stack.enter_context(precision_init_ctx)
+ return stack
+
+ @override
+ def all_reduce(
+ self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
+ ) -> Tensor:
+ if isinstance(tensor, Tensor):
+ return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
+ return tensor
+
+ @override
+ def barrier(self, *args: Any, **kwargs: Any) -> None:
+ if not _distributed_is_initialized():
+ return
+ if torch.distributed.get_backend() == "nccl":
+ torch.distributed.barrier(device_ids=[self.root_device.index])
+ else:
+ torch.distributed.barrier()
+
+ @override
+ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
+ if not _distributed_is_initialized():
+ return obj
+
+ obj = [obj]
+ torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
+ return obj[0]
+
+ @override
+ def save_checkpoint(
+ self,
+ path: _PATH,
+ state: Dict[str, Union[Module, Optimizer, Any]],
+ storage_options: Optional[Any] = None,
+ filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
+ ) -> None:
+ """Save model, optimizer, and other state to a checkpoint on disk.
+
+ If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file
+ per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file
+ `meta.pt` with the rest of the user's state (only saved from rank 0).
+ If distributed checkpointing is disabled (``save_distributed_checkpoint=False``), the checkpoint will be
+ written to a single file containing the weights, optimizer state and other metadata.
+
+ """
+ if storage_options is not None:
+ raise TypeError(
+ f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because"
+ f" `{type(self).__name__}` does not use the `CheckpointIO`."
+ )
+ if filter is not None and self._save_distributed_checkpoint:
+ # https://github.com/pytorch/pytorch/issues/105379
+ raise NotImplementedError(
+ f"{type(self).__name__} doesn't support loading distributed filtered checkpoints,"
+ " so saving them is disabled."
+ )
+ # broadcast the path from rank 0 to ensure all the states are saved in a common path
+ path = Path(self.broadcast(path))
+ _save_checkpoint(
+ path=path,
+ state=state,
+ full_state_dict=(not self._save_distributed_checkpoint),
+ rank=self.global_rank,
+ filter=filter,
+ )
+
+ @override
+ def load_checkpoint(
+ self,
+ path: _PATH,
+ state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
+ strict: bool = True,
+ ) -> Dict[str, Any]:
+ """Load the contents from a checkpoint and restore the state of the given objects."""
+ if not state:
+ raise ValueError(
+ f"Got {type(self).__name__}.load_checkpoint(..., state={state!r}) but a state with at least "
+ " a model instance to reload is required. Pass it in like so:"
+ f" {type(self).__name__}.load_checkpoint(..., state={{'model': model, ...}})"
+ )
+ # broadcast the path from rank 0 to ensure all the states are loaded from a common path
+ path = Path(self.broadcast(path))
+
+ if isinstance(state, Module):
+ _load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
+ return {}
+
+ if isinstance(state, Optimizer):
+ raise NotImplementedError(
+ f"Loading a single optimizer object from a checkpoint is not supported yet with"
+ f" {type(self).__name__}."
+ )
+
+ return _load_checkpoint(path=path, state=state, strict=strict)
+
+ def _setup_distributed(self) -> None:
+ reset_seed()
+ self._set_world_ranks()
+ self._process_group_backend = self._get_process_group_backend()
+ assert self.cluster_environment is not None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+
+ def _get_process_group_backend(self) -> str:
+ return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
+
+ def _set_world_ranks(self) -> None:
+ if self.cluster_environment is not None:
+ self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
+ self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
+ # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
+ # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
+ rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
+
+
+class _ParallelBackwardSyncControl(_BackwardSyncControl):
+ @override
+ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
+ """Blocks gradient synchronization inside the FSDP2 modules."""
+ return _FSDPNoSync(module=module, enabled=enabled)
+
+
+class _FSDPNoSync(ContextManager):
+ def __init__(self, module: Module, enabled: bool) -> None:
+ self._module = module
+ self._enabled = enabled
+
+ def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
+ from torch.distributed._composable.fsdp import FSDPModule
+
+ for mod in self._module.modules():
+ if isinstance(mod, FSDPModule):
+ mod.set_requires_gradient_sync(requires_grad_sync, recurse=False)
+
+ def __enter__(self) -> None:
+ self._set_requires_grad_sync(not self._enabled)
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self._set_requires_grad_sync(self._enabled)
+
+
+def _save_checkpoint(
+ path: Path,
+ state: Dict[str, Union[Module, Optimizer, Any]],
+ full_state_dict: bool,
+ rank: int,
+ filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
+) -> None:
+ if path.is_dir() and full_state_dict and not _is_sharded_checkpoint(path):
+ raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
+
+ modules = [module for module in state.values() if _has_dtensor_modules(module)]
+ if len(modules) == 0:
+ raise ValueError(
+ "Could not find a distributed model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
+ )
+ if len(modules) > 1:
+ raise ValueError(
+ "Found multiple distributed models in the given state. Saving distributed checkpoints is"
+ " currently limited to a single model per checkpoint. To save multiple models, call the"
+ " save method for each model separately with a different path."
+ )
+ module = modules[0]
+
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, get_optimizer_state_dict
+
+ state_dict_options = StateDictOptions(full_state_dict=full_state_dict, cpu_offload=True)
+
+ # replace the modules and optimizer objects in the state with their local state dict
+ # and separate the user's metadata
+ converted_state: Dict[str, Any] = {}
+ metadata: Dict[str, Any] = {}
+ for key, obj in state.items():
+ converted: Any
+ if isinstance(obj, Module):
+ converted = get_model_state_dict(obj, options=state_dict_options)
+ target_dict = converted_state
+ elif isinstance(obj, Optimizer):
+ converted = get_optimizer_state_dict(module, obj, options=state_dict_options)
+ target_dict = converted_state
+ else: # everything not a module or optimizer is considered metadata
+ converted = obj.state_dict() if isinstance(obj, _Stateful) else obj
+ target_dict = metadata
+ _apply_filter(key, filter or {}, converted, target_dict)
+
+ if full_state_dict:
+ if _is_sharded_checkpoint(path):
+ shutil.rmtree(path)
+ converted_state.update(metadata)
+ if rank == 0:
+ torch.save(converted_state, path)
+ else:
+ if path.is_file():
+ path.unlink()
+ path.mkdir(parents=True, exist_ok=True)
+ _distributed_checkpoint_save(converted_state, path)
+ if rank == 0:
+ torch.save(metadata, path / _METADATA_FILENAME)
+
+
+def _load_checkpoint(
+ path: Path,
+ state: Dict[str, Union[Module, Optimizer, Any]],
+ strict: bool = True,
+ optimizer_states_from_list: bool = False,
+) -> Dict[str, Any]:
+ from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
+ get_model_state_dict,
+ get_optimizer_state_dict,
+ set_optimizer_state_dict,
+ )
+
+ modules = {key: module for key, module in state.items() if _has_dtensor_modules(module)}
+ if len(modules) == 0:
+ raise ValueError(
+ "Could not find a distributed model in the provided checkpoint state. Please provide the model as"
+ " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
+ " you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
+ )
+ optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)}
+ if len(modules) > 1:
+ raise ValueError(
+ "Found multiple distributed models in the given state. Loading distributed checkpoints is"
+ " currently limited to a single model per checkpoint. To load multiple models, call the"
+ " load method for each model separately with a different path."
+ )
+ module_key, module = list(modules.items())[0]
+
+ if _is_sharded_checkpoint(path):
+ state_dict_options = StateDictOptions(cpu_offload=True)
+
+ module_state = {module_key: get_model_state_dict(module)}
+ _distributed_checkpoint_load(module_state, path)
+ module.load_state_dict(module_state[module_key], strict=strict)
+
+ # the optimizer states must be loaded separately
+ for optim_key, optim in optimizers.items():
+ optim_state = {optim_key: get_optimizer_state_dict(module, optim)}
+ _distributed_checkpoint_load(optim_state, path)
+ set_optimizer_state_dict(module, optim, optim_state_dict=optim_state[optim_key], options=state_dict_options)
+
+ # Load metadata (anything not a module or optimizer)
+ metadata = torch.load(path / _METADATA_FILENAME)
+ requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
+ _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)
+ for key in requested_metadata_keys:
+ if key not in metadata:
+ continue
+ state[key] = metadata.pop(key)
+
+ # return the remaining metadata that wasn't requested as part of `state`
+ return metadata
+
+ if _is_full_checkpoint(path):
+ checkpoint = torch.load(path, mmap=True, map_location="cpu", weights_only=False)
+ _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
+
+ state_dict_options = StateDictOptions(
+ broadcast_from_rank0=True,
+ full_state_dict=True,
+ strict=strict,
+ )
+ for optimizer_idx, (optimizer_name, optimizer) in enumerate(optimizers.items()):
+ if optimizer_states_from_list:
+ # This code path is only used by `lightning.pytorch`, which saves optimizer states as a list
+ # rather than individual states at the top level.
+ optimizer_state = checkpoint["optimizer_states"][optimizer_idx]
+ else:
+ optimizer_state = checkpoint.pop(optimizer_name)
+
+ optimizer_state = _rekey_optimizer_state_if_needed(optimizer_state, module)
+ set_optimizer_state_dict(
+ module,
+ optimizer,
+ optim_state_dict=optimizer_state,
+ options=state_dict_options,
+ )
+
+ requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
+ _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
+
+ # Load metadata (anything not a module or optimizer)
+ _move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys)
+
+ # return the remaining metadata that wasn't requested as part of `state`
+ return checkpoint
+
+ raise ValueError(
+ f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a"
+ " directory with distributed checkpoint shards, or a single file with a full checkpoint."
+ )
+
+
+def _setup_device_mesh(
+ data_parallel_size: int,
+ tensor_parallel_size: int,
+ world_size: int,
+ device: torch.device,
+) -> "DeviceMesh":
+ from torch.distributed.device_mesh import init_device_mesh
+
+ if data_parallel_size * tensor_parallel_size != world_size:
+ raise RuntimeError(
+ f"The sizes `data_parallel_size={data_parallel_size}` and"
+ f" `tensor_parallel_size={tensor_parallel_size}` multiplied should equal the world size"
+ f" ({world_size})."
+ )
+ return init_device_mesh(
+ device_type=device.type,
+ mesh_shape=(data_parallel_size, tensor_parallel_size),
+ mesh_dim_names=("data_parallel", "tensor_parallel"),
+ )
+
+
+def _has_dtensor_modules(module: object) -> TypeGuard[Module]:
+ from torch.distributed._tensor import DTensor
+
+ return isinstance(module, Module) and any(isinstance(t, DTensor) for t in module.parameters())
+
+
+def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int, strict: bool = True) -> None:
+ """Loads the state dict from a file path into the FSDP module."""
+ if not _is_full_checkpoint(path):
+ raise ValueError(
+ "Failed to load checkpoint directly into the model. The given path must be a single file containing the"
+ f" full state dict: {path}"
+ )
+ # Use `lazy_load`/`mmap` instead to avoid storing a copy of the full checkpoint per rank
+ state_dict = torch.load(path, mmap=True, map_location="cpu") if _TORCH_GREATER_EQUAL_2_3 else _lazy_load(path)
+ _load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict)
+
+
+def _load_raw_module_state(
+ state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True
+) -> None:
+ """Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ if _has_dtensor_modules(module):
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
+
+ state_dict_options = StateDictOptions(
+ broadcast_from_rank0=True,
+ full_state_dict=True,
+ # must be set False to allow loading each param separately below
+ strict=False,
+ )
+
+ for submodule_name, submodule in module.named_modules():
+ for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
+ full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
+ if full_param_name not in state_dict:
+ if not strict:
+ continue
+ raise KeyError(
+ f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint."
+ " To disable strict loading, set `strict=False`."
+ )
+ local_state_dict = {param_name: state_dict[full_param_name]}
+ set_model_state_dict(submodule, local_state_dict, options=state_dict_options)
+
+ elif isinstance(module, FSDP):
+ with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
+ module.load_state_dict(state_dict, strict=strict)
+ else:
+ module.load_state_dict(state_dict, strict=strict)
+
+
+def _named_parameters_and_buffers_to_load(module: Module) -> Generator:
+ """Returns parameters and buffers, with non-persistent buffers excluded."""
+ for param_name, param in itertools.chain(
+ module.named_buffers(recurse=False),
+ module.named_parameters(recurse=False),
+ ):
+ if param_name in module._non_persistent_buffers_set:
+ continue
+ yield param_name, param
+
+
+def _rekey_optimizer_state_if_needed(optimizer_state_dict: Dict[str, Any], module: Module) -> Dict[str, Any]:
+ """Handles the case where the optimizer state is saved from a normal optimizer and converts the keys to parameter
+ names."""
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp import OptimStateKeyType
+
+ if isinstance(list(optimizer_state_dict["state"].keys())[0], int):
+ optimizer_state_dict = FSDP.rekey_optim_state_dict(optimizer_state_dict, OptimStateKeyType.PARAM_NAME, module)
+ return optimizer_state_dict
diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py
index 1c64f97394fa2..6bfed6a270b68 100644
--- a/src/lightning/fabric/strategies/strategy.py
+++ b/src/lightning/fabric/strategies/strategy.py
@@ -29,7 +29,6 @@
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.utilities.apply_func import move_data_to_device
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp, _Stateful
@@ -122,8 +121,7 @@ def tensor_init_context(self) -> ContextManager:
"""Controls how tensors get created (device, dtype)."""
precision_init_ctx = self.precision.tensor_init_context()
stack = ExitStack()
- if _TORCH_GREATER_EQUAL_2_0:
- stack.enter_context(self.root_device)
+ stack.enter_context(self.root_device)
stack.enter_context(precision_init_ctx)
return stack
@@ -140,8 +138,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag
"""
precision_module_ctx = self.precision.module_init_context()
stack = ExitStack()
- if _TORCH_GREATER_EQUAL_2_0:
- stack.enter_context(self.root_device)
+ stack.enter_context(self.root_device)
stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
stack.enter_context(precision_module_ctx)
return stack
diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py
index 1b53292ff1581..e4c080d8110db 100644
--- a/src/lightning/fabric/strategies/xla_fsdp.py
+++ b/src/lightning/fabric/strategies/xla_fsdp.py
@@ -39,7 +39,6 @@
_validate_keys_for_strict_loading,
)
from lightning.fabric.utilities.cloud_io import get_filesystem
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp
@@ -57,7 +56,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded):
.. warning:: This is an :ref:`experimental ` feature.
- For more information check out https://github.com/pytorch/xla/blob/master/docs/fsdp.md
+ For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md
Args:
auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in
@@ -420,11 +419,6 @@ def save_checkpoint(
consolidated checkpoint combining all of the sharded checkpoints.
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- "Saving and loading checkpoints with the `XLAFSDPStrategy` is not supported in PyTorch < 2.0."
- " Please upgrade `torch`."
- )
# broadcast the path from rank 0 to ensure all the states are saved in a common path
path = Path(self.broadcast(path))
if path.is_dir() and any(path.iterdir()):
@@ -527,11 +521,6 @@ def load_checkpoint(
directory of multiple files rather than a single file.
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- "Saving and loading checkpoints with the `FSDPStrategy` is not supported in PyTorch < 2.0."
- " Please upgrade `torch` or file an issue: `https://github.com/Lightning-AI/lightning/issues`."
- )
if not state:
raise ValueError(
f"Got `XLAFSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least "
diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py
index 15a4b48ed4518..d43565f494d3c 100644
--- a/src/lightning/fabric/utilities/apply_func.py
+++ b/src/lightning/fabric/utilities/apply_func.py
@@ -15,19 +15,22 @@
from abc import ABC
from functools import partial
-from typing import Any, Callable, List, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union
-import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
+from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE
from lightning.fabric.utilities.types import _DEVICE
+if TYPE_CHECKING:
+ import numpy as np
+
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
-def _from_numpy(value: np.ndarray, device: _DEVICE) -> Tensor:
+def _from_numpy(value: "np.ndarray", device: _DEVICE) -> Tensor:
return torch.from_numpy(value).to(device)
@@ -36,9 +39,13 @@ def _from_numpy(value: np.ndarray, device: _DEVICE) -> Tensor:
(bool, partial(torch.tensor, dtype=torch.uint8)),
(int, partial(torch.tensor, dtype=torch.int)),
(float, partial(torch.tensor, dtype=torch.float)),
- (np.ndarray, _from_numpy),
]
+if _NUMPY_AVAILABLE:
+ import numpy as np
+
+ CONVERSION_DTYPES.append((np.ndarray, _from_numpy))
+
class _TransferableDataType(ABC):
"""A custom type for data that can be moved to a torch device via ``.to(...)``.
diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py
index af795a4801ef0..7ecc9eea501a6 100644
--- a/src/lightning/fabric/utilities/cloud_io.py
+++ b/src/lightning/fabric/utilities/cloud_io.py
@@ -33,6 +33,7 @@
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
+ weights_only: bool = False,
) -> Any:
"""Loads a checkpoint.
@@ -46,15 +47,21 @@ def _load(
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
+ weights_only=weights_only,
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
+ weights_only=weights_only,
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
- return torch.load(f, map_location=map_location) # type: ignore[arg-type]
+ return torch.load(
+ f,
+ map_location=map_location, # type: ignore[arg-type]
+ weights_only=weights_only,
+ )
def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
@@ -76,7 +83,10 @@ def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None
bytesbuffer = io.BytesIO()
log.debug(f"Saving checkpoint: {filepath}")
torch.save(checkpoint, bytesbuffer)
- with fsspec.open(filepath, "wb") as f:
+
+ # We use a transaction here to avoid file corruption if the save gets interrupted
+ fs, urlpath = fsspec.core.url_to_fs(str(filepath))
+ with fs.transaction, fs.open(urlpath, "wb") as f:
f.write(bytesbuffer.getvalue())
diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py
index a9e614cbbd1d2..9f06dc50cfbef 100644
--- a/src/lightning/fabric/utilities/device_dtype_mixin.py
+++ b/src/lightning/fabric/utilities/device_dtype_mixin.py
@@ -109,14 +109,12 @@ def half(self) -> Self:
def _update_properties(
root: torch.nn.Module, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
- def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
+ for module in root.modules():
if not isinstance(module, _DeviceDtypeModuleMixin):
- return
+ continue
# cannot use `module.to()` because we don't actually want to move the model in case there are multiple
# devices types (such as partial meta parameters)
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype
-
- root.apply(apply_fn)
diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py
index 30bfe4e254a07..0e6c52dfb09b9 100644
--- a/src/lightning/fabric/utilities/distributed.py
+++ b/src/lightning/fabric/utilities/distributed.py
@@ -1,6 +1,8 @@
+import atexit
import contextlib
import logging
import os
+import signal
import time
from contextlib import nullcontext
from datetime import timedelta
@@ -12,10 +14,11 @@
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
-from typing_extensions import Self, override
+from typing_extensions import Self, TypeGuard, override
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
@@ -28,6 +31,8 @@ class group: # type: ignore
if TYPE_CHECKING:
+ from torch.distributed._tensor import DTensor
+
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy
@@ -291,6 +296,10 @@ def _init_dist_connection(
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
+ if torch_distributed_backend == "nccl":
+ # PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
+ atexit.register(_destroy_dist_connection)
+
# On rank=0 let everyone know training is starting
rank_zero_info(
f"{'-' * 100}\n"
@@ -300,6 +309,14 @@ def _init_dist_connection(
)
+def _destroy_dist_connection() -> None:
+ # Don't allow Ctrl+C to interrupt this handler
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ if _distributed_is_initialized():
+ torch.distributed.destroy_process_group()
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+
+
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
@@ -413,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)
+
+
+def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
+ if _TORCH_GREATER_EQUAL_2_4:
+ from torch.distributed._tensor import DTensor
+
+ return isinstance(tensor, DTensor)
+ return False
diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py
index cc069a2a73338..a1c5a6f6dcd1b 100644
--- a/src/lightning/fabric/utilities/imports.py
+++ b/src/lightning/fabric/utilities/imports.py
@@ -17,7 +17,10 @@
import platform
import sys
-from lightning_utilities.core.imports import compare_version
+from lightning_utilities.core.imports import RequirementCache, compare_version
+
+_NUMPY_AVAILABLE = RequirementCache("numpy")
+
_IS_WINDOWS = platform.system() == "Windows"
@@ -26,13 +29,12 @@
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
-_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
-_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0")
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
-_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True)
-_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
+_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
+_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
+_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
+_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
-_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_UTILITIES_GREATER_EQUAL_0_10 = compare_version("lightning_utilities", operator.ge, "0.10.0")
diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py
index e1b80f7a55fe1..c92dfd8c2e82b 100644
--- a/src/lightning/fabric/utilities/init.py
+++ b/src/lightning/fabric/utilities/init.py
@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
-from typing import Any, Callable, Dict, Optional, Sequence
+from typing import Any, Callable, Dict, Optional, Sequence, Union
import torch
+from torch.nn import Module, Parameter
+from torch.optim import Optimizer
from torch.overrides import TorchFunctionMode
from typing_extensions import override
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
+from lightning.fabric.utilities.rank_zero import rank_zero_warn
from lightning.fabric.utilities.types import _DEVICE
@@ -56,10 +58,8 @@ def __torch_function__(
return func(*args, **kwargs)
-def _materialize(module: torch.nn.Module, device: _DEVICE) -> None:
+def _materialize(module: Module, device: _DEVICE) -> None:
"""Materialize a module."""
- if not _TORCH_GREATER_EQUAL_2_1:
- raise RuntimeError("recurse=False requires torch 2.1")
module.to_empty(device=device, recurse=False)
if not hasattr(module, "reset_parameters"):
raise TypeError(
@@ -69,8 +69,45 @@ def _materialize(module: torch.nn.Module, device: _DEVICE) -> None:
module.reset_parameters()
-def _materialize_meta_tensors(module: torch.nn.Module, device: _DEVICE) -> None:
+def _materialize_meta_tensors(module: Module, device: _DEVICE) -> None:
"""Materialize all tensors in a given module."""
for module in module.modules():
- if any(t.is_meta for t in itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))):
+ if _has_meta_device_parameters_or_buffers(module, recurse=False):
_materialize(module, device)
+
+
+def _materialize_distributed_module(module: Module, device: torch.device) -> None:
+ # Reference: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#meta-device-initialization
+ # TODO: Introduce `Fabric.materialize(module)` to give user control when materialization should happen
+ # TODO: Make `torchmetrics.Metric` compatible with the `to_empty()` + `reset_parameters()` semantics
+ if not _has_meta_device_parameters_or_buffers(module):
+ return
+
+ module.to_empty(device=device) # has to be called on the root module
+
+ uninitialized_modules = set()
+ for submodule in module.modules():
+ if all(False for _ in itertools.chain(submodule.parameters(recurse=False), submodule.buffers(recurse=False))):
+ # module has no parameters or buffers
+ continue
+ if callable(reset_method := getattr(submodule, "reset_parameters", None)):
+ reset_method()
+ else:
+ uninitialized_modules.add(type(submodule).__name__)
+
+ if uninitialized_modules:
+ rank_zero_warn(
+ "Parameter initialization incomplete. The following modules have parameters or buffers with uninitialized"
+ " memory because they don't define a `reset_parameters()` method for re-initialization:"
+ f" {', '.join(uninitialized_modules)}"
+ )
+
+
+def _has_meta_device_parameters_or_buffers(obj: Union[Module, Optimizer], recurse: bool = True) -> bool:
+ if isinstance(obj, Optimizer):
+ return any(
+ t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
+ )
+ if isinstance(obj, Module):
+ return any(t.is_meta for t in itertools.chain(obj.parameters(recurse=recurse), obj.buffers(recurse=recurse)))
+ raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py
index 29ccca9e4375f..a1c3b6933b2f6 100644
--- a/src/lightning/fabric/utilities/load.py
+++ b/src/lightning/fabric/utilities/load.py
@@ -25,10 +25,7 @@
from torch.nn import Parameter
from typing_extensions import override
-from lightning.fabric.utilities.imports import (
- _TORCH_GREATER_EQUAL_2_0,
- _TORCH_GREATER_EQUAL_2_3,
-)
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.types import _PATH, _Stateful
_METADATA_FILENAME = "meta.pt"
@@ -143,6 +140,10 @@ def __torch_function__(
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
return func(*loaded_args, **kwargs)
+ @property
+ def device(self) -> torch.device:
+ return torch.device(self.storageinfo[3])
+
def __getattr__(self, name: str) -> Any:
# These properties don't require materialization and can be accessed through the meta tensor directly
if name in {
@@ -163,7 +164,7 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.metatensor, name)
# materializing these is needed for quantization (see lit-gpt)
- if name in {"contiguous", "cuda", "half"}:
+ if name in {"contiguous", "cuda", "half", "data", "to"}:
return getattr(self._load_tensor(), name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
@@ -202,8 +203,6 @@ def persistent_load(self, pid: tuple) -> "TypedStorage":
def _lazy_load(filename: _PATH) -> Any:
- if not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError("Lazy-loading is only supported with PyTorch >= 2.0.")
if not os.path.isfile(filename):
raise FileNotFoundError(f"Path {str(filename)!r} does not exist or is not a file.")
file_reader = torch.PyTorchFileReader(str(filename))
diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py
index 2604a0d926d21..07b76ad9b04d8 100644
--- a/src/lightning/fabric/utilities/logger.py
+++ b/src/lightning/fabric/utilities/logger.py
@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+import json
from argparse import Namespace
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Mapping, MutableMapping, Optional, Union
-import numpy as np
from torch import Tensor
+from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE
+
def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[str, Any]:
"""Ensure parameters are a dict or convert to dict if necessary.
@@ -52,8 +55,11 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
def _sanitize_callable(val: Any) -> Any:
- # Give them one chance to return a value. Don't go rabbit hole of recursive call
+ if inspect.isclass(val):
+ # If it's a class, don't try to instantiate it, just return the name
+ return val.__name__
if callable(val):
+ # Callables get a chance to return a name
try:
_val = val()
if callable(_val):
@@ -89,7 +95,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
result: Dict[str, Any] = {}
for k, v in params.items():
new_key = parent_key + delimiter + str(k) if parent_key else str(k)
- if is_dataclass(v):
+ if is_dataclass(v) and not isinstance(v, type):
v = asdict(v)
elif isinstance(v, Namespace):
v = vars(v)
@@ -124,14 +130,33 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""
for k in params:
- # convert relevant np scalars to python types first (instead of str)
- if isinstance(params[k], (np.bool_, np.integer, np.floating)):
- params[k] = params[k].item()
- elif type(params[k]) not in [bool, int, float, str, Tensor]:
+ if _NUMPY_AVAILABLE:
+ import numpy as np
+
+ if isinstance(params[k], (np.bool_, np.integer, np.floating)):
+ params[k] = params[k].item()
+ if type(params[k]) not in [bool, int, float, str, Tensor]:
params[k] = str(params[k])
return params
+def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Convert non-serializable objects in params to string."""
+ return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()}
+
+
+def _is_json_serializable(value: Any) -> bool:
+ """Test whether a variable can be encoded as json."""
+ if value is None or isinstance(value, (bool, int, float, str, list, dict)): # fast path
+ return True
+ try:
+ json.dumps(value)
+ return True
+ except (TypeError, OverflowError):
+ # OverflowError is raised if number is too large to encode
+ return False
+
+
def _add_prefix(
metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str
) -> Mapping[str, Union[Tensor, float]]:
diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py
index e2605ceca4670..2c57ec9d1f64a 100644
--- a/src/lightning/fabric/utilities/optimizer.py
+++ b/src/lightning/fabric/utilities/optimizer.py
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections.abc import MutableMapping
from typing import Iterable
-from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.optim import Optimizer
-from lightning.fabric.utilities.apply_func import move_data_to_device
+from lightning.fabric.utilities.apply_func import apply_to_collection, move_data_to_device
from lightning.fabric.utilities.types import _DEVICE
@@ -31,4 +31,12 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N
def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
for p, v in optimizer.state.items():
- optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
+ if not isinstance(v, MutableMapping):
+ # Support for custom optimizers
+ optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
+ continue
+ for key, val in v.items():
+ # The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer
+ # needs it. See https://github.com/pytorch/pytorch/issues/74424
+ if key != "step":
+ v[key] = move_data_to_device(val, device)
diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py
index a3b2ac442b46b..9ad0f90221429 100644
--- a/src/lightning/fabric/utilities/registry.py
+++ b/src/lightning/fabric/utilities/registry.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from importlib.metadata import entry_points
from inspect import getmembers, isclass
from types import ModuleType
from typing import Any, List, Type, Union
from lightning_utilities import is_overridden
-from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0
_log = logging.getLogger(__name__)
@@ -35,16 +36,9 @@ def _load_external_callbacks(group: str) -> List[Any]:
A list of all callbacks collected from external factories.
"""
- if _PYTHON_GREATER_EQUAL_3_8_0:
- from importlib.metadata import entry_points
-
- factories = (
- entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
- )
- else:
- from pkg_resources import iter_entry_points
-
- factories = iter_entry_points(group) # type: ignore[assignment]
+ factories = (
+ entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type]
+ )
external_callbacks: List[Any] = []
for factory in factories:
diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py
index b274bce88fcdf..a2d627828a77e 100644
--- a/src/lightning/fabric/utilities/seed.py
+++ b/src/lightning/fabric/utilities/seed.py
@@ -3,20 +3,21 @@
import random
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
-import numpy as np
import torch
+from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE
from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
-max_seed_value = np.iinfo(np.uint32).max
-min_seed_value = np.iinfo(np.uint32).min
+max_seed_value = 4294967295 # 2^32 - 1 (uint32)
+min_seed_value = 0
-def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
+
+def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: bool = True) -> int:
r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module.
In addition, sets the following environment variables:
@@ -31,6 +32,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~lightning.fabric.utilities.seed.pl_worker_init_function`.
+ verbose: Whether to print a message on each rank with the seed being set.
"""
if seed is None:
@@ -51,10 +53,15 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0
- log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
+ if verbose:
+ log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
- np.random.seed(seed)
+ if _NUMPY_AVAILABLE:
+ import numpy as np
+
+ np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
@@ -72,7 +79,7 @@ def reset_seed() -> None:
if seed is None:
return
workers = os.environ.get("PL_SEED_WORKERS", "0")
- seed_everything(int(seed), workers=bool(int(workers)))
+ seed_everything(int(seed), workers=bool(int(workers)), verbose=False)
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
@@ -91,24 +98,38 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
log.debug(
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
)
- ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
- # use 128 bits (4 x 32-bit words)
- np.random.seed(ss.generate_state(4))
- # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
- torch_ss, stdlib_ss = ss.spawn(2)
- torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
- # use 128 bits expressed as an integer
- stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
- random.seed(stdlib_seed)
+ seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
+ torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
+ random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
+ if _NUMPY_AVAILABLE:
+ import numpy as np
+
+ np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only
+
+
+def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:
+ """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
+ algorithm."""
+ # Combine base seed, worker id and rank into a unique 64-bit number
+ combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
+ seeds = []
+ for _ in range(count):
+ # x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
+ combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
+ seeds.append(combined_seed)
+ return seeds
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
states = {
"torch": torch.get_rng_state(),
- "numpy": np.random.get_state(),
"python": python_get_rng_state(),
}
+ if _NUMPY_AVAILABLE:
+ import numpy as np
+
+ states["numpy"] = np.random.get_state()
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else []
return states
@@ -121,6 +142,9 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
- np.random.set_state(rng_state_dict["numpy"])
+ if _NUMPY_AVAILABLE and "numpy" in rng_state_dict:
+ import numpy as np
+
+ np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py
index b9bfd1e269d71..6f0513465cab5 100644
--- a/src/lightning/fabric/utilities/testing/_runif.py
+++ b/src/lightning/fabric/utilities/testing/_runif.py
@@ -17,14 +17,14 @@
from typing import Dict, List, Optional, Tuple
import torch
-from lightning_utilities.core.imports import compare_version
+from lightning_utilities.core.imports import RequirementCache, compare_version
from packaging.version import Version
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
def _runif_reasons(
@@ -112,18 +112,15 @@ def _runif_reasons(
reasons.append("Standalone execution")
kwargs["standalone"] = True
- if deepspeed and not _DEEPSPEED_AVAILABLE:
+ if deepspeed and not (
+ _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
+ ):
reasons.append("Deepspeed")
if dynamo:
- if _TORCH_GREATER_EQUAL_2_1:
- from torch._dynamo.eval_frame import is_dynamo_supported
+ from torch._dynamo.eval_frame import is_dynamo_supported
- cond = not is_dynamo_supported()
- else:
- cond = sys.platform == "win32" or sys.version_info >= (3, 11)
- cond |= not _TORCH_GREATER_EQUAL_2_0
- if cond:
+ if not is_dynamo_supported():
reasons.append("torch.dynamo")
return reasons, kwargs
diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py
index c340686346c0f..6743da7b34085 100644
--- a/src/lightning/fabric/utilities/throughput.py
+++ b/src/lightning/fabric/utilities/throughput.py
@@ -18,7 +18,6 @@
import torch
from typing_extensions import override
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
@@ -292,11 +291,9 @@ def measure_flops(
FLOPs will be included in the result.
"""
- if not _TORCH_GREATER_EQUAL_2_1:
- raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
from torch.utils.flop_counter import FlopCounterMode
- flop_counter = FlopCounterMode(model, display=False)
+ flop_counter = FlopCounterMode(display=False)
with flop_counter:
if loss_fn is None:
forward_fn()
diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py
index c4bc32f3cf319..2e18dc89b05b2 100644
--- a/src/lightning/fabric/utilities/types.py
+++ b/src/lightning/fabric/utilities/types.py
@@ -28,10 +28,10 @@
import torch
from torch import Tensor
-from torch.optim import Optimizer
-from typing_extensions import TypeAlias, overload
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
+# TODO: Unused import, but lightning_habana imports these from here
+from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau # noqa: F401
+from typing_extensions import TypeAlias, overload
UntypedStorage: TypeAlias = torch.UntypedStorage
@@ -42,7 +42,6 @@
]
_PARAMETERS = Iterator[torch.nn.Parameter]
-
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp
@@ -70,49 +69,6 @@ def size(self) -> int: ...
def rank(self) -> int: ...
-# Inferred from `torch.optim.lr_scheduler.pyi`
-# Missing attributes were added to improve typing
-@runtime_checkable
-class LRScheduler(_Stateful[str], Protocol):
- optimizer: Optimizer
- base_lrs: List[float]
-
- def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ...
-
- def step(self, epoch: Optional[int] = None) -> None: ...
-
-
-_TORCH_LRSCHEDULER: TypeAlias = (
- torch.optim.lr_scheduler.LRScheduler # type: ignore[valid-type]
- if _TORCH_GREATER_EQUAL_2_0
- else torch.optim.lr_scheduler._LRScheduler
-)
-
-
-# Inferred from `torch.optim.lr_scheduler.pyi`
-# Missing attributes were added to improve typing
-@runtime_checkable
-class ReduceLROnPlateau(_Stateful[str], Protocol):
- in_cooldown: bool
- optimizer: Optimizer
-
- def __init__(
- self,
- optimizer: Optimizer,
- mode: str = ...,
- factor: float = ...,
- patience: int = ...,
- verbose: bool = ...,
- threshold: float = ...,
- threshold_mode: str = ...,
- cooldown: int = ...,
- min_lr: float = ...,
- eps: float = ...,
- ) -> None: ...
-
- def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None: ...
-
-
@runtime_checkable
class Steppable(Protocol):
"""To structurally type ``optimizer.step()``"""
diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py
index 093b355e2c376..c57f1974a6bba 100644
--- a/src/lightning/fabric/wrappers.py
+++ b/src/lightning/fabric/wrappers.py
@@ -14,8 +14,8 @@
import inspect
from copy import deepcopy
from functools import partial, wraps
+from types import MethodType
from typing import (
- TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -35,6 +35,7 @@
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
+from torch._dynamo import OptimizedModule
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@@ -45,12 +46,8 @@
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
-if TYPE_CHECKING:
- from torch._dynamo import OptimizedModule
-
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
@@ -127,6 +124,7 @@ def __init__(
self._forward_module = forward_module
self._original_module = original_module or forward_module
self._strategy = strategy
+ self._forward_methods = set(_LIGHTNING_MODULE_STEP_METHODS)
self._fabric_module_initialized = True
@property
@@ -169,6 +167,20 @@ def load_state_dict( # type: ignore[override]
) -> _IncompatibleKeys:
return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
+ def mark_forward_method(self, method: Union[MethodType, str]) -> None:
+ """Mark a method as a 'forward' method to prevent it bypassing the strategy wrapper (e.g., DDP)."""
+ if not isinstance(method, (MethodType, str)):
+ raise TypeError(f"Expected a method or a string, but got: {type(method).__name__}")
+ name = method if isinstance(method, str) else method.__name__
+ if name == "forward":
+ raise ValueError("You cannot mark the forward method itself as a forward method.")
+ if not isinstance(getattr(self._original_module, name, None), MethodType):
+ raise AttributeError(
+ f"You marked '{name}' as a forward method, but `{type(self._original_module).__name__}.{name}` does not"
+ f" exist or is not a method."
+ )
+ self._forward_methods.add(name)
+
def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward
@@ -211,8 +223,8 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
- " model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
- " `.backward()`. You should pass your inputs through `forward()`.",
+ " model. To avoid issues with the currently selected strategy, explicitly mark it as a"
+ f" forward method with `fabric_model.mark_forward_method({name!r})` after `fabric.setup()`."
)
for handle in handles:
handle.remove()
@@ -235,8 +247,12 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
@override
def __getattr__(self, item: Any) -> Any:
- if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
- # Special support for `LightningModule`, to prevent bypassing DDP's forward
+ if (
+ item != "_forward_methods"
+ and item in self._forward_methods
+ and self._forward_module != self._original_module
+ ):
+ # Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)
try:
@@ -329,26 +345,17 @@ def _unwrap(
return obj
types = [_FabricModule, _FabricOptimizer, _FabricDataLoader]
- if _TORCH_GREATER_EQUAL_2_0:
- from torch._dynamo import OptimizedModule
-
- types.append(OptimizedModule)
+ types.append(OptimizedModule)
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
-def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
+def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
Use this function before instance checks against e.g. :class:`_FabricModule`.
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- # obj can't be an `OptimizedModule` anyway
- return obj, None
-
- from torch._dynamo import OptimizedModule
-
if isinstance(obj, OptimizedModule):
if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
raise RuntimeError(
@@ -359,10 +366,7 @@ def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.
return obj, None
-def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "OptimizedModule":
- if not _TORCH_GREATER_EQUAL_2_0:
- raise RuntimeError("Converting to a compiled module is only supported in PyTorch >= 2.0.0")
-
+def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> OptimizedModule:
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
@@ -414,5 +418,4 @@ def _capture(*args: Any, **kwargs: Any) -> Any:
return _capture
-if _TORCH_GREATER_EQUAL_2_0:
- torch.compile = _capture_compile_kwargs(torch.compile)
+torch.compile = _capture_compile_kwargs(torch.compile)
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index be7b66a27ca66..1b251b8fb06fa 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -4,54 +4,78 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-## [unReleased] - 2024-MM-DD
+
+## [2.4.0] - 2024-08-06
### Added
-- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))
+- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))
+- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))
+- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
+- Added support for PyTorch 2.4 ([#20010](https://github.com/Lightning-AI/pytorch-lightning/pull/20010))
+- Added support for Python 3.12 ([20078](https://github.com/Lightning-AI/pytorch-lightning/pull/20078))
+- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))
+- Added the count of modules in train and eval mode to the printed `ModelSummary` table ([#20159](https://github.com/Lightning-AI/pytorch-lightning/pull/20159))
-- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
+### Changed
-- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))
+- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
+- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))
+- NumPy is no longer a required dependency ([#20090](https://github.com/Lightning-AI/pytorch-lightning/issues/20090))
-- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))
+### Removed
--
+- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
+- Removed support for Python 3.8 ([#20071](https://github.com/Lightning-AI/lightning/pull/20071))
-### Changed
+### Fixed
-- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
+- Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068))
+- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
+- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
+- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
+- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
-- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
--
+## [2.3.0] - 2024-06-13
-### Deprecated
+### Added
--
+- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))
+- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))
+- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))
+- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))
+- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
+- Added `ModelParallelStrategy` to support 2D parallelism ([#19878](https://github.com/Lightning-AI/pytorch-lightning/pull/19878), [#19888](https://github.com/Lightning-AI/pytorch-lightning/pull/19888))
+- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))
+- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))
--
+### Changed
--
+- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
+- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
+- It is no longer allowed to skip `training_step()` by returning `None` in distributed training ([#19918](https://github.com/Lightning-AI/pytorch-lightning/pull/19918))
### Removed
- Removed the Bagua integration (`Trainer(strategy="bagua")`) ([#19445](https://github.com/Lightning-AI/lightning/pull/19445))
-
--
-
--
+- Removed support for PyTorch 1.13 ([#19706](https://github.com/Lightning-AI/lightning/pull/19706))
### Fixed
-- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))
+- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) ([#19886](https://github.com/Lightning-AI/lightning/pull/19886))
+- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
+- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))
+- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804))
+- Fixed resetting `epoch_loop.restarting` to avoid full validation run after `LearningRateFinder` ([#19818](https://github.com/Lightning-AI/pytorch-lightning/issues/19818))
-- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
+## [2.2.2] - 2024-04-11
--
+### Fixed
--
+- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))
+- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
## [2.2.1] - 2024-03-04
diff --git a/src/lightning/pytorch/__init__.py b/src/lightning/pytorch/__init__.py
index b8c89f075de50..53e7d5b7c1f7c 100644
--- a/src/lightning/pytorch/__init__.py
+++ b/src/lightning/pytorch/__init__.py
@@ -33,8 +33,6 @@
__all__ = ["Trainer", "LightningDataModule", "LightningModule", "Callback", "seed_everything"]
-# for compatibility with namespace packages
-__import__("pkg_resources").declare_namespace(__name__)
LIGHTNING_LOGO: str = """
####
diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py
index 357cfceefa03e..6a94c7ece70a3 100644
--- a/src/lightning/pytorch/callbacks/lr_monitor.py
+++ b/src/lightning/pytorch/callbacks/lr_monitor.py
@@ -44,6 +44,8 @@ class LearningRateMonitor(Callback):
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
+ log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to
+ ``False``.
Raises:
MisconfigurationException:
@@ -58,7 +60,7 @@ class LearningRateMonitor(Callback):
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
- ``Adam-1`` etc. If a optimizer has multiple parameter groups they will
+ ``Adam-1`` etc. If an optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schedulers.
A ``name`` keyword can also be used for parameter groups in the
diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py
index 6c5dd01df15c7..9587da0f4600b 100644
--- a/src/lightning/pytorch/callbacks/model_checkpoint.py
+++ b/src/lightning/pytorch/callbacks/model_checkpoint.py
@@ -27,7 +27,7 @@
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
-from typing import Any, Dict, Literal, Optional, Set
+from typing import Any, Dict, Literal, Optional, Set, Union
from weakref import proxy
import torch
@@ -94,7 +94,9 @@ class ModelCheckpoint(Checkpoint):
Please note that the monitors are checked every ``every_n_epochs`` epochs.
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
- collisions unless ``enable_version_counter`` is set to False.
+ collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
+ ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
+ collisions.
mode: one of {min, max}.
If ``save_top_k != 0``, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
@@ -216,7 +218,7 @@ def __init__(
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
- save_last: Optional[Literal[True, False, "link"]] = None,
+ save_last: Optional[Union[bool, Literal["link"]]] = None,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py
index ddff91ed2949e..89c31b2cc65e8 100644
--- a/src/lightning/pytorch/callbacks/model_summary.py
+++ b/src/lightning/pytorch/callbacks/model_summary.py
@@ -66,9 +66,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
total_parameters = model_summary.total_parameters
trainable_parameters = model_summary.trainable_parameters
model_size = model_summary.model_size
+ total_training_modes = model_summary.total_training_modes
if trainer.is_global_zero:
- self.summarize(summary_data, total_parameters, trainable_parameters, model_size, **self._summarize_kwargs)
+ self.summarize(
+ summary_data,
+ total_parameters,
+ trainable_parameters,
+ model_size,
+ total_training_modes,
+ **self._summarize_kwargs,
+ )
def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]:
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
@@ -83,12 +91,14 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
+ total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
summary_table = _format_summary_table(
total_parameters,
trainable_parameters,
model_size,
+ total_training_modes,
*summary_data,
)
log.info("\n" + summary_table)
diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py
index c89f6f7b739f2..785bf65af4361 100644
--- a/src/lightning/pytorch/callbacks/progress/progress_bar.py
+++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py
@@ -48,7 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
"""
def __init__(self) -> None:
- self._trainer: Optional["pl.Trainer"] = None
+ self._trainer: Optional[pl.Trainer] = None
self._current_eval_dataloader_idx: Optional[int] = None
@property
diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py
index 4fcca7d0b0e65..896de71267835 100644
--- a/src/lightning/pytorch/callbacks/progress/rich_progress.py
+++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py
@@ -148,7 +148,7 @@ def __init__(
self._trainer = trainer
self._tasks: Dict[Union[int, TaskID], Any] = {}
self._current_task_id = 0
- self._metrics: Dict[Union[str, "Style"], Any] = {}
+ self._metrics: Dict[Union[str, Style], Any] = {}
self._style = style
self._text_delimiter = text_delimiter
self._metrics_format = metrics_format
@@ -187,9 +187,6 @@ def _generate_metrics_texts(self) -> Generator[str, None, None]:
value = f"{value:{self._metrics_format}}"
yield f"{name}: {value}"
-else:
- Task, Style = Any, Any # type: ignore[assignment, misc]
-
@dataclass
class RichProgressBarTheme:
@@ -209,14 +206,14 @@ class RichProgressBarTheme:
"""
- description: Union[str, Style] = "white"
- progress_bar: Union[str, Style] = "#6206E0"
- progress_bar_finished: Union[str, Style] = "#6206E0"
- progress_bar_pulse: Union[str, Style] = "#6206E0"
- batch_progress: Union[str, Style] = "white"
- time: Union[str, Style] = "grey54"
- processing_speed: Union[str, Style] = "grey70"
- metrics: Union[str, Style] = "white"
+ description: Union[str, "Style"] = ""
+ progress_bar: Union[str, "Style"] = "#6206E0"
+ progress_bar_finished: Union[str, "Style"] = "#6206E0"
+ progress_bar_pulse: Union[str, "Style"] = "#6206E0"
+ batch_progress: Union[str, "Style"] = ""
+ time: Union[str, "Style"] = "dim"
+ processing_speed: Union[str, "Style"] = "dim underline"
+ metrics: Union[str, "Style"] = "italic"
metrics_text_delimiter: str = " "
metrics_format: str = ".3f"
@@ -274,16 +271,15 @@ def __init__(
self._console_kwargs = console_kwargs or {}
self._enabled: bool = True
self.progress: Optional[CustomProgress] = None
- self.train_progress_bar_id: Optional["TaskID"]
- self.val_sanity_progress_bar_id: Optional["TaskID"] = None
- self.val_progress_bar_id: Optional["TaskID"]
- self.test_progress_bar_id: Optional["TaskID"]
- self.predict_progress_bar_id: Optional["TaskID"]
+ self.train_progress_bar_id: Optional[TaskID]
+ self.val_sanity_progress_bar_id: Optional[TaskID] = None
+ self.val_progress_bar_id: Optional[TaskID]
+ self.test_progress_bar_id: Optional[TaskID]
+ self.predict_progress_bar_id: Optional[TaskID]
self._reset_progress_bar_ids()
- self._metric_component: Optional["MetricsTextColumn"] = None
+ self._metric_component: Optional[MetricsTextColumn] = None
self._progress_stopped: bool = False
self.theme = theme
- self._update_for_light_colab_theme()
@property
def refresh_rate(self) -> float:
@@ -298,36 +294,29 @@ def is_disabled(self) -> bool:
return not self.is_enabled
@property
- def train_progress_bar(self) -> Task:
+ def train_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.train_progress_bar_id is not None
return self.progress.tasks[self.train_progress_bar_id]
@property
- def val_sanity_check_bar(self) -> Task:
+ def val_sanity_check_bar(self) -> "Task":
assert self.progress is not None
assert self.val_sanity_progress_bar_id is not None
return self.progress.tasks[self.val_sanity_progress_bar_id]
@property
- def val_progress_bar(self) -> Task:
+ def val_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.val_progress_bar_id is not None
return self.progress.tasks[self.val_progress_bar_id]
@property
- def test_progress_bar(self) -> Task:
+ def test_progress_bar(self) -> "Task":
assert self.progress is not None
assert self.test_progress_bar_id is not None
return self.progress.tasks[self.test_progress_bar_id]
- def _update_for_light_colab_theme(self) -> None:
- if _detect_light_colab_theme():
- attributes = ["description", "batch_progress", "metrics"]
- for attr in attributes:
- if getattr(self.theme, attr) == "white":
- setattr(self.theme, attr, "black")
-
@override
def disable(self) -> None:
self._enabled = False
@@ -452,7 +441,7 @@ def on_validation_batch_start(
def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID":
assert self.progress is not None
return self.progress.add_task(
- f"[{self.theme.description}]{description}",
+ f"[{self.theme.description}]{description}" if self.theme.description else description,
total=total_batches,
visible=visible,
)
@@ -659,20 +648,3 @@ def __getstate__(self) -> Dict:
state["progress"] = None
state["_console"] = None
return state
-
-
-def _detect_light_colab_theme() -> bool:
- """Detect if it's light theme in Colab."""
- try:
- import get_ipython
- except (NameError, ModuleNotFoundError):
- return False
- ipython = get_ipython()
- if "google.colab" in str(ipython.__class__):
- try:
- from google.colab import output
-
- return output.eval_js('document.documentElement.matches("[theme=light]")')
- except ModuleNotFoundError:
- return False
- return False
diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
index bf9e238a01ead..cf9cd71614674 100644
--- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
+++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
@@ -96,15 +96,15 @@ class TQDMProgressBar(ProgressBar):
Set it to ``0`` to disable the display.
process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
- together. This corresponds to
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
- :class:`~lightning.pytorch.trainer.trainer.Trainer`.
+ together.
+ leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch.
+ Default: ``False``
"""
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
- def __init__(self, refresh_rate: int = 1, process_position: int = 0):
+ def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
@@ -113,6 +113,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
+ self._leave = leave
def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
@@ -262,6 +263,8 @@ def on_train_start(self, *_: Any) -> None:
@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
+ if self._leave:
+ self.train_progress_bar = self.init_train_tqdm()
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
@@ -279,6 +282,8 @@ def on_train_batch_end(
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
+ if self._leave:
+ self.train_progress_bar.close()
@override
def on_train_end(self, *_: Any) -> None:
diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py
index e5e214d1d14d1..e83a9de06375c 100644
--- a/src/lightning/pytorch/callbacks/pruning.py
+++ b/src/lightning/pytorch/callbacks/pruning.py
@@ -205,14 +205,14 @@ def __init__(
raise MisconfigurationException(
f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}"
f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}."
- " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
+ " HINT: if passing a `BasePruningMethod`, pass the class, not an instance"
)
# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
raise MisconfigurationException(
- 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
- f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
+ 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'
+ f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " # type: ignore[union-attr]
)
self.pruning_fn = pruning_fn
@@ -308,7 +308,7 @@ def apply_lottery_ticket_hypothesis(self) -> None:
def _apply_local_pruning(self, amount: float) -> None:
for module, name in self._parameters_to_prune:
- self.pruning_fn(module, name=name, amount=amount)
+ self.pruning_fn(module, name=name, amount=amount) # type: ignore[call-arg]
def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
self._global_kwargs["amount"] = amount
diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py
index f551a9c397531..c6c429b4bd2f5 100644
--- a/src/lightning/pytorch/callbacks/rich_model_summary.py
+++ b/src/lightning/pytorch/callbacks/rich_model_summary.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
from typing_extensions import override
@@ -71,6 +71,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
+ total_training_modes: Dict[str, int],
**summarize_kwargs: Any,
) -> None:
from rich import get_console
@@ -110,5 +111,7 @@ def summarize(
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
+ grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
+ grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")
console.print(grid)
diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
index 731f161683102..737084ced426d 100644
--- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
+++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
@@ -17,15 +17,15 @@
"""
from copy import deepcopy
-from typing import Any, Callable, Dict, List, Optional, Union, cast
+from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
import torch
from torch import Tensor, nn
+from torch.optim.lr_scheduler import LRScheduler
from torch.optim.swa_utils import SWALR
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
@@ -42,7 +42,7 @@ def __init__(
swa_lrs: Union[float, List[float]],
swa_epoch_start: Union[int, float] = 0.8,
annealing_epochs: int = 10,
- annealing_strategy: str = "cos",
+ annealing_strategy: Literal["cos", "linear"] = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
):
@@ -123,7 +123,7 @@ def __init__(
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm: Optional[bool] = None
- self._average_model: Optional["pl.LightningModule"] = None
+ self._average_model: Optional[pl.LightningModule] = None
self._initialized = False
self._swa_scheduler: Optional[LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
@@ -303,14 +303,14 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
dtype=module.running_var.dtype,
)
self.momenta[module] = module.momentum
- module.momentum = None # type: ignore[assignment]
+ module.momentum = None
assert module.num_batches_tracked is not None
module.num_batches_tracked *= 0
def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
- bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment]
+ bn_module.momentum = self.momenta[bn_module]
@staticmethod
def update_parameters(
diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py
index a6854b9bf6d89..26af335f7be93 100644
--- a/src/lightning/pytorch/cli.py
+++ b/src/lightning/pytorch/cli.py
@@ -23,11 +23,11 @@
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import get_filesystem
-from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, seed_everything
from lightning.pytorch.core.mixins.hparams_mixin import _given_hyperparameters_context
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -63,15 +63,15 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
-LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau)
-LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
-LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]
+LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau)
+LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau]
+LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]]
# Type aliases intended for convenience of CLI developers
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
OptimizerCallable = Callable[[Iterable], Optimizer]
-LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]]
+LRSchedulerCallable = Callable[[Optimizer], Union[LRScheduler, ReduceLROnPlateau]]
class LightningArgumentParser(ArgumentParser):
@@ -534,7 +534,7 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
self.config = parser.parse_args(args)
def _add_instantiators(self) -> None:
- self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
+ self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]
@@ -791,8 +791,18 @@ def __init__(self, cli: LightningCLI, key: str) -> None:
self.key = key
def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
+ hparams = self.cli.config_dump.get(self.key, {})
+ if "class_path" in hparams:
+ # To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
+ # parameters are stored directly, and the class_path in a special key `_class_path` to clarify its internal
+ # use.
+ hparams = {
+ "_class_path": hparams["class_path"],
+ **hparams.get("init_args", {}),
+ **hparams.get("dict_kwargs", {}),
+ }
with _given_hyperparameters_context(
- hparams=self.cli.config_dump.get(self.key, {}),
+ hparams=hparams,
instantiator="lightning.pytorch.cli.instantiate_module",
):
return class_type(*args, **kwargs)
@@ -800,10 +810,14 @@ def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> M
def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
parser = ArgumentParser(exit_on_error=False)
- if "class_path" in config:
- parser.add_subclass_arguments(class_type, "module")
+ if "_class_path" in config:
+ parser.add_subclass_arguments(class_type, "module", fail_untyped=False)
+ config = {
+ "class_path": config["_class_path"],
+ "dict_kwargs": {k: v for k, v in config.items() if k != "_class_path"},
+ }
else:
- parser.add_class_arguments(class_type, "module")
+ parser.add_class_arguments(class_type, "module", fail_untyped=False)
cfg = parser.parse_object({"module": config})
init = parser.instantiate_classes(cfg)
return init.module
diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py
index 982fe46412ac5..6cb8f79f09284 100644
--- a/src/lightning/pytorch/core/datamodule.py
+++ b/src/lightning/pytorch/core/datamodule.py
@@ -81,7 +81,7 @@ def teardown(self):
def __init__(self) -> None:
super().__init__()
# Pointer to the trainer object
- self.trainer: Optional["pl.Trainer"] = None
+ self.trainer: Optional[pl.Trainer] = None
@classmethod
def from_datasets(
@@ -235,7 +235,7 @@ def load_from_checkpoint(
"""
loaded = _load_from_checkpoint(
- cls, # type: ignore[arg-type]
+ cls,
checkpoint_path,
map_location=map_location,
hparams_file=hparams_file,
diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py
index 4a4cad3d5f080..5495a0262036d 100644
--- a/src/lightning/pytorch/core/hooks.py
+++ b/src/lightning/pytorch/core/hooks.py
@@ -19,7 +19,6 @@
from torch import Tensor
from torch.optim.optimizer import Optimizer
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities import move_data_to_device
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
@@ -158,8 +157,7 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in
def on_validation_model_zero_grad(self) -> None:
"""Called by the training loop to release gradients before entering the validation loop."""
- zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
- self.zero_grad(**zero_grad_kwargs)
+ self.zero_grad()
def on_validation_model_eval(self) -> None:
"""Called when the validation loop starts.
diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py
index 3075e8952b148..782fc40d928ef 100644
--- a/src/lightning/pytorch/core/module.py
+++ b/src/lightning/pytorch/core/module.py
@@ -17,9 +17,11 @@
import numbers
import weakref
from contextlib import contextmanager
+from io import BytesIO
from pathlib import Path
from typing import (
IO,
+ TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -50,7 +52,6 @@
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
-from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
@@ -66,7 +67,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
-from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn
+from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
@@ -76,6 +77,9 @@
OptimizerLRScheduler,
)
+if TYPE_CHECKING:
+ from torch.distributed.device_mesh import DeviceMesh
+
_ONNX_AVAILABLE = RequirementCache("onnx")
warning_cache = WarningCache()
@@ -110,6 +114,7 @@ class LightningModule(
"trainer",
"fabric",
"strict_loading",
+ "device_mesh",
]
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
@@ -124,7 +129,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# pointer to the trainer object
- self._trainer: Optional["pl.Trainer"] = None
+ self._trainer: Optional[pl.Trainer] = None
# attributes that can be set by user
self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
@@ -135,13 +140,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._current_fx_name: Optional[str] = None
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
- self._register_sharded_tensor_state_dict_hooks_if_available()
self._compiler_ctx: Optional[Dict[str, Any]] = None
# attributes only used when using fabric
- self._fabric: Optional["lf.Fabric"] = None
+ self._fabric: Optional[lf.Fabric] = None
self._fabric_optimizers: List[_FabricOptimizer] = []
+ # access to device mesh in `conigure_model()` hook
+ self._device_mesh: Optional[DeviceMesh] = None
+
@overload
def optimizers(
self, use_pl_optimizer: Literal[True] = True
@@ -217,9 +224,6 @@ def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer # type: ignore[assignment]
- # https://github.com/pytorch/pytorch/issues/95857
- if not _TORCH_GREATER_EQUAL_2_0 and trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
- trainer = weakref.proxy(trainer)
self._trainer = trainer
@property
@@ -322,6 +326,12 @@ def loggers(self) -> Union[List[Logger], List[FabricLogger]]:
return self._trainer.loggers
return []
+ @property
+ def device_mesh(self) -> Optional["DeviceMesh"]:
+ """Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the
+ :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook to parallelize the LightningModule."""
+ return self._device_mesh
+
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
trainer = self._trainer
if trainer:
@@ -394,7 +404,7 @@ def log(
The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
Args:
- name: key to log.
+ name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
prog_bar: if ``True`` logs to the progress bar.
logger: if ``True`` logs to the logger.
@@ -558,6 +568,7 @@ def log_dict(
Args:
dictionary: key value pairs.
+ Keys must be identical across all processes if using DDP or any other distributed strategy.
The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
prog_bar: if ``True`` logs to the progress base.
logger: if ``True`` logs to the logger.
@@ -1354,7 +1365,7 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
)
@torch.no_grad()
- def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
+ def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
"""Saves the model in ONNX format.
Args:
@@ -1377,10 +1388,8 @@ def forward(self, x):
model.to_onnx("export.onnx", input_sample, export_params=True)
"""
- if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE:
- raise ModuleNotFoundError(
- f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`"
- )
+ if not _ONNX_AVAILABLE:
+ raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")
mode = self.training
@@ -1395,6 +1404,7 @@ def forward(self, x):
input_sample = self._on_before_batch_transfer(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)
+ file_path = str(file_path) if isinstance(file_path, Path) else file_path
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
@@ -1572,7 +1582,7 @@ def load_from_checkpoint(
"""
loaded = _load_from_checkpoint(
- cls, # type: ignore[arg-type]
+ cls,
checkpoint_path,
map_location,
hparams_file,
@@ -1587,24 +1597,6 @@ def __getstate__(self) -> Dict[str, Any]:
state["_trainer"] = None
return state
- def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
- """Adds ShardedTensor state dict hooks if ShardedTensors are supported.
-
- These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
-
- """
- if _TORCH_GREATER_EQUAL_2_1:
- # ShardedTensor is deprecated in favor of DistributedTensor
- return
- if _IS_WINDOWS or not torch.distributed.is_available():
- rank_zero_debug("Could not register sharded tensor state dict hooks")
- return
-
- from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
-
- self._register_state_dict_hook(state_dict_hook)
- self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
-
@contextmanager
def _jit_is_scripting() -> Generator:
diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py
index b7a63a8e17cab..777dca0b51dfe 100644
--- a/src/lightning/pytorch/core/optimizer.py
+++ b/src/lightning/pytorch/core/optimizer.py
@@ -19,10 +19,11 @@
import torch
from torch import optim
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import ReduceLROnPlateau
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.types import Optimizable, ReduceLROnPlateau, _Stateful
+from lightning.fabric.utilities.types import Optimizable, _Stateful
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py
index f8e9c8300337a..521192f500b53 100644
--- a/src/lightning/pytorch/core/saving.py
+++ b/src/lightning/pytorch/core/saving.py
@@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
try:
v = v.name if isinstance(v, Enum) else v
yaml.dump(v)
- except TypeError:
+ except (TypeError, ValueError):
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
hparams[k] = type(v).__name__
else:
diff --git a/src/lightning/pytorch/demos/__init__.py b/src/lightning/pytorch/demos/__init__.py
index 5ffa4afc92527..fa91d7cac9fde 100644
--- a/src/lightning/pytorch/demos/__init__.py
+++ b/src/lightning/pytorch/demos/__init__.py
@@ -1 +1,2 @@
+from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401
from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401
diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py
index 3dd7bd8b1afc8..fd2660228146e 100644
--- a/src/lightning/pytorch/demos/boring_classes.py
+++ b/src/lightning/pytorch/demos/boring_classes.py
@@ -18,9 +18,9 @@
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
-from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch import LightningDataModule, LightningModule
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.utilities.types import STEP_OUTPUT
@@ -134,7 +134,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
return {"y": self.step(batch)}
- def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
+ def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]:
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py
new file mode 100644
index 0000000000000..672b61ad0eff9
--- /dev/null
+++ b/src/lightning/pytorch/demos/lstm.py
@@ -0,0 +1,98 @@
+"""Demo of a simple LSTM language model.
+
+Code is adapted from the PyTorch examples at
+https://github.com/pytorch/examples/blob/main/word_language_model
+
+"""
+
+from typing import Iterator, List, Optional, Sized, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader, Sampler
+
+from lightning.pytorch.core import LightningModule
+from lightning.pytorch.demos.transformer import WikiText2
+
+
+class SimpleLSTM(nn.Module):
+ def __init__(
+ self, vocab_size: int = 33278, ninp: int = 512, nhid: int = 512, nlayers: int = 4, dropout: float = 0.2
+ ):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.drop = nn.Dropout(dropout)
+ self.encoder = nn.Embedding(vocab_size, ninp)
+ self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout, batch_first=True)
+ self.decoder = nn.Linear(nhid, vocab_size)
+ self.nlayers = nlayers
+ self.nhid = nhid
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ nn.init.uniform_(self.encoder.weight, -0.1, 0.1)
+ nn.init.zeros_(self.decoder.bias)
+ nn.init.uniform_(self.decoder.weight, -0.1, 0.1)
+
+ def forward(self, input: Tensor, hidden: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
+ emb = self.drop(self.encoder(input))
+ output, hidden = self.rnn(emb, hidden)
+ output = self.drop(output)
+ decoded = self.decoder(output).view(-1, self.vocab_size)
+ return F.log_softmax(decoded, dim=1), hidden
+
+ def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]:
+ weight = next(self.parameters())
+ return (
+ weight.new_zeros(self.nlayers, batch_size, self.nhid),
+ weight.new_zeros(self.nlayers, batch_size, self.nhid),
+ )
+
+
+class SequenceSampler(Sampler[List[int]]):
+ def __init__(self, dataset: Sized, batch_size: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.chunk_size = len(self.dataset) // self.batch_size
+
+ def __iter__(self) -> Iterator[List[int]]:
+ n = len(self.dataset)
+ for i in range(self.chunk_size):
+ yield list(range(i, n - (n % self.batch_size), self.chunk_size))
+
+ def __len__(self) -> int:
+ return self.chunk_size
+
+
+class LightningLSTM(LightningModule):
+ def __init__(self, vocab_size: int = 33278):
+ super().__init__()
+ self.model = SimpleLSTM(vocab_size=vocab_size)
+ self.hidden: Optional[Tuple[Tensor, Tensor]] = None
+
+ def on_train_epoch_end(self) -> None:
+ self.hidden = None
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ input, target = batch
+ if self.hidden is None:
+ self.hidden = self.model.init_hidden(input.size(0))
+ self.hidden = (self.hidden[0].detach(), self.hidden[1].detach())
+ output, self.hidden = self.model(input, self.hidden)
+ loss = F.nll_loss(output, target.view(-1))
+ self.log("train_loss", loss, prog_bar=True)
+ return loss
+
+ def prepare_data(self) -> None:
+ WikiText2(download=True)
+
+ def train_dataloader(self) -> DataLoader:
+ dataset = WikiText2()
+ return DataLoader(dataset, batch_sampler=SequenceSampler(dataset, batch_size=20))
+
+ def configure_optimizers(self) -> Optimizer:
+ return torch.optim.SGD(self.parameters(), lr=20.0)
diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py
index 833c15d91cbdd..ac83b5539f249 100644
--- a/src/lightning/pytorch/demos/transformer.py
+++ b/src/lightning/pytorch/demos/transformer.py
@@ -85,11 +85,10 @@ def forward(self, x: Tensor) -> Tensor:
if self.pe is None:
# 1) can't use buffer, see https://github.com/pytorch/pytorch/issues/68407
# 2) can't use parameter becauses pe gets sliced and DDP requires all params to participate in forward
- # 3) can't make it a `requires_grad=False` parameter because FSDP in PyTorch < 2.1 needs all params to
- # require grad
+ # TODO: Could make this a `nn.Parameter` with `requires_grad=False`
self.pe = self._init_pos_encoding(device=x.device)
- x + self.pe[: x.size(0), :]
+ x = x + self.pe[: x.size(0), :]
return self.dropout(x)
def _init_pos_encoding(self, device: torch.device) -> Tensor:
diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py
index ccb6d62de866a..277af5c85f539 100644
--- a/src/lightning/pytorch/loggers/comet.py
+++ b/src/lightning/pytorch/loggers/comet.py
@@ -268,7 +268,7 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi
self.logger.experiment.some_comet_function()
"""
- if self._experiment is not None:
+ if self._experiment is not None and self._experiment.alive:
return self._experiment
if self._future_experiment_key is not None:
diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py
index fdaeb18e9199f..caca0c181c6ff 100644
--- a/src/lightning/pytorch/loggers/csv_logs.py
+++ b/src/lightning/pytorch/loggers/csv_logs.py
@@ -19,7 +19,6 @@
"""
-import logging
import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
@@ -35,8 +34,6 @@
from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.utilities.rank_zero import rank_zero_only
-log = logging.getLogger(__name__)
-
class ExperimentWriter(_FabricExperimentWriter):
r"""Experiment writer for CSVLogger.
diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py
index c3051d34a7b09..40e8ed8c4a13e 100644
--- a/src/lightning/pytorch/loggers/logger.py
+++ b/src/lightning/pytorch/loggers/logger.py
@@ -15,11 +15,11 @@
import functools
import operator
+import statistics
from abc import ABC
from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
-import numpy as np
from typing_extensions import override
from lightning.fabric.loggers import Logger as FabricLogger
@@ -100,7 +100,7 @@ def method(*args: Any, **kwargs: Any) -> None:
def merge_dicts( # pragma: no cover
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping] = None,
- default_func: Callable[[Sequence[float]], float] = np.mean,
+ default_func: Callable[[Sequence[float]], float] = statistics.mean,
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function.
@@ -126,7 +126,7 @@ def merge_dicts( # pragma: no cover
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}}
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}}
>>> dflt_func = min
- >>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}}
+ >>> agg_funcs = {'a': statistics.mean, 'v': max, 'd': {'d1': sum}}
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3,
'b': 2.0,
diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py
index cf874b6d690bf..1b15014cd0528 100644
--- a/src/lightning/pytorch/loggers/mlflow.py
+++ b/src/lightning/pytorch/loggers/mlflow.py
@@ -42,6 +42,7 @@
log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0", "mlflow")
+_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0", "mlflow")
class MLFlowLogger(Logger):
@@ -100,6 +101,8 @@ def any_lightning_module_function_or_hook(self):
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
default.
run_id: The run identifier of the experiment. If not provided, a new run is started.
+ synchronous: Hints mlflow whether to block the execution for every logging call until complete where
+ applicable. Requires mlflow >= 2.8.0
Raises:
ModuleNotFoundError:
@@ -120,9 +123,12 @@ def __init__(
prefix: str = "",
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
+ synchronous: Optional[bool] = None,
):
if not _MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
+ if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
+ raise ModuleNotFoundError("`synchronous` requires mlflow>=2.8.0")
super().__init__()
if not tracking_uri:
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
@@ -138,7 +144,7 @@ def __init__(
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._prefix = prefix
self._artifact_location = artifact_location
-
+ self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
self._initialized = False
from mlflow.tracking import MlflowClient
@@ -233,7 +239,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
- self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
+ self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100], **self._log_batch_kwargs)
@override
@rank_zero_only
@@ -261,7 +267,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
k = new_k
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
- self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
+ self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
@override
@rank_zero_only
@@ -354,7 +360,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
# Artifact path on mlflow
- artifact_path = f"model/checkpoints/{Path(p).stem}"
+ artifact_path = Path(p).stem
# Log the checkpoint
self.experiment.log_artifact(self._run_id, p, artifact_path)
diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py
index 50ead5cfd39d8..88e026f6945e0 100644
--- a/src/lightning/pytorch/loggers/tensorboard.py
+++ b/src/lightning/pytorch/loggers/tensorboard.py
@@ -16,7 +16,6 @@
------------------
"""
-import logging
import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
@@ -36,8 +35,6 @@
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
-log = logging.getLogger(__name__)
-
class TensorBoardLogger(Logger, FabricTensorBoardLogger):
r"""Log to local or remote file system in `TensorBoard `_ format.
@@ -245,7 +242,6 @@ def _get_next_version(self) -> int:
try:
listdir_info = self._fs.listdir(root_dir)
except OSError:
- log.warning("Missing logger folder: %s", root_dir)
return 0
existing_versions = []
diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py
index 4025f2cd18004..20f8d02a7ab9b 100644
--- a/src/lightning/pytorch/loggers/wandb.py
+++ b/src/lightning/pytorch/loggers/wandb.py
@@ -26,7 +26,12 @@
from torch import Tensor
from typing_extensions import override
-from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
+from lightning.fabric.utilities.logger import (
+ _add_prefix,
+ _convert_json_serializable,
+ _convert_params,
+ _sanitize_callable_params,
+)
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
@@ -43,7 +48,7 @@
class WandbLogger(Logger):
- r"""Log using `Weights and Biases `_.
+ r"""Log using `Weights and Biases `_.
**Installation and set-up**
@@ -248,7 +253,7 @@ def any_lightning_module_function_or_hook(self):
See Also:
- `Demo in Google Colab `__ with hyperparameter search and model logging
- - `W&B Documentation `__
+ - `W&B Documentation `__
Args:
name: Display name for the run.
@@ -419,6 +424,7 @@ def watch(
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _sanitize_callable_params(params)
+ params = _convert_json_serializable(params)
self.experiment.config.update(params, allow_val_change=True)
@override
diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py
index 82666a0e21e5f..2ce6acab11a37 100644
--- a/src/lightning/pytorch/loops/optimization/automatic.py
+++ b/src/lightning/pytorch/loops/optimization/automatic.py
@@ -314,8 +314,14 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
"""
trainer = self.trainer
- # manually capture logged metrics
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
+ if training_step_output is None and trainer.world_size > 1:
+ raise RuntimeError(
+ "Skipping the `training_step` by returning None in distributed training is not supported."
+ " It is recommended that you rewrite your training logic to avoid having to skip the step in the first"
+ " place."
+ )
+
return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py
index 8ca54184b477a..99ea5c4254d62 100644
--- a/src/lightning/pytorch/loops/utilities.py
+++ b/src/lightning/pytorch/loops/utilities.py
@@ -21,7 +21,6 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _distributed_is_initialized
-from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators.xla import XLAAccelerator
from lightning.pytorch.callbacks.timer import Timer
@@ -171,9 +170,6 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any:
elif isinstance(self.trainer.strategy, FSDPStrategy):
# https://github.com/pytorch/pytorch/issues/95957
context_manager = torch.no_grad
- elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None:
- # avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1
- context_manager = torch.no_grad
elif self.inference_mode:
context_manager = torch.inference_mode
else:
diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py
index 70ce9a87fb37a..e63ccd6912b63 100644
--- a/src/lightning/pytorch/plugins/precision/amp.py
+++ b/src/lightning/pytorch/plugins/precision/amp.py
@@ -18,8 +18,8 @@
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
@@ -40,7 +40,7 @@ def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
- scaler: Optional[torch.cuda.amp.GradScaler] = None,
+ scaler: Optional["torch.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
@@ -50,9 +50,7 @@ def __init__(
self.precision = precision
if scaler is None and self.precision == "16-mixed":
- with _patch_cuda_is_available():
- # if possible, we defer CUDA initialization to support strategies that will attempt forks
- scaler = torch.cuda.amp.GradScaler()
+ scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
diff --git a/src/lightning/pytorch/plugins/precision/bitsandbytes.py b/src/lightning/pytorch/plugins/precision/bitsandbytes.py
index 62acc7bf77c8d..3a2daa828bc3c 100644
--- a/src/lightning/pytorch/plugins/precision/bitsandbytes.py
+++ b/src/lightning/pytorch/plugins/precision/bitsandbytes.py
@@ -16,7 +16,7 @@
class BitsandbytesPrecision(Precision, FabricBNBPrecision):
- """Plugin for quantizing weights with `bitsandbytes `__.
+ """Plugin for quantizing weights with `bitsandbytes `__.
.. warning:: This is an :ref:`experimental ` feature.
diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py
index 829e93a52a645..e1e90281cf3af 100644
--- a/src/lightning/pytorch/plugins/precision/deepspeed.py
+++ b/src/lightning/pytorch/plugins/precision/deepspeed.py
@@ -113,7 +113,7 @@ def backward( # type: ignore[override]
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
" the backward logic internally."
)
- deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
+ deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
deepspeed_engine.backward(tensor, *args, **kwargs)
@override
@@ -135,7 +135,7 @@ def optimizer_step( # type: ignore[override]
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
- deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
+ deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model
return deepspeed_engine.step(**kwargs)
@override
diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py
index c41199adb480e..e6c684967ed40 100644
--- a/src/lightning/pytorch/plugins/precision/fsdp.py
+++ b/src/lightning/pytorch/plugins/precision/fsdp.py
@@ -22,7 +22,6 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -87,21 +86,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
- # With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
- # property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
- # `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py
index cd35e2e515de2..51bdddb18f814 100644
--- a/src/lightning/pytorch/plugins/precision/precision.py
+++ b/src/lightning/pytorch/plugins/precision/precision.py
@@ -95,7 +95,7 @@ def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> N
def _wrap_closure(
self,
model: "pl.LightningModule",
- optimizer: Optimizer,
+ optimizer: Steppable,
closure: Callable[[], Any],
) -> Any:
"""This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py
index 9e166b5e34d9f..467b47124eb60 100644
--- a/src/lightning/pytorch/profilers/advanced.py
+++ b/src/lightning/pytorch/profilers/advanced.py
@@ -16,13 +16,17 @@
import cProfile
import io
import logging
+import os
import pstats
+import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing_extensions import override
+from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.profilers.profiler import Profiler
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
log = logging.getLogger(__name__)
@@ -40,6 +44,7 @@ def __init__(
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
line_count_restriction: float = 1.0,
+ dump_stats: bool = False,
) -> None:
"""
Args:
@@ -54,6 +59,8 @@ def __init__(
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
+ dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
+
Raises:
ValueError:
If you attempt to stop recording an action which was never started.
@@ -61,6 +68,7 @@ def __init__(
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: Dict[str, cProfile.Profile] = {}
self.line_count_restriction = line_count_restriction
+ self.dump_stats = dump_stats
@override
def start(self, action_name: str) -> None:
@@ -75,10 +83,27 @@ def stop(self, action_name: str) -> None:
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()
+ def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None:
+ assert self.dirpath
+ dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
+ dst_fs = get_filesystem(dst_filepath)
+ dst_fs.mkdirs(self.dirpath, exist_ok=True)
+ # temporarily save to local since pstats can only dump into a local file
+ with tempfile.TemporaryDirectory(
+ prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
+ ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
+ src_filepath = os.path.join(tmp_dir, "tmp.prof")
+ profile.dump_stats(src_filepath)
+ src_fs = get_filesystem(src_filepath)
+ with src_fs.open(src_filepath, "rb") as src_file:
+ dst_file.write(src_file.read())
+
@override
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
+ if self.dump_stats:
+ self._dump_stats(action_name, pr)
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
ps.print_stats(self.line_count_restriction)
diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py
index 96b8d319bb884..a26b3d321d2e0 100644
--- a/src/lightning/pytorch/profilers/pytorch.py
+++ b/src/lightning/pytorch/profilers/pytorch.py
@@ -28,6 +28,7 @@
from typing_extensions import override
from lightning.fabric.accelerators.cuda import is_cuda_available
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
@@ -295,14 +296,14 @@ def __init__(
self._emit_nvtx = emit_nvtx
self._export_to_chrome = export_to_chrome
self._row_limit = row_limit
- self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
+ self._sort_by_key = sort_by_key or _default_sort_by_key(profiler_kwargs)
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
self.profiler: Optional[_PROFILER] = None
- self.function_events: Optional["EventList"] = None
- self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector
+ self.function_events: Optional[EventList] = None
+ self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector
self._register: Optional[RegisterRecordFunction] = None
self._parent_profiler: Optional[ContextManager] = None
self._recording_map: Dict[str, record_function] = {}
@@ -400,13 +401,19 @@ def _default_schedule() -> Optional[Callable]:
return None
def _default_activities(self) -> List["ProfilerActivity"]:
- activities: List["ProfilerActivity"] = []
+ activities: List[ProfilerActivity] = []
if not _KINETO_AVAILABLE:
return activities
- if self._profiler_kwargs.get("use_cpu", True):
+ if _TORCH_GREATER_EQUAL_2_4:
activities.append(ProfilerActivity.CPU)
- if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
- activities.append(ProfilerActivity.CUDA)
+ if is_cuda_available():
+ activities.append(ProfilerActivity.CUDA)
+ else:
+ # `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4
+ if self._profiler_kwargs.get("use_cpu", True):
+ activities.append(ProfilerActivity.CPU)
+ if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
+ activities.append(ProfilerActivity.CUDA)
return activities
@override
@@ -565,3 +572,13 @@ def teardown(self, stage: Optional[str]) -> None:
self._recording_map = {}
super().teardown(stage=stage)
+
+
+def _default_sort_by_key(profiler_kwargs: dict) -> str:
+ activities = profiler_kwargs.get("activities", [])
+ is_cuda = (
+ profiler_kwargs.get("use_cuda", False) # `use_cuda` is deprecated in PyTorch >= 2.4
+ or (activities and ProfilerActivity.CUDA in activities)
+ or (not activities and is_cuda_available())
+ )
+ return f"{'cuda' if is_cuda else 'cpu'}_time_total"
diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py
index 14ffe52870ba5..9c2b2a6a3a621 100644
--- a/src/lightning/pytorch/strategies/__init__.py
+++ b/src/lightning/pytorch/strategies/__init__.py
@@ -18,6 +18,7 @@
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
+from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401
@@ -31,6 +32,7 @@
"DDPStrategy",
"DeepSpeedStrategy",
"FSDPStrategy",
+ "ModelParallelStrategy",
"ParallelStrategy",
"SingleDeviceStrategy",
"Strategy",
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 2cc099be39d26..1eaa5bab75fbe 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -24,6 +24,7 @@
import torch
from torch.nn import Module
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from typing_extensions import override
import lightning.pytorch as pl
@@ -37,7 +38,7 @@
)
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
-from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
+from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
from lightning.pytorch.plugins.precision import Precision
@@ -336,10 +337,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
- # we set the device so that optimizers can be created with distributed comms.
- assert self.lightning_module is not None
- self.lightning_module._device = self.root_device
-
assert self.model is not None
self.model = self.precision_plugin.convert_module(self.model)
self.model = self._setup_model(self.model)
@@ -417,7 +414,7 @@ def _setup_model_and_optimizer(
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.
- This calls :func:`deepspeed.initialize` internally.
+ This calls ``deepspeed.initialize`` internally.
"""
import deepspeed
diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py
index 4c6f1fec5fe15..ab6e579c3071f 100644
--- a/src/lightning/pytorch/strategies/fsdp.py
+++ b/src/lightning/pytorch/strategies/fsdp.py
@@ -16,7 +16,21 @@
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Literal,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+)
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@@ -37,16 +51,15 @@
_distributed_checkpoint_save,
_get_full_state_dict_context,
_get_sharded_state_dict_context,
- _has_meta_device_parameters,
_init_cpu_offload,
_init_sharding_strategy,
_is_full_checkpoint,
_is_sharded_checkpoint,
- _load_raw_module_state,
_move_torchmetrics_to_device,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
+from lightning.fabric.strategies.model_parallel import _load_raw_module_state
from lightning.fabric.utilities.distributed import (
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
@@ -54,11 +67,8 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
-from lightning.fabric.utilities.imports import (
- _TORCH_GREATER_EQUAL_2_0,
- _TORCH_GREATER_EQUAL_2_1,
-)
-from lightning.fabric.utilities.init import _EmptyInit
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
+from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
@@ -74,15 +84,11 @@
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
+ from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
+ from torch.distributed.fsdp.wrap import ModuleWrapPolicy
- if _TORCH_GREATER_EQUAL_2_0:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-
- _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
- else:
- _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool]] # type: ignore[misc]
-
+ _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
@@ -92,8 +98,6 @@
class FSDPStrategy(ParallelStrategy):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
- .. warning:: This is an :ref:`experimental ` feature.
-
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
@@ -125,10 +129,14 @@ class FSDPStrategy(ParallelStrategy):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
- replicates across machines.
+ replicates across machines. See also the `device_mesh` parameter below.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
+ device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
+ replicate the model. The product of the two numbers must equal the world size. Only valid in combination
+ with the `HYBRID_SHARD` sharding strategy.
+
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
@@ -158,6 +166,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "full",
+ device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -173,22 +182,21 @@ def __init__(
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)
+
+ if device_mesh is not None:
+ if not _TORCH_GREATER_EQUAL_2_2:
+ raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
+ self.kwargs["device_mesh"] = device_mesh
+
self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)
- if _TORCH_GREATER_EQUAL_2_0:
- # Avoids the need for user to reference params in `configure_optimizers` via
- # `self.trainer.model.parameters()` and enables support for multiple parameter groups.
- self.kwargs.setdefault("use_orig_params", True)
+ # Avoids the need for user to reference params in `configure_optimizers` via
+ # `self.trainer.model.parameters()` and enables support for multiple parameter groups.
+ self.kwargs.setdefault("use_orig_params", True)
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
-
- if state_dict_type == "sharded" and not _TORCH_GREATER_EQUAL_2_0:
- raise NotImplementedError(
- "Saving checkpoints with `FSDPStrategy(state_dict_type='sharded')` is not supported in PyTorch < 2.0."
- " Please upgrade `torch`."
- )
self._state_dict_type = state_dict_type
@property
@@ -260,6 +268,12 @@ def setup_environment(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
+ if isinstance(self.kwargs.get("device_mesh"), tuple):
+ from torch.distributed.device_mesh import init_device_mesh
+
+ self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])
+
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
@@ -284,7 +298,7 @@ def _setup_model(self, model: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel
if any(isinstance(mod, FullyShardedDataParallel) for mod in model.modules()):
- if _has_meta_device_parameters(model):
+ if _has_meta_device_parameters_or_buffers(model):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
@@ -321,10 +335,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
self.model = self._layer_sync.apply(self.model)
- # we set the device so that optimizers can be created with distributed comms.
- assert self.lightning_module is not None
- self.lightning_module._device = self.root_device
-
self.model = self.precision_plugin.convert_module(self.model)
if is_overridden("configure_sharded_model", self.lightning_module):
@@ -355,8 +365,8 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = False
try:
- # In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
- # `self.trainer.model.parameters()` in configure_optimizers()
+ # If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in
+ # `configure_optimizers()`
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
@@ -364,7 +374,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
- # We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
+ # We avoid this limitation by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
@@ -380,14 +390,10 @@ def model_to_device(self) -> None:
@contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
- empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
- if _TORCH_GREATER_EQUAL_2_1 and empty_init:
- # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
- # 1) materialize module 2) call `reset_parameters()` 3) shard the module.
- # These operations are applied to each submodule 'bottom up' in the module hierarchy.
- empty_init_context = torch.device("meta")
- else:
- empty_init_context = _EmptyInit(enabled=bool(empty_init))
+ # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
+ # 1) materialize module 2) call `reset_parameters()` 3) shard the module.
+ # These operations are applied to each submodule 'bottom up' in the module hierarchy.
+ empty_init_context = torch.device("meta") if empty_init else nullcontext()
with empty_init_context, self.precision_plugin.tensor_init_context():
yield
@@ -517,10 +523,6 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
@override
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
- if not _TORCH_GREATER_EQUAL_2_0:
- rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support saving the optimizer state.")
- return {}
-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType
@@ -629,7 +631,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
return metadata
if _is_full_checkpoint(path):
- checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
+ checkpoint = _lazy_load(path)
_load_raw_module_state(
checkpoint.pop("state_dict"),
module=self.model,
@@ -637,10 +639,9 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
strict=self.lightning_module.strict_loading,
)
- if _TORCH_GREATER_EQUAL_2_0:
- # Materialize lazy tensors if there are any left in the checkpoint
- # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
- checkpoint = _materialize_tensors(checkpoint)
+ # Materialize lazy tensors if there are any left in the checkpoint
+ # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
+ checkpoint = _materialize_tensors(checkpoint)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import OptimStateKeyType
@@ -649,9 +650,6 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
if optimizer_states is None or self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
# If the optimizer states are not present, we don't need to do anything (backward compatibility)
return checkpoint
- if not _TORCH_GREATER_EQUAL_2_0:
- rank_zero_warn("FSDP in Lightning with PyTorch < 2.0 does not support loading the optimizer state.")
- return checkpoint
if len(self.optimizers) != len(optimizer_states):
raise RuntimeError(
f"You have configured {len(self.optimizers)} optimizers but the checkpoint contains"
diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py
index aa96da63adb65..05e3fed561ccb 100644
--- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py
+++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import io
import logging
import os
import queue
@@ -19,7 +20,6 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union
-import numpy as np
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
@@ -226,7 +226,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
"""Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
- avoid issues with memory sharing, we cast the data to numpy.
+ avoid issues with memory sharing, we convert tensors to bytes.
Args:
trainer: reference to the Trainer.
@@ -236,14 +236,15 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
process this output.
"""
- callback_metrics: dict = apply_to_collection(
- trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
- ) # send as numpy to avoid issues with memory sharing
- return {"callback_metrics": callback_metrics}
+ callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
+ buffer = io.BytesIO()
+ torch.save(callback_metrics, buffer)
+ # send tensors as bytes to avoid issues with memory sharing
+ return {"callback_metrics_bytes": buffer.getvalue()}
def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
- cast back the data to ``torch.Tensor``.
+ convert bytes back to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
@@ -252,14 +253,15 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
"""
# NOTE: `get_extra_results` needs to be called before
- callback_metrics = extra["callback_metrics"]
- trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
+ callback_metrics_bytes = extra["callback_metrics_bytes"]
+ callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True)
+ trainer.callback_metrics.update(callback_metrics)
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:
- log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
+ log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
with suppress(ProcessLookupError):
os.kill(proc.pid, signum)
diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py
index 03dbbc52365fb..d2035d03d2589 100644
--- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py
+++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py
@@ -107,7 +107,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
- log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
+ log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)
diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py
new file mode 100644
index 0000000000000..fb45166378c78
--- /dev/null
+++ b/src/lightning/pytorch/strategies/model_parallel.py
@@ -0,0 +1,363 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import shutil
+from contextlib import contextmanager, nullcontext
+from datetime import timedelta
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union
+
+import torch
+from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
+from torch import Tensor
+from torch.optim import Optimizer
+from typing_extensions import override
+
+import lightning.pytorch as pl
+from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
+from lightning.fabric.strategies.model_parallel import (
+ _distributed_checkpoint_save,
+ _is_sharded_checkpoint,
+ _load_checkpoint,
+ _setup_device_mesh,
+)
+from lightning.fabric.utilities.distributed import (
+ _distributed_is_initialized,
+ _get_default_process_group_backend_for_device,
+ _init_dist_connection,
+ _sync_ddp_if_available,
+)
+from lightning.fabric.utilities.distributed import group as _group
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
+from lightning.fabric.utilities.init import _materialize_distributed_module
+from lightning.fabric.utilities.load import _METADATA_FILENAME
+from lightning.fabric.utilities.optimizer import _optimizers_to_device
+from lightning.fabric.utilities.seed import reset_seed
+from lightning.fabric.utilities.types import _PATH, ReduceOp
+from lightning.pytorch.core.optimizer import LightningOptimizer
+from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
+from lightning.pytorch.strategies.parallel import ParallelStrategy
+from lightning.pytorch.strategies.strategy import TBroadcast
+from lightning.pytorch.trainer.states import TrainerFn
+from lightning.pytorch.utilities.model_helpers import is_overridden
+from lightning.pytorch.utilities.rank_zero import rank_zero_only
+
+if TYPE_CHECKING:
+ from torch.distributed.device_mesh import DeviceMesh
+
+
+class ModelParallelStrategy(ParallelStrategy):
+ """Enables user-defined parallelism applied to a model.
+
+ .. warning:: This is an :ref:`experimental ` feature.
+
+ Currently supports up to 2D parallelism. Specifically, it supports the combination of
+ Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
+ experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html).
+ Requires PyTorch 2.4 or newer.
+
+ Arguments:
+ data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
+ sets this size to the number of nodes in the cluster.
+ tensor_parallel_size: The number of devices within a tensor-parallel group. Defaults to ``"auto"``, which
+ sets this size to the number of GPUs in a single node.
+ save_distributed_checkpoint: If ``True``, each rank saves its shard of weights and optimizer states to a file.
+ The checkpoint is a folder with as many files as the world size.
+ If ``False``, the full weights and optimizer states get assembled on rank 0 and saved to a single file.
+
+ """
+
+ def __init__(
+ self,
+ data_parallel_size: Union[Literal["auto"], int] = "auto",
+ tensor_parallel_size: Union[Literal["auto"], int] = "auto",
+ save_distributed_checkpoint: bool = True,
+ process_group_backend: Optional[str] = None,
+ timeout: Optional[timedelta] = default_pg_timeout,
+ ) -> None:
+ super().__init__()
+ if not _TORCH_GREATER_EQUAL_2_4:
+ raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
+ self._data_parallel_size = data_parallel_size
+ self._tensor_parallel_size = tensor_parallel_size
+ self._save_distributed_checkpoint = save_distributed_checkpoint
+ self._process_group_backend: Optional[str] = process_group_backend
+ self._timeout: Optional[timedelta] = timeout
+ self._device_mesh: Optional[DeviceMesh] = None
+ self.num_nodes = 1
+
+ @property
+ def device_mesh(self) -> "DeviceMesh":
+ if self._device_mesh is None:
+ raise RuntimeError("Accessing the device mesh before processes have initialized is not allowed.")
+ return self._device_mesh
+
+ @property
+ @override
+ def root_device(self) -> torch.device:
+ assert self.parallel_devices is not None
+ return self.parallel_devices[self.local_rank]
+
+ @property
+ def num_processes(self) -> int:
+ return len(self.parallel_devices) if self.parallel_devices is not None else 0
+
+ @property
+ @override
+ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
+ assert self.device_mesh is not None
+ data_parallel_mesh = self.device_mesh["data_parallel"]
+ return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()}
+
+ @property
+ def process_group_backend(self) -> Optional[str]:
+ return self._process_group_backend
+
+ @property
+ @override
+ def restore_checkpoint_after_setup(self) -> bool:
+ return True
+
+ @property
+ @override
+ def lightning_restore_optimizer(self) -> bool:
+ return False
+
+ @override
+ def _configure_launcher(self) -> None:
+ assert self.cluster_environment is not None
+ if not self.cluster_environment.creates_processes_externally:
+ self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
+
+ @override
+ def setup_environment(self) -> None:
+ super().setup_environment()
+ self._setup_distributed()
+ if self._data_parallel_size == "auto":
+ self._data_parallel_size = self.num_nodes
+ if self._tensor_parallel_size == "auto":
+ self._tensor_parallel_size = self.num_processes
+ self._device_mesh = _setup_device_mesh(
+ self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device
+ )
+ # Users can access device mesh in `LightningModule.configure_model()`
+ assert self.lightning_module is not None
+ self.lightning_module._device_mesh = self._device_mesh
+
+ @override
+ def setup(self, trainer: "pl.Trainer") -> None:
+ from torch.distributed.fsdp import FullyShardedDataParallel
+
+ assert self.model is not None
+ assert self.accelerator is not None
+ self.accelerator.setup(trainer)
+
+ if not is_overridden("configure_model", self.lightning_module):
+ raise TypeError(
+ f"When using the {type(self).__name__}, you are required to override the `configure_model()` hook in"
+ f" the LightningModule and apply parallelization there."
+ )
+ if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()):
+ raise TypeError(
+ "Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
+ f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
+ )
+
+ _materialize_distributed_module(self.model, self.root_device)
+
+ self.model = self.precision_plugin.convert_module(self.model)
+ self.model_to_device() # move all remaining layers if any left on CPU.
+
+ self.barrier()
+
+ if trainer.state.fn == TrainerFn.FITTING:
+ self.setup_optimizers(trainer)
+ self.setup_precision_plugin()
+ if trainer.state.fn == TrainerFn.FITTING:
+ _optimizers_to_device(self.optimizers, self.root_device)
+
+ @override
+ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
+ # If we're setting up for evaluation after fitting, we need to discard the optimizers
+ # since we're rewrapping the model, otherwise optimizer param references are no longer valid
+ # and subsequent checkpoint saving can fail
+ self._reset_optimizers_and_schedulers()
+
+ return super().setup_optimizers(trainer)
+
+ @override
+ def model_to_device(self) -> None:
+ assert self.model is not None
+ self.model.to(self.root_device)
+
+ @contextmanager
+ @override
+ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
+ # Materializaton happens in `setup()`
+ empty_init_context = torch.device("meta") if empty_init else nullcontext()
+ with empty_init_context, self.precision_plugin.tensor_init_context():
+ yield
+
+ @override
+ def barrier(self, name: Optional[str] = None) -> None:
+ if not _distributed_is_initialized():
+ return
+ if torch.distributed.get_backend() == "nccl":
+ torch.distributed.barrier(device_ids=self._determine_device_ids())
+ else:
+ torch.distributed.barrier()
+
+ @override
+ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
+ if not _distributed_is_initialized():
+ return obj
+
+ obj = [obj]
+ torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
+ return obj[0]
+
+ @override
+ def reduce(
+ self,
+ tensor: Union[Tensor, Any],
+ group: Optional[Any] = None,
+ reduce_op: Optional[Union[ReduceOp, str]] = "mean",
+ ) -> Tensor:
+ if isinstance(tensor, Tensor):
+ return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
+ return tensor
+
+ def _determine_device_ids(self) -> List[int]:
+ return [self.root_device.index]
+
+ @override
+ def teardown(self) -> None:
+ assert self.cluster_environment is not None
+ assert self.accelerator is not None
+ self.cluster_environment.teardown()
+ self.precision_plugin.teardown()
+ self.accelerator.teardown()
+
+ @override
+ def lightning_module_state_dict(self) -> Dict[str, Any]:
+ """Collects the state dict of the model.
+
+ Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
+
+ """
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
+
+ state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True)
+ assert self.model is not None
+ return get_model_state_dict(self.model, options=state_dict_options)
+
+ @override
+ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:
+ # Override to do nothing, the strategy already loaded the states in `load_checkpoint()`
+ pass
+
+ @override
+ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]:
+ """Collects the state of the given optimizer.
+
+ Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``.
+
+ """
+ from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp import OptimStateKeyType
+
+ state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True)
+ if isinstance(optimizer, LightningOptimizer):
+ optimizer = optimizer._optimizer
+
+ assert self.model is not None
+
+ state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options)
+ if not self._save_distributed_checkpoint and self.global_rank == 0:
+ # Store the optimizer state dict in standard format
+ state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
+ return state_dict
+
+ @override
+ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
+ # Override to do nothing, the strategy already loaded the states in `load_checkpoint()`
+ pass
+
+ @override
+ def save_checkpoint(
+ self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
+ ) -> None:
+ if storage_options is not None:
+ raise TypeError(
+ f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because"
+ f" `{type(self).__name__}` does not use the `CheckpointIO`."
+ )
+ # broadcast the path from rank 0 to ensure all the checkpoints are saved to a common path
+ path = Path(self.broadcast(filepath))
+ if path.is_dir() and not self._save_distributed_checkpoint and not _is_sharded_checkpoint(path):
+ raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
+
+ if self._save_distributed_checkpoint:
+ if path.is_file():
+ path.unlink()
+ path.mkdir(parents=True, exist_ok=True)
+
+ converted_state = {"state_dict": checkpoint.pop("state_dict")}
+ converted_state.update({
+ f"optimizer_{idx}": optim_state
+ for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
+ })
+ _distributed_checkpoint_save(converted_state, path)
+
+ if self.global_rank == 0:
+ torch.save(checkpoint, path / _METADATA_FILENAME)
+ else:
+ if _is_sharded_checkpoint(path):
+ shutil.rmtree(path)
+ return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
+
+ @override
+ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
+ # broadcast the path from rank 0 to ensure all the states are loaded from a common path
+ path = Path(self.broadcast(checkpoint_path))
+ state = {
+ "state_dict": self.model,
+ **{f"optimizer_{idx}": optimizer for idx, optimizer in enumerate(self.optimizers)},
+ }
+ assert self.lightning_module is not None
+ return _load_checkpoint(
+ path=path,
+ state=state,
+ strict=self.lightning_module.strict_loading,
+ optimizer_states_from_list=True,
+ )
+
+ def _setup_distributed(self) -> None:
+ super().setup_environment()
+ reset_seed()
+ self.set_world_ranks()
+ self._process_group_backend = self._get_process_group_backend()
+ assert self.cluster_environment is not None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+
+ def _get_process_group_backend(self) -> str:
+ return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
+
+ def set_world_ranks(self) -> None:
+ if self.cluster_environment is not None:
+ self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
+ self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
+ # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
+ # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
+ rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py
index f2acd8ac98eba..314007f497f59 100644
--- a/src/lightning/pytorch/strategies/strategy.py
+++ b/src/lightning/pytorch/strategies/strategy.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from abc import ABC, abstractmethod
-from contextlib import contextmanager, nullcontext
+from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
import torch
@@ -26,7 +26,6 @@
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.distributed import ReduceOp
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device
from lightning.fabric.utilities.types import _PATH
@@ -53,7 +52,7 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[Precision] = None,
) -> None:
- self._accelerator: Optional["pl.accelerators.Accelerator"] = accelerator
+ self._accelerator: Optional[pl.accelerators.Accelerator] = accelerator
self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io
self._precision_plugin: Optional[Precision] = None
# Call the precision setter for input validation
@@ -509,9 +508,8 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
If ``None``, the strategy will decide. Some strategies may not support all options.
"""
- device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
empty_init_context = _EmptyInit(enabled=bool(empty_init))
- with empty_init_context, device_context, self.precision_plugin.tensor_init_context():
+ with empty_init_context, self.root_device, self.precision_plugin.tensor_init_context():
yield
@contextmanager
diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py
index befd7f0df84dc..4c3bc5ef41bdd 100644
--- a/src/lightning/pytorch/trainer/call.py
+++ b/src/lightning/pytorch/trainer/call.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import signal
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Union
@@ -20,10 +21,12 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
+from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
+from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
from lightning.pytorch.trainer.states import TrainerStatus
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.model_helpers import is_overridden
-from lightning.pytorch.utilities.rank_zero import rank_zero_warn
+from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
log = logging.getLogger(__name__)
@@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None
- # TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
- rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
- # user could press Ctrl+c many times... only shutdown once
- if not trainer.interrupted:
- _interrupt(trainer, exception)
+ rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
+ # user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ _interrupt(trainer, exception)
+ trainer._teardown()
+ launcher = trainer.strategy.launcher
+ if isinstance(launcher, _SubprocessScriptLauncher):
+ launcher.kill(_get_sigkill_signal())
+ exit(1)
+
except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()
diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
index 269a1c4c7b754..06f3ee366bcaa 100644
--- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
@@ -53,6 +53,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
+ ModelParallelStrategy,
ParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
@@ -327,7 +328,8 @@ def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str
f" using {accelerator_name} accelerator."
)
- def _choose_auto_accelerator(self) -> str:
+ @staticmethod
+ def _choose_auto_accelerator() -> str:
"""Choose the accelerator type (str) based on availability."""
if XLAAccelerator.is_available():
return "tpu"
@@ -455,10 +457,11 @@ def _check_strategy_and_fallback(self) -> None:
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
if (
- strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
+ strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
- raise MisconfigurationException(
- f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used."
+ raise ValueError(
+ f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
+ f" {self._accelerator_flag}"
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
@@ -527,6 +530,16 @@ def _validate_precision_choice(self) -> None:
self.accelerator, CUDAAccelerator
):
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
+ mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
+ if (
+ isinstance(self._strategy_flag, ModelParallelStrategy)
+ and self._precision_flag not in mp_precision_supported
+ ):
+ raise ValueError(
+ f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_flag!r})`."
+ f" Choose a different precision among: {', '.join(mp_precision_supported)}."
+ )
+
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator
@@ -599,6 +612,7 @@ def is_distributed(self) -> bool:
DDPStrategy,
FSDPStrategy,
DeepSpeedStrategy,
+ ModelParallelStrategy,
XLAStrategy,
]
if _habana_available_and_importable():
diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
index 7e0ef433031bd..583105c3660e0 100644
--- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
+++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
@@ -24,7 +24,6 @@
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.distributed import _distributed_is_initialized
-from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities.data import extract_batch_size
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
@@ -92,7 +91,7 @@ def _generate_sync_fn(self) -> None:
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
# save the function as `_fn` as the meta are being re-created and the object references need to match.
# ignore typing, bad support for `partial`: mypy/issues/1484
- self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore [arg-type]
+ self._fn: Callable = partial(fn, reduce_op=self.op, group=self.group) # type: ignore[arg-type,operator,misc]
@property
def __call__(self) -> Any:
@@ -112,7 +111,7 @@ class _Metadata:
on_step: bool = False
on_epoch: bool = True
# https://github.com/pytorch/pytorch/issues/96197
- reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment]
+ reduce_fx: Callable = torch.mean
enable_graph: bool = False
add_dataloader_idx: bool = True
dataloader_idx: Optional[int] = None
@@ -305,9 +304,7 @@ def __repr__(self) -> str:
@override
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
- d = self.__dict__
- if _TORCH_GREATER_EQUAL_2_0: # https://github.com/pytorch/pytorch/issues/96198
- d = dict(d)
+ d = dict(self.__dict__)
self.__dict__.update(apply_to_collection(d, (Tensor, Metric), move_data_to_device, *args, **kwargs))
return self
@@ -364,7 +361,7 @@ def log(
on_step: bool = False,
on_epoch: bool = True,
# https://github.com/pytorch/pytorch/issues/96197
- reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment]
+ reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,
@@ -403,26 +400,19 @@ def log(
# register logged value if it doesn't exist
if key not in self:
- self.register_key(key, meta, value)
+ metric = _ResultMetric(meta, isinstance(value, Tensor))
+ self[key] = metric
# check the stored metadata and the current one match
elif meta != self[key].meta:
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
+ self[key].to(value.device)
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)
- def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
- """Create one _ResultMetric object per value.
-
- Value can be provided as a nested collection
-
- """
- metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
- self[key] = metric
-
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py
index 728d8b6b6ee43..05a975326005f 100644
--- a/src/lightning/pytorch/trainer/connectors/signal_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py
@@ -2,7 +2,6 @@
import os
import re
import signal
-import sys
import threading
from subprocess import call
from types import FrameType
@@ -10,7 +9,7 @@
import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
-from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
+from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_info
# copied from signal.pyi
@@ -54,7 +53,7 @@ def register_signal_handlers(self) -> None:
sigterm_handlers.append(self._sigterm_handler_fn)
# Windows seems to have signal incompatibilities
- if not self._is_on_windows():
+ if not _IS_WINDOWS:
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
assert sigusr is not None
if sigusr_handlers and not self._has_already_handler(sigusr):
@@ -134,30 +133,8 @@ def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]:
@staticmethod
def _valid_signals() -> Set[signal.Signals]:
- """Returns all valid signals supported on the current platform.
-
- Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
- older Python versions.
-
- """
- if _PYTHON_GREATER_EQUAL_3_8_0:
- return signal.valid_signals()
- if _IS_WINDOWS:
- # supported signals on Windows: https://docs.python.org/3/library/signal.html#signal.signal
- return {
- signal.SIGABRT,
- signal.SIGFPE,
- signal.SIGILL,
- signal.SIGINT,
- signal.SIGSEGV,
- signal.SIGTERM,
- signal.SIGBREAK,
- }
- return set(signal.Signals)
-
- @staticmethod
- def _is_on_windows() -> bool:
- return sys.platform == "win32"
+ """Returns all valid signals supported on the current platform."""
+ return signal.valid_signals()
@staticmethod
def _has_already_handler(signum: _SIGNUM) -> bool:
@@ -172,3 +149,7 @@ def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["_original_handlers"] = {}
return state
+
+
+def _get_sigkill_signal() -> _SIGNUM:
+ return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL
diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py
index 6436fc54b7bed..406f686efe732 100644
--- a/src/lightning/pytorch/trainer/trainer.py
+++ b/src/lightning/pytorch/trainer/trainer.py
@@ -23,7 +23,6 @@
import logging
import math
import os
-import warnings
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
@@ -35,7 +34,6 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
@@ -83,10 +81,6 @@
from lightning.pytorch.utilities.warnings import PossibleUserWarning
log = logging.getLogger(__name__)
-# warnings to ignore in trainer
-warnings.filterwarnings(
- "ignore", message="torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead"
-)
class Trainer:
@@ -207,7 +201,7 @@ def __init__(
across epochs or during iteration-based training.
Default: ``1.0``.
- check_val_every_n_epoch: Perform a validation loop every after every `N` training epochs. If ``None``,
+ check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
Default: ``1``.
@@ -1018,9 +1012,7 @@ def _teardown(self) -> None:
def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# wait for all to join if on distributed
self.strategy.barrier("run-stage")
-
- zero_grad_kwargs = {} if _TORCH_GREATER_EQUAL_2_0 else {"set_to_none": True}
- self.lightning_module.zero_grad(**zero_grad_kwargs)
+ self.lightning_module.zero_grad()
if self.evaluating:
return self._evaluation_loop.run()
@@ -1084,8 +1076,7 @@ def init_module(self, empty_init: Optional[bool] = None) -> Generator:
the right data type depending on the precision setting in the Trainer.
The parameters and tensors get created on the device and with the right data type right away without wasting
- memory being allocated unnecessarily. The automatic device placement under this context manager is only
- supported with PyTorch 2.0 and newer.
+ memory being allocated unnecessarily.
Args:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
@@ -1093,13 +1084,6 @@ def init_module(self, empty_init: Optional[bool] = None) -> Generator:
Set this to ``True`` if you are loading a checkpoint into a large model.
"""
- if not _TORCH_GREATER_EQUAL_2_0 and self.strategy.root_device.type != "cpu":
- rank_zero_warn(
- "`Trainer.init_module()` can't place tensors on the device directly"
- " with PyTorch < 2.0. Parameters will remain on CPU until the trainer starts."
- " Upgrade to PyTorch >= 2.0 to fully utilize this feature.",
- category=PossibleUserWarning,
- )
if is_overridden("model_sharded_context", self.strategy, parent=Strategy):
# warning instead of error so that code changes are not required when changing strategies
# this is a limitation because processes are not expected to have been launched when this is called
@@ -1665,8 +1649,8 @@ def _results(self) -> Optional[_ResultCollection]:
def estimated_stepping_batches(self) -> Union[int, float]:
r"""The estimated number of batches that will ``optimizer.step()`` during training.
- This accounts for gradient accumulation and the current trainer configuration. This might sets up your training
- dataloader if hadn't been set up already.
+ This accounts for gradient accumulation and the current trainer configuration. This might be used when setting
+ up your training dataloader, if it hasn't been set up already.
.. code-block:: python
diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py
index f39788b8ea290..d756d3d76597c 100644
--- a/src/lightning/pytorch/tuner/lr_finder.py
+++ b/src/lightning/pytorch/tuner/lr_finder.py
@@ -16,19 +16,19 @@
import os
import uuid
from copy import deepcopy
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
from lightning_utilities.core.imports import RequirementCache
+from torch.optim.lr_scheduler import LRScheduler
from typing_extensions import override
import lightning.pytorch as pl
-from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.parsing import lightning_hasattr, lightning_setattr
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
-from lightning.pytorch.utilities.types import STEP_OUTPUT, LRScheduler, LRSchedulerConfig
+from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
@@ -127,13 +127,14 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
- scheduler = cast(LRScheduler, scheduler)
trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")]
_validate_optimizers_attached(trainer.optimizers, trainer.lr_scheduler_configs)
- def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]:
+ def plot(
+ self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None
+ ) -> Optional[Union["plt.Figure", "plt.SubFigure"]]:
"""Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
@@ -152,10 +153,11 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] =
lrs = self.results["lr"]
losses = self.results["loss"]
+ fig: Optional[Union[plt.Figure, plt.SubFigure]]
if ax is None:
fig, ax = plt.subplots()
else:
- fig = ax.figure # type: ignore[assignment]
+ fig = ax.figure
# Plot loss as a function of the learning rate
ax.plot(lrs, losses)
@@ -191,7 +193,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
losses = losses[torch.isfinite(losses)]
if len(losses) < 2:
- # computing np.gradient requires at least 2 points
+ # computing torch.gradient requires at least 2 points
log.error(
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
" iteration limits or the size of your dataset/dataloader."
@@ -302,6 +304,7 @@ def _lr_find(
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
+ trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
return lr_finder
@@ -439,7 +442,7 @@ def on_train_batch_end(
self.losses.append(smoothed_loss)
-class _LinearLR(_TORCH_LRSCHEDULER):
+class _LinearLR(LRScheduler):
"""Linearly increases the learning rate between two boundaries over a number of iterations.
Args:
@@ -459,8 +462,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
- # mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
- @override # type: ignore[misc]
+ @override
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
@@ -477,7 +479,7 @@ def lr(self) -> Union[float, List[float]]:
return self._lr
-class _ExponentialLR(_TORCH_LRSCHEDULER):
+class _ExponentialLR(LRScheduler):
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
Arguments:
@@ -497,8 +499,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
- # mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
- @override # type: ignore[misc]
+ @override
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
diff --git a/src/lightning/pytorch/utilities/__init__.py b/src/lightning/pytorch/utilities/__init__.py
index 5cd0af5ac7813..c3ba77b46e8b7 100644
--- a/src/lightning/pytorch/utilities/__init__.py
+++ b/src/lightning/pytorch/utilities/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""General utilities."""
-import numpy
+import torch
from lightning.fabric.utilities import (
LightningEnum,
@@ -55,6 +55,6 @@
"suggested_max_num_workers",
]
-FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
-FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
-FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
+FLOAT16_EPSILON = torch.finfo(torch.float16).eps
+FLOAT32_EPSILON = torch.finfo(torch.float32).eps
+FLOAT64_EPSILON = torch.finfo(torch.float64).eps
diff --git a/src/lightning/pytorch/utilities/_pytree.py b/src/lightning/pytorch/utilities/_pytree.py
index 34f7db1486b71..f5f48b481c879 100644
--- a/src/lightning/pytorch/utilities/_pytree.py
+++ b/src/lightning/pytorch/utilities/_pytree.py
@@ -25,7 +25,7 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
child_pytrees, context = flatten_fn(pytree)
result: List[Any] = []
- children_specs: List["TreeSpec"] = []
+ children_specs: List[TreeSpec] = []
for child in child_pytrees:
flat, child_spec = _tree_flatten(child)
result += flat
diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py
index a77ed553d418e..cb2433e04bb1a 100644
--- a/src/lightning/pytorch/utilities/compile.py
+++ b/src/lightning/pytorch/utilities/compile.py
@@ -14,14 +14,14 @@
from typing import Union
import torch
+from torch._dynamo import OptimizedModule
import lightning.pytorch as pl
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy
from lightning.pytorch.utilities.model_helpers import _check_mixed_imports
-def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule":
+def from_compiled(model: OptimizedModule) -> "pl.LightningModule":
"""Returns an instance LightningModule from the output of ``torch.compile``.
.. warning:: This is an :ref:`experimental ` feature.
@@ -33,11 +33,6 @@ def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule
Use this method to obtain a LightningModule that still runs with all the optimizations from ``torch.compile``.
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise ModuleNotFoundError("`from_compiled` requires torch>=2.0")
-
- from torch._dynamo import OptimizedModule
-
if not isinstance(model, OptimizedModule):
raise ValueError(f"`model` is required to be a `OptimizedModule`. Found a `{type(model).__name__}` instead.")
@@ -60,11 +55,7 @@ def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule
}
orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[method-assign]
- if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
- orig_module.forward._torchdynamo_inline = orig_module.forward
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[method-assign]
- if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
- orig_module.training_step._torchdynamo_inline = orig_module.training_step
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[method-assign]
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[method-assign]
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[method-assign]
@@ -82,11 +73,6 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod
Note: this method will in-place modify the ``LightningModule`` that is passed in.
"""
- if not _TORCH_GREATER_EQUAL_2_0:
- raise ModuleNotFoundError("`to_uncompiled` requires torch>=2.0")
-
- from torch._dynamo import OptimizedModule
-
if isinstance(model, OptimizedModule):
original = model._orig_mod
if not isinstance(original, pl.LightningModule):
@@ -117,13 +103,6 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod
def _maybe_unwrap_optimized(model: object) -> "pl.LightningModule":
- if not _TORCH_GREATER_EQUAL_2_0:
- if not isinstance(model, pl.LightningModule):
- _check_mixed_imports(model)
- raise TypeError(f"`model` must be a `LightningModule`, got `{type(model).__qualname__}`")
- return model
- from torch._dynamo import OptimizedModule
-
if isinstance(model, OptimizedModule):
return from_compiled(model)
if isinstance(model, pl.LightningModule):
diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py
index dde6cf2a33b02..36adedf4a831b 100644
--- a/src/lightning/pytorch/utilities/model_helpers.py
+++ b/src/lightning/pytorch/utilities/model_helpers.py
@@ -127,6 +127,9 @@ def wrapper(*args: Any, **kwargs: Any) -> _R_co:
return wrapper
-# trick static type checkers into thinking it's a @classmethod
-# https://github.com/microsoft/pyright/issues/5865
-_restricted_classmethod = classmethod if TYPE_CHECKING else _restricted_classmethod_impl
+if TYPE_CHECKING:
+ # trick static type checkers into thinking it's a @classmethod
+ # https://github.com/microsoft/pyright/issues/5865
+ _restricted_classmethod = classmethod
+else:
+ _restricted_classmethod = _restricted_classmethod_impl
diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py
index ef2827d3b7eed..c40dc94568a51 100644
--- a/src/lightning/pytorch/utilities/model_summary/model_summary.py
+++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py
@@ -25,7 +25,7 @@
from torch.utils.hooks import RemovableHandle
import lightning.pytorch as pl
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
+from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache
@@ -107,10 +107,7 @@ def hook_with_kwargs(_: nn.Module, args: Any, kwargs: Any, out: Any) -> None:
handle = None
if not isinstance(self._module, torch.jit.ScriptModule):
- if _TORCH_GREATER_EQUAL_2_0:
- handle = self._module.register_forward_hook(hook_with_kwargs, with_kwargs=True)
- else:
- handle = self._module.register_forward_hook(hook)
+ handle = self._module.register_forward_hook(hook_with_kwargs, with_kwargs=True)
return handle
@@ -139,7 +136,7 @@ def layer_type(self) -> str:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
- return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
+ return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
@property
def training(self) -> bool:
@@ -191,6 +188,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
+ 3 Modules in train mode
+ 0 Modules in eval mode
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
@@ -202,6 +201,8 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
+ 3 Modules in train mode
+ 0 Modules in eval mode
"""
@@ -256,15 +257,19 @@ def param_nums(self) -> List[int]:
def training_modes(self) -> List[bool]:
return [layer.training for layer in self._layer_summary.values()]
+ @property
+ def total_training_modes(self) -> Dict[str, int]:
+ modes = [layer.training for layer in self._model.modules()]
+ modes = modes[1:] # exclude the root module
+ return {"train": modes.count(True), "eval": modes.count(False)}
+
@property
def total_parameters(self) -> int:
- return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
+ return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
@property
def trainable_parameters(self) -> int:
- return sum(
- p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
- )
+ return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)
@property
def total_layer_params(self) -> int:
@@ -355,8 +360,9 @@ def __str__(self) -> str:
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
+ total_training_modes = self.total_training_modes
- return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)
+ return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)
def __repr__(self) -> str:
return str(self)
@@ -376,6 +382,7 @@ def _format_summary_table(
total_parameters: int,
trainable_parameters: int,
model_size: float,
+ total_training_modes: Dict[str, int],
*cols: Tuple[str, List[str]],
) -> str:
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
@@ -412,6 +419,10 @@ def _format_summary_table(
summary += "Total params"
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
summary += "Total estimated model params size (MB)"
+ summary += "\n" + s.format(total_training_modes["train"], 10)
+ summary += "Modules in train mode"
+ summary += "\n" + s.format(total_training_modes["eval"], 10)
+ summary += "Modules in eval mode"
return summary
@@ -458,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
return f"{number:,.1f} {labels[index]}"
-def _is_lazy_weight_tensor(p: Tensor) -> bool:
+def _tensor_has_shape(p: Tensor) -> bool:
from torch.nn.parameter import UninitializedParameter
- if isinstance(p, UninitializedParameter):
+ # DTensor is a subtype of `UninitializedParameter`, but the shape is known
+ if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
warning_cache.warn(
"The total number of parameters detected may be inaccurate because the model contains"
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
index c3c9cfe9823a3..57d9ae5024b58 100644
--- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
+++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
@@ -25,7 +25,7 @@
NOT_APPLICABLE,
LayerSummary,
ModelSummary,
- _is_lazy_weight_tensor,
+ _tensor_has_shape,
get_human_readable_count,
)
@@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
@override
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
- return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
+ return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
@property
def average_shard_parameters(self) -> int:
@@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int:
def partitioned_size(p: Parameter) -> int:
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()
- return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
+ return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())
class DeepSpeedSummary(ModelSummary):
@@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
@property
@override
def total_parameters(self) -> int:
- return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
+ return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())
@property
@override
def trainable_parameters(self) -> int:
return sum(
- deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
+ deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
for p in self._model.parameters()
if p.requires_grad
)
diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py
index c3e0262d9906f..03b3afd61b875 100644
--- a/src/lightning/pytorch/utilities/testing/_runif.py
+++ b/src/lightning/pytorch/utilities/testing/_runif.py
@@ -15,7 +15,6 @@
from lightning_utilities.core.imports import RequirementCache
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
@@ -94,7 +93,7 @@ def _runif_reasons(
if sklearn and not _SKLEARN_AVAILABLE:
reasons.append("scikit-learn")
- if onnx and _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE:
+ if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")
return reasons, kwargs
diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py
index 203df53f22ba7..c1b971e924a52 100644
--- a/src/lightning/pytorch/utilities/types.py
+++ b/src/lightning/pytorch/utilities/types.py
@@ -38,10 +38,11 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from torchmetrics import Metric
from typing_extensions import NotRequired, Required
-from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau
+from lightning.fabric.utilities.types import ProcessGroup
_NUMBER = Union[int, float]
_METRIC = Union[Metric, Tensor, _NUMBER]
@@ -76,15 +77,15 @@ def no_sync(self) -> Generator: ...
# todo: improve LRSchedulerType naming/typing
-LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau)
-LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau]
-LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
+LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau)
+LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau]
+LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]
@dataclass
class LRSchedulerConfig:
- scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
+ scheduler: Union[LRScheduler, ReduceLROnPlateau]
# no custom name
name: Optional[str] = None
# after epoch is over
@@ -106,7 +107,7 @@ class LRSchedulerConfigType(TypedDict, total=False):
frequency: int
reduce_on_plateau: bool
monitor: Optional[str]
- scrict: bool
+ strict: bool
class OptimizerLRSchedulerConfig(TypedDict):
diff --git a/src/lightning/store/README.md b/src/lightning/store/README.md
deleted file mode 100644
index 0f2fdf61f76de..0000000000000
--- a/src/lightning/store/README.md
+++ /dev/null
@@ -1,41 +0,0 @@
-## Getting Started
-
-- Login to lightning.ai (_optional_) \<-- takes less than a minute. ⏩
-- Store your models on the cloud \<-- simple call: `upload_model(...)`. 🗳️
-- Share it with your friends \<-- just share the "username/model_name" (and version if required) format. :handshake:
-- They download using a simple call: `download_model("username/model_name", version="your_version")`. :wink:
-- Lightning :zap: fast, isn't it?. :heart:
-
-## Usage
-
-**Storing to the cloud**
-
-```python
-import lightning as L
-
-# Upload a checkpoint:
-L.store.upload_model("mnist_model", "mnist_model.ckpt")
-
-# Optionally provide a version:
-L.store.upload_model("mnist_model", "mnist_model.ckpt", version="1.0.0")
-```
-
-**List your models**
-
-```python
-import lightning as L
-
-models = L.store.list_models()
-
-print([model.name for model in models])
-# ['username/mnist_model']
-```
-
-**Downloading from the cloud**
-
-```python
-import lightning as L
-
-# Download a checkpoint
-L.store.download_model("username/mnist_model", "any_path.ckpt")
-```
diff --git a/src/lightning/store/__init__.py b/src/lightning/store/__init__.py
deleted file mode 100644
index 951bd7262e698..0000000000000
--- a/src/lightning/store/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from lightning.store.store import download_model, list_models, upload_model
-
-__all__ = ["download_model", "upload_model", "list_models"]
diff --git a/src/lightning/store/store.py b/src/lightning/store/store.py
deleted file mode 100644
index f2389f8b8aa63..0000000000000
--- a/src/lightning/store/store.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-from typing import List
-
-from lightning_cloud.openapi import V1Model, V1UploadModelRequest
-
-from lightning.app.utilities.cloud import _get_project
-from lightning.store.utils import _Client, _download_file_from_url, _upload_file_to_url
-
-
-def upload_model(
- name: str,
- path: str,
- version: str = "latest",
- progress_bar: bool = True,
-) -> None:
- """Upload a model to the lightning cloud.
-
- Args:
- name:
- The model name.
- path:
- The path to the checkpoint to be uploaded.
- version:
- The version of the model to be uploaded. If not provided, default will be latest (not overridden).
- progress_bar:
- A progress bar to show the uploading status. Disable this if not needed, by setting to `False`.
-
- """
- client = _Client()
- user = client.auth_service_get_user()
- # TODO: Allow passing this
- project_id = _get_project(client).project_id
-
- # TODO: Post model parts if the file size is over threshold
- body = V1UploadModelRequest(
- name=f"{user.username}/{name}",
- version=version,
- project_id=project_id,
- )
- model = client.models_store_upload_model(body)
-
- _upload_file_to_url(model.upload_url, path, progress_bar=progress_bar)
-
-
-def download_model(
- name: str,
- path: str,
- version: str = "latest",
- progress_bar: bool = True,
-) -> None:
- """Download a model from the lightning cloud.
-
- Args:
- name:
- The unique name of the model to be downloaded. Format: `/`.
- path:
- The path to download the model to.
- version:
- The version of the model to be uploaded. If not provided, default will be latest (not overridden).
- progress_bar:
- Show progress on download.
-
- """
- client = _Client()
- download_url = client.models_store_download_model(name=name, version=version).download_url
- _download_file_from_url(download_url, os.path.abspath(path), progress_bar=progress_bar)
-
-
-def list_models() -> List[V1Model]:
- """List your models in the lightning cloud.
-
- Returns:
- A list of model objects.
-
- """
- client = _Client()
- # TODO: Allow passing this
- project_id = _get_project(client).project_id
- return client.models_store_list_models(project_id=project_id).models
diff --git a/src/lightning/store/utils.py b/src/lightning/store/utils.py
deleted file mode 100644
index f3b8b9b7ae883..0000000000000
--- a/src/lightning/store/utils.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-
-import requests
-from lightning_cloud.openapi import AuthServiceApi, ModelsStoreApi, ProjectsServiceApi
-from lightning_cloud.rest_client import create_swagger_client
-from tqdm import tqdm
-from tqdm.utils import CallbackIOWrapper
-
-
-def _upload_file_to_url(url: str, path: str, progress_bar: bool) -> None:
- if progress_bar:
- file_size = os.path.getsize(path)
- with open(path, "rb") as fd, tqdm(
- desc="Uploading",
- total=file_size,
- unit="B",
- unit_scale=True,
- unit_divisor=1000,
- ) as t:
- reader_wrapper = CallbackIOWrapper(t.update, fd, "read")
- response = requests.put(url, data=reader_wrapper)
- response.raise_for_status()
- else:
- with open(path, "rb") as fo:
- requests.put(url, data=fo)
-
-
-def _download_file_from_url(url: str, path: str, progress_bar: bool) -> None:
- with requests.get(url, stream=True) as req_stream:
- total_size_in_bytes = int(req_stream.headers.get("content-length", 0))
- block_size = 1000 * 1000 # 1 MB
-
- download_progress_bar = None
- if progress_bar:
- download_progress_bar = tqdm(
- desc="Downloading",
- total=total_size_in_bytes,
- unit="B",
- unit_scale=True,
- unit_divisor=1000,
- )
- with open(path, "wb") as f:
- for chunk in req_stream.iter_content(chunk_size=block_size):
- if download_progress_bar:
- download_progress_bar.update(len(chunk))
- f.write(chunk)
- if download_progress_bar:
- download_progress_bar.close()
-
-
-class _Client(AuthServiceApi, ModelsStoreApi, ProjectsServiceApi):
- def __init__(self):
- api_client = create_swagger_client()
- super().__init__(api_client)
diff --git a/src/lightning_app/MANIFEST.in b/src/lightning_app/MANIFEST.in
deleted file mode 100644
index a8e251508baf5..0000000000000
--- a/src/lightning_app/MANIFEST.in
+++ /dev/null
@@ -1,11 +0,0 @@
-include src/version.info
-include src/lightning_app/version.info
-include src/lightning_app/CHANGELOG.md
-include src/lightning_app/README.md
-recursive-include requirements/app *.txt
-include .actions/assistant.py
-recursive-include src/lightning_app/cli/*-template *
-# TODO: remove this once lightning-ui package is ready as a dependency
-recursive-include src/lightning_app/ui *
-include src/lightning_app/components/serve/catimage.png
-include src/lightning_app/py.typed # marker file for PEP 561
diff --git a/src/lightning_app/README.md b/src/lightning_app/README.md
deleted file mode 100644
index bf28077071b36..0000000000000
--- a/src/lightning_app/README.md
+++ /dev/null
@@ -1,146 +0,0 @@
-
-
-
-
-**With Lightning Apps, you build exactly what you need: from production-ready, multi-cloud ML systems to simple research demos.**
-
-______________________________________________________________________
-
-
- Website •
- Docs •
- Getting started •
- Help •
- Slack
-
-
-[](https://pypi.org/project/lightning_app/)
-[](https://badge.fury.io/py/lightning_app)
-[](https://pepy.tech/project/lightning-app)
-[](https://anaconda.org/conda-forge/lightning_app)
-
-
-
-
-
-## From production-ready, multi-cloud ML systems to simple research demos.
-
-Lightning Apps enable researchers, data scientists, and software engineers to build, share and iterate on highly scalable, complex AI workflows using the tools and technologies of their choice without any of the cloud boilerplate.
-
-With Lightning Apps, your favorite components can work together on any machine at any scale.
-
-# Getting started
-
-## Install Lightning
-
-
-
-Prerequisites
-
-> TIP: We strongly recommend creating a virtual environment first.
-> Don’t know what this is? Follow our [beginner guide here](https://lightning.ai/docs/stable/install/installation.html).
-
-- Python 3.8.x or later (3.8.x, 3.9.x, 3.10.x, ...)
-- Git
-- Set up an alias for python=python3
-- Add the root folder of Lightning to the Environment Variables to PATH
-- (quick-start app requirement) Install Z shell (zsh)
-
-
-
-```bash
-pip install -U lightning
-```
-
-## Run your first Lightning App
-
-1. Install a simple training and deployment app by typing:
-
-```bash
-lightning install app lightning/quick-start
-```
-
-2. If everything was successful, move into the new directory:
-
-```bash
-cd lightning-quick-start
-```
-
-3. Run the app locally
-
-```bash
-lightning run app app.py
-```
-
-4. Alternatively, run it on the public Lightning Cloud to share your app!
-
-```bash
-lightning run app app.py --cloud
-```
-
-[Read this guide](https://lightning.ai/docs/stable/levels/basic/) to learn the basics of Lightning Apps in 15 minutes.
-
-# Features
-
-Lightning Apps consist of a root [LightningFlow](https://lightning.ai/docs/stable/glossary/app_tree.html) component, that optionally contains a tree of 2 types of components: [LightningFlow](https://lightning.ai/lightning-docs/core_api/lightning_flow.html) 🌊 and [LightningWork](https://lightning.ai/lightning-docs/core_api/lightning_work/) ⚒️. Key functionality includes:
-
-- A shared state between components.
-- A constantly running event loop for reactivity.
-- Dynamic attachment of components at runtime.
-- Start and stop functionality of your works.
-
-Lightning Apps can run [locally](https://lightning.ai/lightning-docs/workflows/run_on_private_cloud.html) 💻 or [on the cloud](https://lightning.ai/lightning-docs/core_api/lightning_work/compute.html) 🌩️.
-
-Easy communication 🛰️ between components is supported with:
-
-- [Directional state updates](https://lightning.ai/lightning-docs/core_api/lightning_app/communication.html?highlight=directional%20state) from the Works to the Flow creating an event: When creating interactive apps, you will likely want your components to share information with each other. You might to rely on that information to control their execution, share progress in the UI, trigger a sequence of operations, or more.
-- [Storage](https://lightning.ai/lightning-docs/api_reference/storage.html): The Lightning Storage system makes it easy to share files between LightningWork so you can run your app both locally and in the cloud without changing the code.
- - [Path](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.storage.path.Path.html#lightning_app.storage.path.Path): The Path object is a reference to a specific file or directory from a LightningWork and can be used to transfer those files to another LightningWork (one way, from source to destination).
- - [Payload](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.storage.payload.Payload.html#lightning_app.storage.payload.Payload): The Payload object enables transferring of Python objects from one work to another in a similar fashion as Path.
- - [Drive](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.storage.drive.Drive.html#lightning_app.storage.drive.Drive): The Drive object provides a central place for your components to share data. The drive acts as an isolated folder and any component can access it by knowing its name.
-
-Lightning Apps have built-in support for [adding UIs](https://lightning.ai/lightning-docs/workflows/add_web_ui/) 🎨:
-
-- [StaticWebFrontEnd](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.frontend.web.StaticWebFrontend.html#lightning_app.frontend.web.StaticWebFrontend): A frontend that serves static files from a directory using FastAPI.
-- [StreamlitFrontend](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.frontend.stream_lit.StreamlitFrontend.html#lightning_app.frontend.stream_lit.StreamlitFrontend): A frontend for wrapping Streamlit code in your LightingFlow.
-- [ServeGradio](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.components.serve.gradio_server.ServeGradio.html#servegradio): This class enables you to quickly create a `gradio` based UI for your Lightning App.
-
-[Scheduling](https://lightning.ai/lightning-docs/glossary/scheduling.html) ⏲️: The Lightning Scheduling system makes it easy to schedule your components execution with any arbitrary conditions.
-
-Advanced users who need full control over the environment a LightningWork runs in can [specify a custom Docker image](https://lightning.ai/lightning-docs/glossary/build_config/build_config_advanced.html?highlight=docker) 🐋 that will be deployed in the cloud.
-
-[Environment variables](https://lightning.ai/lightning-docs/glossary/environment_variables.html?highlight=environment%20variables) 💬: If your app is using secrets or values, such as API keys or access tokens, use environment variables to avoid sticking them in the source code.
-
-Ready to use [built-in components](https://lightning.ai/lightning-docs/api_reference/components.html?highlight=built%20components) 🧱:
-
-- [PopenPythonScript](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.components.python.popen.PopenPythonScript.html#lightning_app.components.python.popen.PopenPythonScript): This class enables you to easily run a Python Script.
-- [ModelInferenceAPI](https://lightning.ai/docs/app/stable/api_reference/generated/lightning.app.components.serve.serve.ModelInferenceAPI.html#lightning_app.components.serve.serve.ModelInferenceAPI): This class enables you to easily get your model served.
-
-# App gallery
-
-The [Lightning AI website](https://lightning.ai/) features a curated gallery of Lightning Apps and components that makes it easy to get started. A few highlights:
-
-| App | Description |
-| ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| Train & Demo PyTorch Lightning | Train a model using PyTorch Lightning and deploy it to an interactive demo. Use this Lightning App as a starting point for building more complex apps around your models. |
-| Lightning Sweeper | Run a hyperparameter sweep over any model script across hundreds of cloud machines at once. This Lightning App uses Optuna to provide advanced tuning algorithms (from grid and random search to Hyperband). |
-| Flashy | Flashy, the auto-AI Lightning App, selects the best deep learning model for your image or text datasets. It automatically uses state-of-the-art models from Torchision, TIMM and Hugging Face. |
-
-## Current limitations
-
-- Lightning requires Python 3.8.x or later (3.8.x, 3.9.x, 3.10.x).
-- For now, you can only run a single app locally at a time.
-- You are required to install the Lightning App requirements locally, even when starting the app on the cloud.
-- Multiple works cannot share the same machine.
-- To run on the cloud, you will need access to a browser.
-- Frontends only support the HTTP protocol. TCP support is coming in the future.
-- App Flow Frontends cannot be changed after startup, but you the layout can be updated reactively.
-- Authentication is not supported.
-
-## Asking for help
-
-If you have any questions please:
-
-1. [Read the docs](https://lightning.ai/lightning-docs/).
-1. [Search through existing Discussions](https://github.com/Lightning-ai/lightning/discussions), or [add a new question](https://github.com/Lightning-ai/lightning/discussions/new)
-1. [Join our Discord community ](https://discord.gg/VptPCZkGNa).
diff --git a/src/lightning_app/__about__.py b/src/lightning_app/__about__.py
deleted file mode 100644
index 3ffbe2420a905..0000000000000
--- a/src/lightning_app/__about__.py
+++ /dev/null
@@ -1,35 +0,0 @@
-#!/usr/bin/env python
-# Copyright The Lightning AI team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-
-# __version__ = "0.5.1"
-__author__ = "Lightning-AI et al."
-__author_email__ = "name@pytorchlightning.ai"
-__license__ = "Apache-2.0"
-__copyright__ = f"Copyright (c) 2021-{time.strftime('%Y')}, {__author__}."
-__homepage__ = "https://github.com/Lightning-AI/lightning"
-__docs__ = (
- "Use Lightning Apps to build everything from production-ready, multi-cloud ML systems to simple research demos."
-)
-
-__all__ = [
- "__author__",
- "__author_email__",
- "__copyright__",
- "__docs__",
- "__homepage__",
- "__license__",
-]
diff --git a/src/lightning_app/__main__.py b/src/lightning_app/__main__.py
deleted file mode 100644
index dc40614cf3d8f..0000000000000
--- a/src/lightning_app/__main__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from lightning_app.cli.lightning_cli import main
-
-if __name__ == "__main__":
- main()
diff --git a/src/lightning_app/__setup__.py b/src/lightning_app/__setup__.py
deleted file mode 100644
index 080eb6c3a6399..0000000000000
--- a/src/lightning_app/__setup__.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import glob
-import os
-from importlib.util import module_from_spec, spec_from_file_location
-from pathlib import Path
-from types import ModuleType
-from typing import Any, Dict
-
-from setuptools import find_packages
-
-_PROJECT_ROOT = "."
-_SOURCE_ROOT = os.path.join(_PROJECT_ROOT, "src")
-_PACKAGE_ROOT = os.path.join(_SOURCE_ROOT, "lightning_app")
-_PATH_REQUIREMENTS = os.path.join("requirements", "app")
-_FREEZE_REQUIREMENTS = os.environ.get("FREEZE_REQUIREMENTS", "0").lower() in ("1", "true")
-
-
-def _load_py_module(name: str, location: str) -> ModuleType:
- spec = spec_from_file_location(name, location)
- assert spec, f"Failed to load module {name} from {location}"
- py = module_from_spec(spec)
- assert spec.loader, f"ModuleSpec.loader is None for {name} from {location}"
- spec.loader.exec_module(py)
- return py
-
-
-def _load_assistant() -> ModuleType:
- location = os.path.join(_PROJECT_ROOT, ".actions", "assistant.py")
- return _load_py_module("assistant", location)
-
-
-def _prepare_extras() -> Dict[str, Any]:
- assistant = _load_assistant()
- # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
- # Define package extras. These are only installed if you specify them.
- # From remote, use like `pip install "pytorch-lightning[dev, docs]"`
- # From local copy of repo, use like `PACKAGE_NAME=app pip install ".[dev, docs]"`
- req_files = [Path(p) for p in glob.glob(os.path.join(_PATH_REQUIREMENTS, "*.txt"))]
- common_args = {"path_dir": _PATH_REQUIREMENTS, "unfreeze": "none" if _FREEZE_REQUIREMENTS else "major"}
- extras = {
- p.stem: assistant.load_requirements(file_name=p.name, **common_args)
- for p in req_files
- if p.name not in ("docs.txt", "app.txt")
- }
- extras["extra"] = extras["cloud"] + extras["ui"] + extras["components"]
- extras["all"] = extras["extra"]
- extras["dev"] = extras["all"] + extras["test"] # + extras['docs']
- return extras
-
-
-def _setup_args() -> Dict[str, Any]:
- assistant = _load_assistant()
- about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py"))
- version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py"))
- long_description = assistant.load_readme_description(
- _PACKAGE_ROOT, homepage=about.__homepage__, version=version.version
- )
-
- # TODO: remove this once lightning-ui package is ready as a dependency
- ui_ver_file = os.path.join(_SOURCE_ROOT, "app-ui-version.info")
- if os.path.isfile(ui_ver_file):
- with open(ui_ver_file, encoding="utf-8") as fo:
- ui_version = fo.readlines()[0].strip()
- download_fe_version = {"version": ui_version}
- else:
- print(f"Missing file with FE version: {ui_ver_file}")
- download_fe_version = {}
- assistant._download_frontend(_PACKAGE_ROOT, **download_fe_version)
-
- return {
- "name": "lightning-app",
- "version": version.version,
- "description": about.__docs__,
- "author": about.__author__,
- "author_email": about.__author_email__,
- "url": about.__homepage__,
- "download_url": "https://github.com/Lightning-AI/lightning",
- "license": about.__license__,
- "packages": find_packages(where="src", include=["lightning_app", "lightning_app.*"]),
- "package_dir": {"": "src"},
- "long_description": long_description,
- "long_description_content_type": "text/markdown",
- "include_package_data": True,
- "zip_safe": False,
- "keywords": ["deep learning", "pytorch", "AI"],
- "python_requires": ">=3.8",
- "entry_points": {
- "console_scripts": [
- "lightning_app = lightning_app.cli.lightning_cli:main",
- ],
- },
- "setup_requires": [],
- "install_requires": assistant.load_requirements(
- _PATH_REQUIREMENTS, file_name="app.txt", unfreeze="none" if _FREEZE_REQUIREMENTS else "major"
- ),
- "extras_require": _prepare_extras(),
- "project_urls": {
- "Bug Tracker": "https://github.com/Lightning-AI/lightning/issues",
- "Documentation": "https://lightning.ai/lightning-docs",
- "Source Code": "https://github.com/Lightning-AI/lightning",
- },
- "classifiers": [
- "Environment :: Console",
- "Natural Language :: English",
- # How mature is this project? Common values are
- # 3 - Alpha, 4 - Beta, 5 - Production/Stable
- "Development Status :: 4 - Beta",
- # Indicate who your project is intended for
- "Intended Audience :: Developers",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Scientific/Engineering :: Information Analysis",
- # Pick your license as you wish
- # 'License :: OSI Approved :: BSD License',
- "Operating System :: OS Independent",
- # Specify the Python versions you support here. In particular, ensure
- # that you indicate whether you support Python 2, Python 3 or both.
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- ],
- }
diff --git a/src/lightning_app/__version__.py b/src/lightning_app/__version__.py
deleted file mode 100644
index 1491508baf4b3..0000000000000
--- a/src/lightning_app/__version__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import os
-
-_PACKAGE_ROOT = os.path.dirname(__file__)
-_VERSION_PATH = os.path.join(os.path.dirname(_PACKAGE_ROOT), "version.info")
-if not os.path.exists(_VERSION_PATH):
- # relevant for `bdist_wheel`
- _VERSION_PATH = os.path.join(_PACKAGE_ROOT, "version.info")
-with open(_VERSION_PATH, encoding="utf-8") as fo:
- version = fo.readlines()[0].strip()
diff --git a/src/lightning_app/py.typed b/src/lightning_app/py.typed
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/src/lightning_app/shell-folder_code-lives-lightning.info b/src/lightning_app/shell-folder_code-lives-lightning.info
deleted file mode 100644
index 16ef1072850f2..0000000000000
--- a/src/lightning_app/shell-folder_code-lives-lightning.info
+++ /dev/null
@@ -1,2 +0,0 @@
-This folder serves only for building the `lightning-app` package when you install from source code with env variable `PACKAGE_NAME=app`
-Please do not edit these files - you may see some as they are automatically generated/moved from their current location.
diff --git a/src/version.info b/src/version.info
index c1ffed065aef0..997b764c6e5ec 100644
--- a/src/version.info
+++ b/src/version.info
@@ -1 +1 @@
-2.3.0dev
+2.5.0.dev
diff --git a/tests/integrations_app/__init__.py b/tests/integrations_app/__init__.py
deleted file mode 100644
index a3c9eb29e7220..0000000000000
--- a/tests/integrations_app/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from os.path import dirname
-
-_PATH_TESTS_DIR = dirname(dirname(__file__))
diff --git a/tests/integrations_app/apps/collect_failures/__init__.py b/tests/integrations_app/apps/collect_failures/__init__.py
deleted file mode 100644
index b022c62adcdf3..0000000000000
--- a/tests/integrations_app/apps/collect_failures/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: check the tests to be able to run without this init file
diff --git a/tests/integrations_app/apps/collect_failures/app.py b/tests/integrations_app/apps/collect_failures/app.py
deleted file mode 100644
index 068a541aafbe0..0000000000000
--- a/tests/integrations_app/apps/collect_failures/app.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import logging
-import sys
-import time
-
-from lightning.app import LightningApp, LightningFlow, LightningWork
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-logger.addHandler(logging.StreamHandler(sys.stdout))
-
-
-class SimpleWork(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False, parallel=True, raise_exception=False)
- self.is_running_now = False
-
- def run(self):
- self.is_running_now = True
- print("work_is_running")
- for i in range(1, 10):
- time.sleep(1)
- if i % 5 == 0:
- raise Exception(f"invalid_value_of_i_{i}")
- print(f"good_value_of_i_{i}")
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.simple_work = SimpleWork()
-
- def run(self):
- print("useless_garbage_log_that_is_always_there_to_overload_logs")
- self.simple_work.run()
- if not self.simple_work.is_running_now:
- pass
- # work is not ready yet
- print("waiting_for_work_to_be_ready")
- else:
- print("flow_and_work_are_running")
- logger.info("logger_flow_work")
- time.sleep(0.1)
-
-
-if __name__ == "__main__":
- app = LightningApp(RootFlow(), log_level="debug")
diff --git a/tests/integrations_app/apps/collect_failures/requirements.txt b/tests/integrations_app/apps/collect_failures/requirements.txt
deleted file mode 100644
index 7800f0fad3fff..0000000000000
--- a/tests/integrations_app/apps/collect_failures/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-redis
diff --git a/tests/integrations_app/apps/core_features_app/__init__.py b/tests/integrations_app/apps/core_features_app/__init__.py
deleted file mode 100644
index b022c62adcdf3..0000000000000
--- a/tests/integrations_app/apps/core_features_app/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: check the tests to be able to run without this init file
diff --git a/tests/integrations_app/apps/core_features_app/app.py b/tests/integrations_app/apps/core_features_app/app.py
deleted file mode 100644
index db73ae21c0572..0000000000000
--- a/tests/integrations_app/apps/core_features_app/app.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import os
-
-from lightning.app.core import LightningApp, LightningFlow
-
-
-class EnvVarTestApp(LightningFlow):
- def __init__(self):
- super().__init__()
-
- def run(self):
- # these env vars are set here: tests/integrations_app/test_core_features_app.py:15
- assert os.getenv("FOO", "") == "bar"
- assert os.getenv("BLA", "") == "bloz"
- self.stop()
-
-
-app = LightningApp(EnvVarTestApp())
diff --git a/tests/integrations_app/apps/custom_work_dependencies/__init__.py b/tests/integrations_app/apps/custom_work_dependencies/__init__.py
deleted file mode 100644
index b022c62adcdf3..0000000000000
--- a/tests/integrations_app/apps/custom_work_dependencies/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: check the tests to be able to run without this init file
diff --git a/tests/integrations_app/apps/custom_work_dependencies/app.py b/tests/integrations_app/apps/custom_work_dependencies/app.py
deleted file mode 100644
index 27f2d3f3ca876..0000000000000
--- a/tests/integrations_app/apps/custom_work_dependencies/app.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import os
-
-from lightning.app import BuildConfig, CloudCompute, LightningApp, LightningFlow, LightningWork
-
-
-class CustomBuildConfig(BuildConfig):
- def build_commands(self):
- return ["sudo apt update", "sudo apt install redis", "pip install lmdb"]
-
-
-class WorkWithCustomDeps(LightningWork):
- def __init__(self, cloud_compute: CloudCompute = CloudCompute(), **kwargs):
- build_config = CustomBuildConfig(requirements=["py"])
- super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute, cloud_build_config=build_config)
-
- def run(self):
- # installed by the build commands and by requirements in the build config
- import lmdb
-
- print("installed lmdb version:", lmdb.__version__)
-
-
-class WorkWithCustomBaseImage(LightningWork):
- def __init__(self, cloud_compute: CloudCompute = CloudCompute(), **kwargs):
- # this image has been created from ghcr.io/gridai/base-images:v1.8-cpu
- # by just adding an empty file at /content/.e2e_test
- image_tag = os.getenv("LIGHTNING_E2E_TEST_IMAGE_VERSION", "v1.29")
- custom_image = f"ghcr.io/gridai/image-for-testing-custom-images-in-e2e:{image_tag}"
- build_config = BuildConfig(image=custom_image)
- super().__init__(parallel=True, **kwargs, cloud_compute=cloud_compute, cloud_build_config=build_config)
-
- def run(self):
- # checking the existence of the file - this file had been added to the custom base image
- assert ".e2e_test" in os.listdir("/testdir/"), "file not found"
-
-
-class CustomWorkBuildConfigChecker(LightningFlow):
- def run(self):
- # create dynamically the work at runtime
- if not hasattr(self, "work1"):
- self.work1 = WorkWithCustomDeps()
- if not hasattr(self, "work2"):
- self.work2 = WorkWithCustomBaseImage()
-
- self.work1.run()
- self.work2.run()
-
- if self.work1.has_succeeded and self.work2.has_succeeded:
- print("--- Custom Work Dependency checker End ----")
- self.stop()
-
-
-app = LightningApp(CustomWorkBuildConfigChecker())
diff --git a/tests/integrations_app/apps/idle_timeout/__init__.py b/tests/integrations_app/apps/idle_timeout/__init__.py
deleted file mode 100644
index b022c62adcdf3..0000000000000
--- a/tests/integrations_app/apps/idle_timeout/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# TODO: check the tests to be able to run without this init file
diff --git a/tests/integrations_app/apps/idle_timeout/app.py b/tests/integrations_app/apps/idle_timeout/app.py
deleted file mode 100644
index d33df0a616d58..0000000000000
--- a/tests/integrations_app/apps/idle_timeout/app.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import pathlib
-
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork
-from lightning.app.storage.path import _artifacts_path, _filesystem
-from lightning.app.utilities.enum import WorkStageStatus
-
-
-class SourceFileWriterWork(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False, parallel=True, cloud_compute=CloudCompute(idle_timeout=5))
- self.counter = 0
- self.value = None
- self.path = None
-
- def run(self):
- self.path = "lit://boring_file.txt"
- with open(self.path, "w") as f:
- f.write("path")
- self.counter += 1
-
-
-class DestinationWork(LightningWork):
- def run(self, path):
- assert path.exists()
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.make_check = True
- self.work = SourceFileWriterWork()
- self.dest_work = DestinationWork(parallel=True)
-
- def run(self):
- if self.work.counter == 0:
- self.work.run()
-
- elif self.work.status.stage == WorkStageStatus.STOPPED and self.make_check:
- succeeded_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.SUCCEEDED]
- # Ensure the work succeeded at some point
- assert len(succeeded_statuses) > 0
- succeeded_status = succeeded_statuses[-1]
-
- stopped_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.STOPPED]
-
- # We want to check that the work started shutting down withing the required timeframe, so we take the first
- # status that has `stage == STOPPED`.
- stopped_status = stopped_statuses[0]
-
- # Note: Account for the controlplane, k8s, SIGTERM handler delays.
- assert (stopped_status.timestamp - succeeded_status.timestamp) < 20
-
- fs = _filesystem()
- destination_path = _artifacts_path(self.work) / pathlib.Path(*self.work.path.resolve().parts[1:])
- assert fs.exists(destination_path)
- self.dest_work.run(self.work.path)
- self.make_check = False
- print("Successfully stopped SourceFileWriterWork.")
-
- if self.dest_work.status.stage == WorkStageStatus.SUCCEEDED:
- print("Stopping work")
- self.dest_work.stop()
-
- if self.dest_work.status.stage == WorkStageStatus.STOPPED:
- print(self.dest_work.statuses)
- print("Application End")
- self.stop()
-
-
-app = LightningApp(RootFlow(), log_level="debug")
diff --git a/tests/integrations_app/conftest.py b/tests/integrations_app/conftest.py
deleted file mode 100644
index dec2f12c7ffe4..0000000000000
--- a/tests/integrations_app/conftest.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import contextlib
-import os
-import shutil
-import threading
-from subprocess import Popen
-
-import psutil
-import pytest
-from lightning.app.storage.path import _storage_root_dir
-from lightning.app.utilities.component import _set_context
-from lightning.app.utilities.packaging import cloud_compute
-from lightning.app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
-from lightning.app.utilities.state import AppState
-
-from integrations_app.public import _PATH_EXAMPLES
-
-GITHUB_APP_URLS = {
- "template_react_ui": "https://github.com/Lightning-AI/lightning-template-react.git",
-}
-
-os.environ["LIGHTNING_DISPATCHED"] = "1"
-
-
-def pytest_sessionstart(*_):
- """Pytest hook that get called after the Session object has been created and before performing collection and
- entering the run test loop."""
- for name, url in GITHUB_APP_URLS.items():
- app_path = _PATH_EXAMPLES / name
- if not os.path.exists(app_path):
- Popen(["git", "clone", url, name], cwd=_PATH_EXAMPLES).wait(timeout=90)
- else:
- Popen(["git", "pull", "main"], cwd=app_path).wait(timeout=90)
-
-
-def pytest_sessionfinish(session, exitstatus):
- """Pytest hook that get called after whole test run finished, right before returning the exit status to the
- system."""
- # kill all the processes and threads created by parent
- # TODO this isn't great. We should have each tests doing it's own cleanup
- current_process = psutil.Process()
- for child in current_process.children(recursive=True):
- with contextlib.suppress(psutil.NoSuchProcess):
- params = child.as_dict() or {}
- cmd_lines = params.get("cmdline", [])
- # we shouldn't kill the resource tracker from multiprocessing. If we do,
- # `atexit` will throw as it uses resource tracker to try to clean up
- if cmd_lines and "resource_tracker" in cmd_lines[-1]:
- continue
- child.kill()
-
- main_thread = threading.current_thread()
- for t in threading.enumerate():
- if t is not main_thread:
- t.join(0)
-
-
-@pytest.fixture(autouse=True)
-def cleanup():
- from lightning.app.utilities.app_helpers import _LightningAppRef
-
- yield
- _LightningAppRef._app_instance = None
- shutil.rmtree("./storage", ignore_errors=True)
- shutil.rmtree(_storage_root_dir(), ignore_errors=True)
- shutil.rmtree("./.shared", ignore_errors=True)
- if os.path.isfile(_APP_CONFIG_FILENAME):
- os.remove(_APP_CONFIG_FILENAME)
- _set_context(None)
-
-
-@pytest.fixture(autouse=True)
-def clear_app_state_state_variables():
- """Resets global variables in order to prevent interference between tests."""
- yield
- import lightning.app.utilities.state
-
- lightning.app.utilities.state._STATE = None
- lightning.app.utilities.state._LAST_STATE = None
- AppState._MY_AFFILIATION = ()
- if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
- cloud_compute._CLOUD_COMPUTE_STORE.clear()
diff --git a/tests/integrations_app/flagship/__init__.py b/tests/integrations_app/flagship/__init__.py
deleted file mode 100644
index be56fdce3190d..0000000000000
--- a/tests/integrations_app/flagship/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import os.path
-
-from integrations_app import _PATH_TESTS_DIR
-
-_PATH_INTEGRATIONS_DIR = os.path.join(_PATH_TESTS_DIR, "_flagship-app")
diff --git a/tests/integrations_app/flagship/test_flashy.py b/tests/integrations_app/flagship/test_flashy.py
deleted file mode 100644
index 429968defa9b7..0000000000000
--- a/tests/integrations_app/flagship/test_flashy.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import contextlib
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-from lightning.app.utilities.imports import _is_playwright_available
-
-from integrations_app.flagship import _PATH_INTEGRATIONS_DIR
-
-if _is_playwright_available():
- import playwright
- from playwright.sync_api import Page, expect
-
-
-# TODO: when this function is moved to the app itself we can just import it, so to keep better aligned
-def validate_app_functionalities(app_page: "Page") -> None:
- """Validate the page after app starts.
-
- this is direct copy-paste of validation living in the app repository:
- https://github.com/Lightning-AI/LAI-Flashy-App/blob/main/tests/test_app_gallery.py#L205
-
- app_page: The UI page of the app to be validated.
-
- """
- while True:
- with contextlib.suppress(playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
- app_page.reload()
- sleep(5)
- app_label = app_page.frame_locator("iframe").locator("text=Choose your AI task")
- app_label.wait_for(timeout=30 * 1000)
- break
-
- input_field = app_page.frame_locator("iframe").locator('input:below(:text("Data URL"))').first
- input_field.wait_for(timeout=1000)
- input_field.type("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
- sleep(1)
- upload_btn = app_page.frame_locator("iframe").locator('button:has-text("Upload")')
- upload_btn.wait_for(timeout=1000)
- upload_btn.click()
-
- sleep(10)
-
- train_folder_dropdown = app_page.frame_locator("iframe").locator("#mui-2")
- train_folder_dropdown.click()
-
- train_folder = app_page.frame_locator("iframe").locator('text="hymenoptera_data/train"')
- train_folder.scroll_into_view_if_needed()
- train_folder.click()
-
- val_folder_dropdown = app_page.frame_locator("iframe").locator("#mui-3")
- val_folder_dropdown.click()
-
- val_folder = app_page.frame_locator("iframe").locator('text="hymenoptera_data/val"')
- val_folder.scroll_into_view_if_needed()
- val_folder.click()
-
- train_btn = app_page.frame_locator("iframe").locator('button:has-text("Start training!")')
- train_btn.click()
-
- # Sometimes the results don't show until we refresh the page
- sleep(10)
-
- app_page.reload()
-
- app_page.frame_locator("iframe").locator('button:has-text("RESULTS")').click()
- runs = app_page.frame_locator("iframe").locator("table tbody tr")
- expect(runs).to_have_count(1, timeout=120000)
-
-
-@pytest.mark.cloud()
-def test_app_cloud() -> None:
- with run_app_in_cloud(_PATH_INTEGRATIONS_DIR) as (_, view_page, _, _):
- validate_app_functionalities(view_page)
diff --git a/tests/integrations_app/flagship/test_jupyter.py b/tests/integrations_app/flagship/test_jupyter.py
deleted file mode 100644
index 5d33698040b43..0000000000000
--- a/tests/integrations_app/flagship/test_jupyter.py
+++ /dev/null
@@ -1 +0,0 @@
-# This is placeholder and the reals file ATM is being copied from the original repo/project
diff --git a/tests/integrations_app/flagship/test_muse.py b/tests/integrations_app/flagship/test_muse.py
deleted file mode 100644
index 5d33698040b43..0000000000000
--- a/tests/integrations_app/flagship/test_muse.py
+++ /dev/null
@@ -1 +0,0 @@
-# This is placeholder and the reals file ATM is being copied from the original repo/project
diff --git a/tests/integrations_app/local/__init__.py b/tests/integrations_app/local/__init__.py
deleted file mode 100644
index 1e7d17cc6b536..0000000000000
--- a/tests/integrations_app/local/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from pathlib import Path
-
-_PATH_APPS = Path(__file__).resolve().parents[1] / "apps"
diff --git a/tests/integrations_app/local/test_collect_failures.py b/tests/integrations_app/local/test_collect_failures.py
deleted file mode 100644
index ca11b7528bd40..0000000000000
--- a/tests/integrations_app/local/test_collect_failures.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.local import _PATH_APPS
-
-
-@pytest.mark.cloud()
-def test_collect_failures_example_cloud() -> None:
- # logs are in order
- expected_logs = [
- "useless_garbage_log_that_is_always_there_to_overload_logs",
- "waiting_for_work_to_be_ready",
- "work_is_running",
- "flow_and_work_are_running",
- "logger_flow_work",
- "good_value_of_i_1",
- "good_value_of_i_2",
- "good_value_of_i_3",
- "good_value_of_i_4",
- "invalid_value_of_i_5",
- ]
- with run_app_in_cloud(os.path.join(_PATH_APPS, "collect_failures")) as (
- _,
- _,
- fetch_logs,
- _,
- ):
- last_found_log_index = -1
- while len(expected_logs) != 0:
- for index, log in enumerate(fetch_logs()):
- if expected_logs[0] in log:
- print(f"found expected log: {expected_logs[0]}")
- expected_logs.pop(0)
- assert index > last_found_log_index
- if len(expected_logs) == 0:
- break
- sleep(1)
diff --git a/tests/integrations_app/local/test_core_features_app.py b/tests/integrations_app/local/test_core_features_app.py
deleted file mode 100644
index 49f65fdc73a35..0000000000000
--- a/tests/integrations_app/local/test_core_features_app.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import os
-
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import run_app
-
-from integrations_app.local import _PATH_APPS
-
-
-def test_core_features_app_example():
- runner = CliRunner()
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PATH_APPS, "core_features_app", "app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- "--env", # this is to test env variable
- "FOO=bar",
- "--env",
- "BLA=bloz",
- ],
- catch_exceptions=False,
- )
- assert result.exit_code == 0
diff --git a/tests/integrations_app/local/test_custom_work_dependencies.py b/tests/integrations_app/local/test_custom_work_dependencies.py
deleted file mode 100644
index b7867e2c12a02..0000000000000
--- a/tests/integrations_app/local/test_custom_work_dependencies.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.local import _PATH_APPS
-
-
-@pytest.mark.cloud()
-def test_custom_work_dependencies_example_cloud() -> None:
- # if requirements not installed, the app will fail
- with run_app_in_cloud(
- os.path.join(_PATH_APPS, "custom_work_dependencies"),
- app_name="app.py",
- ) as (_, _, fetch_logs, _):
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["flow"]):
- if "Custom Work Dependency checker End" in log:
- has_logs = True
- print(log)
- sleep(1)
diff --git a/tests/integrations_app/local/test_idle_timeout.py b/tests/integrations_app/local/test_idle_timeout.py
deleted file mode 100644
index 9314cc8b8a99a..0000000000000
--- a/tests/integrations_app/local/test_idle_timeout.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.local import _PATH_APPS
-
-
-@pytest.mark.cloud()
-def test_idle_timeout_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_APPS, "idle_timeout")) as (
- _,
- _,
- fetch_logs,
- _,
- ):
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["flow"]):
- if "Application End" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/__init__.py b/tests/integrations_app/public/__init__.py
deleted file mode 100644
index 5ad0242e64c01..0000000000000
--- a/tests/integrations_app/public/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from pathlib import Path
-
-_PATH_EXAMPLES = Path(__file__).resolve().parents[3] / "examples" / "app"
diff --git a/tests/integrations_app/public/test_app_dag.py b/tests/integrations_app/public/test_app_dag.py
deleted file mode 100644
index 8145f1266a138..0000000000000
--- a/tests/integrations_app/public/test_app_dag.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_app_dag_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "dag")) as (_, _, fetch_logs, _):
- launch_log, finish_log = False, False
- while not (launch_log and finish_log):
- for log in fetch_logs(["flow"]):
- if "Launching a new DAG" in log:
- launch_log = True
- elif "Finished training and evaluating" in log:
- finish_log = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_argparse.py b/tests/integrations_app/public/test_argparse.py
deleted file mode 100644
index 80cb7103ad23c..0000000000000
--- a/tests/integrations_app/public/test_argparse.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os
-import sys
-
-from lightning.app.testing.testing import application_testing
-from lightning.app.utilities.load_app import _patch_sys_argv
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-def test_app_argparse_example():
- original_argv = sys.argv
-
- command_line = [
- os.path.join(_PATH_EXAMPLES, "argparse", "app.py"),
- "--app_args",
- "--use_gpu",
- "--without-server",
- ]
- result = application_testing(command_line=command_line)
- assert result.exit_code == 0, result.__dict__
- assert sys.argv == original_argv
-
-
-def test_patch_sys_argv():
- original_argv = sys.argv
-
- sys.argv = expected = ["lightning", "run", "app", "app.py"]
- with _patch_sys_argv():
- assert sys.argv == ["app.py"]
-
- assert sys.argv == expected
-
- sys.argv = expected = ["lightning", "run", "app", "app.py", "--without-server", "--env", "name=something"]
- with _patch_sys_argv():
- assert sys.argv == ["app.py"]
-
- assert sys.argv == expected
-
- sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args"]
- with _patch_sys_argv():
- assert sys.argv == ["app.py"]
-
- assert sys.argv == expected
-
- sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args", "--env", "name=something"]
- with _patch_sys_argv():
- assert sys.argv == ["app.py"]
-
- assert sys.argv == expected
-
- sys.argv = expected = [
- "lightning",
- "run",
- "app",
- "app.py",
- "--without-server",
- "--app_args",
- "--use_gpu",
- "--name=hello",
- "--env",
- "name=something",
- ]
- with _patch_sys_argv():
- assert sys.argv == ["app.py", "--use_gpu", "--name=hello"]
-
- assert sys.argv == expected
-
- sys.argv = original_argv
diff --git a/tests/integrations_app/public/test_boring_app.py b/tests/integrations_app/public/test_boring_app.py
deleted file mode 100644
index 8bfb1d4daf999..0000000000000
--- a/tests/integrations_app/public/test_boring_app.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import os
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import show
-from lightning.app.testing.testing import run_app_in_cloud, wait_for
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_boring_app_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "boring"), app_name="app_dynamic.py", debug=True) as (
- _,
- view_page,
- _,
- name,
- ):
-
- def check_hello_there(*_, **__):
- locator = view_page.frame_locator("iframe").locator('ul:has-text("Hello there!")')
- if len(locator.all_text_contents()):
- return True
- return None
-
- wait_for(view_page, check_hello_there)
-
- runner = CliRunner()
- result = runner.invoke(show.commands["logs"], [name])
-
- assert result.exit_code == 0
- assert result.exception is None
- # TODO: Resolve
- # lines = result.output.splitlines()
- # assert any("Received from root.dict.dst_w" in line for line in lines)
- print("Succeeded App!")
diff --git a/tests/integrations_app/public/test_commands_and_api.py b/tests/integrations_app/public/test_commands_and_api.py
deleted file mode 100644
index 84ac4e6814ee6..0000000000000
--- a/tests/integrations_app/public/test_commands_and_api.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import os
-from subprocess import Popen
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.timeout(300)
-@pytest.mark.cloud()
-def test_commands_and_api_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "commands_and_api")) as (
- _,
- view_page,
- fetch_logs,
- app_name,
- ):
- # Connect to the App and send the first & second command with the client
- # Requires to be run within the same process.
- cmd_1 = f"python -m lightning connect app {app_name}"
- cmd_2 = "python -m lightning command with client --name=this"
- cmd_3 = "python -m lightning command without client --name=is"
- cmd_4 = "python -m lightning command without client --name=awesome"
- cmd_5 = "lightning logout"
- process = Popen(" && ".join([cmd_1, cmd_2, cmd_3, cmd_4, cmd_5]), shell=True)
- process.wait()
- "/".join(view_page.url.split("/")[:-2])
-
- # Validate the logs.
- has_logs = False
- while not has_logs:
- for log in fetch_logs():
- if "['this', 'is', 'awesome']" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_drive.py b/tests/integrations_app/public/test_drive.py
deleted file mode 100644
index 7401b595c55fe..0000000000000
--- a/tests/integrations_app/public/test_drive.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_drive_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "drive")) as (
- _,
- _,
- fetch_logs,
- _,
- ):
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["flow"]):
- if "Application End!" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_gradio.py b/tests/integrations_app/public/test_gradio.py
deleted file mode 100644
index 219f720223dfc..0000000000000
--- a/tests/integrations_app/public/test_gradio.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import os
-from unittest import mock
-from unittest.mock import ANY
-
-
-@mock.patch.dict(os.environ, {"LIGHTING_TESTING": "1"})
-@mock.patch("lightning.app.components.serve.gradio_server.gradio")
-def test_serve_gradio(gradio_mock):
- from lightning.app.components.serve.gradio_server import ServeGradio
-
- class MyGradioServe(ServeGradio):
- inputs = gradio_mock.inputs.Image(type="pil")
- outputs = gradio_mock.outputs.Image(type="pil")
- examples = [["./examples/app/components/serve/gradio/beyonce.png"]]
-
- def build_model(self):
- super().build_model()
- return "model"
-
- def predict(self, *args, **kwargs):
- super().predict(*args, **kwargs)
- return "prediction"
-
- comp = MyGradioServe()
- comp.run()
- assert comp.model == "model"
- assert comp.predict() == "prediction"
- gradio_mock.Interface.assert_called_once_with(
- fn=ANY, inputs=ANY, outputs=ANY, examples=ANY, title=None, description=None, theme=ANY
- )
diff --git a/tests/integrations_app/public/test_installation_commands_app.py b/tests/integrations_app/public/test_installation_commands_app.py
deleted file mode 100644
index 1ad4ec55eb4ce..0000000000000
--- a/tests/integrations_app/public/test_installation_commands_app.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import os
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_installation_commands_app_example_cloud() -> None:
- # This is expected to pass, since the "setup" flag is passed
- with run_app_in_cloud(
- os.path.join(_PATH_EXAMPLES, "installation_commands"),
- app_name="app.py",
- extra_args=["--setup"],
- debug=True,
- ) as (_, _, fetch_logs, _):
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["work"]):
- if "lmdb successfully installed" in log:
- has_logs = True
diff --git a/tests/integrations_app/public/test_layout.py b/tests/integrations_app/public/test_layout.py
deleted file mode 100644
index c86016c031c47..0000000000000
--- a/tests/integrations_app/public/test_layout.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import os
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import run_app
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.xfail(strict=False, reason="test is skipped because CI was blocking all the PRs.")
-def test_layout_example():
- runner = CliRunner()
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PATH_EXAMPLES, "layout", "app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ],
- catch_exceptions=False,
- )
- assert "Layout End" in str(result.stdout_bytes)
- assert result.exit_code == 0
diff --git a/tests/integrations_app/public/test_multi_node.py b/tests/integrations_app/public/test_multi_node.py
deleted file mode 100644
index ec4597138c8c6..0000000000000
--- a/tests/integrations_app/public/test_multi_node.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import os
-from unittest import mock
-
-import pytest
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.testing.testing import LightningTestApp, application_testing
-from lightning_utilities.core.imports import package_available
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-class LightningTestMultiNodeApp(LightningTestApp):
- def on_before_run_once(self):
- res = super().on_before_run_once()
- if self.works and all(w.has_stopped for w in self.works):
- assert len(self.works) == 2
- return True
- return res
-
-
-# for the skip to work, the package needs to be installed without editable mode
-_SKIP_LIGHTNING_UNAVAILABLE = pytest.mark.skipif(not package_available("lightning"), reason="script requires lightning")
-
-
-@pytest.mark.parametrize(
- "app_name",
- [
- "train_pytorch.py",
- "train_any.py",
- "train_pytorch_spawn.py",
- pytest.param("train_fabric.py", marks=_SKIP_LIGHTNING_UNAVAILABLE),
- pytest.param("train_lt_script.py", marks=_SKIP_LIGHTNING_UNAVAILABLE),
- pytest.param("train_lt.py", marks=_SKIP_LIGHTNING_UNAVAILABLE),
- ],
-)
-@_RunIf(skip_windows=True) # flaky
-@mock.patch("lightning.app.components.multi_node.base.is_running_in_cloud", return_value=True)
-def test_multi_node_examples(_, app_name, monkeypatch):
- # note: this test will fail locally:
- # * if you installed `lightning.app`, then the examples need to be
- # rewritten to use `lightning.app` imports (CI does this)
- # * if you installed `lightning`, then the imports in this file and mocks
- # need to be changed to use `lightning`.
- monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "multi_node"))
- command_line = [app_name, "--blocking", "False", "--open-ui", "False", "--setup"]
- result = application_testing(LightningTestMultiNodeApp, command_line)
- assert result.exit_code == 0
diff --git a/tests/integrations_app/public/test_payload.py b/tests/integrations_app/public/test_payload.py
deleted file mode 100644
index d7ab5d43e34ef..0000000000000
--- a/tests/integrations_app/public/test_payload.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_payload_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "payload")) as (_, _, fetch_logs, _):
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["flow"]):
- if "Application End!" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_pickle_or_not.py b/tests/integrations_app/public/test_pickle_or_not.py
deleted file mode 100644
index 5d94ff657d819..0000000000000
--- a/tests/integrations_app/public/test_pickle_or_not.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import os
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import run_app
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-# TODO: Investigate why it doesn't work
-@pytest.mark.xfail(strict=False, reason="test has been ignored for a while and seems not to be working :(")
-def test_pickle_or_not_example():
- runner = CliRunner()
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PATH_EXAMPLES, "pickle_or_not", "app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ],
- catch_exceptions=False,
- )
- assert "Pickle or Not End" in str(result.stdout_bytes)
- assert result.exit_code == 0
diff --git a/tests/integrations_app/public/test_quick_start.py b/tests/integrations_app/public/test_quick_start.py
deleted file mode 100644
index ed1fb399e5cc7..0000000000000
--- a/tests/integrations_app/public/test_quick_start.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import logging
-import os
-from unittest import mock
-
-import pytest
-from click.testing import CliRunner
-from lightning.app import LightningApp
-from lightning.app.cli.lightning_cli import run_app
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.testing.testing import run_app_in_cloud, wait_for
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-class QuickStartApp(LightningApp):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.root.serve_work._parallel = True
-
- def run_once(self):
- done = super().run_once()
- if self.root.train_work.best_model_path:
- return True
- return done
-
-
-# TODO: Investigate why it doesn't work
-@pytest.mark.xfail(strict=False, reason="test is skipped because CI was blocking all the PRs.")
-@_RunIf(pl=True, skip_windows=True, skip_linux=True)
-def test_quick_start_example(caplog, monkeypatch):
- """This test ensures the Quick Start example properly train and serve PyTorch Lightning."""
- monkeypatch.setattr("logging.getLogger", mock.MagicMock(return_value=logging.getLogger()))
-
- with caplog.at_level(logging.INFO):
- with mock.patch("lightning.app.LightningApp", QuickStartApp):
- runner = CliRunner()
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PATH_EXAMPLES, "lightning-quick-start", "app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ],
- catch_exceptions=False,
- )
- assert result.exit_code == 0
-
-
-@pytest.mark.cloud()
-def test_quick_start_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "lightning-quick-start")) as (_, view_page, _, _):
-
- def click_gradio_demo(*_, **__):
- button = view_page.locator('button:has-text("Interactive demo")')
- button.wait_for(timeout=5 * 1000)
- button.click()
- return True
-
- wait_for(view_page, click_gradio_demo)
-
- def check_examples(*_, **__):
- locator = view_page.frame_locator("iframe").locator('button:has-text("Submit")')
- locator.wait_for(timeout=10 * 1000)
- if len(locator.all_text_contents()) > 0:
- return True
- return None
-
- wait_for(view_page, check_examples)
diff --git a/tests/integrations_app/public/test_scripts.py b/tests/integrations_app/public/test_scripts.py
deleted file mode 100644
index 0a777784690d3..0000000000000
--- a/tests/integrations_app/public/test_scripts.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import os
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import run_app
-from lightning.app.testing.helpers import _run_script, _RunIf
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@_RunIf(pl=True, skip_windows=True)
-@pytest.mark.parametrize(
- "file",
- [
- pytest.param("component_tracer.py"),
- pytest.param("component_popen.py"),
- ],
-)
-def test_scripts(file):
- _run_script(str(os.path.join(_PATH_EXAMPLES, f"components/python/{file}")))
-
-
-@pytest.mark.xfail(strict=False, reason="causing some issues with CI, not sure if the test is actually needed")
-@_RunIf(pl=True, skip_windows=True)
-def test_components_app_example():
- runner = CliRunner()
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PATH_EXAMPLES, "components/python/app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ],
- catch_exceptions=False,
- )
- assert result.exit_code == 0
- assert "tracer script succeed" in result.stdout
diff --git a/tests/integrations_app/public/test_template_react_ui.py b/tests/integrations_app/public/test_template_react_ui.py
deleted file mode 100644
index e9653254df27d..0000000000000
--- a/tests/integrations_app/public/test_template_react_ui.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud, wait_for
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_template_react_ui_example_cloud() -> None:
- """This test ensures streamlit works in the cloud by clicking a button and checking the logs."""
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "template_react_ui")) as (
- _,
- view_page,
- fetch_logs,
- _,
- ):
-
- def click_button(*_, **__):
- button = view_page.frame_locator("iframe").locator('button:has-text("Start Printing")')
- button.wait_for(timeout=3 * 1000)
- if button.all_text_contents() == ["Start Printing"]:
- button.click()
- return True
- return None
-
- wait_for(view_page, click_button)
-
- has_logs = False
- while not has_logs:
- for log in fetch_logs():
- if "0: Hello World!" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_template_streamlit_ui.py b/tests/integrations_app/public/test_template_streamlit_ui.py
deleted file mode 100644
index 34d9b683bc41d..0000000000000
--- a/tests/integrations_app/public/test_template_streamlit_ui.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from time import sleep
-
-import pytest
-from lightning.app.testing.testing import run_app_in_cloud, wait_for
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-@pytest.mark.cloud()
-def test_template_streamlit_ui_example_cloud() -> None:
- """This test ensures streamlit works in the cloud by clicking a button and checking the logs."""
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "template_streamlit_ui")) as (
- _,
- view_page,
- fetch_logs,
- _,
- ):
-
- def click_button(*_, **__):
- button = view_page.frame_locator("iframe").locator('button:has-text("Should print to the terminal ?")')
-
- if button.all_text_contents() == ["Should print to the terminal ?"]:
- button.click()
- return True
- return None
-
- wait_for(view_page, click_button)
-
- has_logs = False
- while not has_logs:
- for log in fetch_logs():
- if "Hello World!" in log:
- has_logs = True
- sleep(1)
diff --git a/tests/integrations_app/public/test_v0_app.py b/tests/integrations_app/public/test_v0_app.py
deleted file mode 100644
index a880d52ceb694..0000000000000
--- a/tests/integrations_app/public/test_v0_app.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import os
-from time import sleep
-from typing import Tuple
-from unittest import mock
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app import LightningApp
-from lightning.app.runners import CloudRuntime
-from lightning.app.testing import EmptyFlow
-from lightning.app.testing.testing import LightningTestApp, application_testing, run_app_in_cloud, wait_for
-from lightning.app.utilities.enum import AppStage
-from lightning.app.utilities.load_app import load_app_from_file
-
-from integrations_app.public import _PATH_EXAMPLES
-
-
-class LightningAppTestInt(LightningTestApp):
- def run_once(self) -> Tuple[bool, float]:
- if self.root.counter == 1:
- print("V0 App End")
- self.stage = AppStage.STOPPING
- return True, 0.0
- return super().run_once()
-
-
-def test_v0_app_example():
- command_line = [
- os.path.join(_PATH_EXAMPLES, "v0", "app.py"),
- "--blocking",
- "False",
- "--open-ui",
- "False",
- ]
- result = application_testing(LightningAppTestInt, command_line)
- assert result.exit_code == 0
-
-
-def run_v0_app(fetch_logs, view_page):
- def check_content(button_name, text_content):
- button = view_page.locator(f'button:has-text("{button_name}")')
- button.wait_for(timeout=3 * 1000)
- button.click()
- view_page.reload()
- locator = view_page.frame_locator("iframe").locator("div")
- locator.wait_for(timeout=3 * 1000)
- assert text_content in " ".join(locator.all_text_contents())
- print(f"Validated {button_name}")
- return True
-
- wait_for(view_page, check_content, "TAB_1", "Hello from component A")
- wait_for(view_page, check_content, "TAB_2", "Hello from component B")
- has_logs = False
- while not has_logs:
- for log in fetch_logs(["flow"]):
- print(log)
- if "'a': 'a', 'b': 'b'" in log:
- has_logs = True
- sleep(1)
-
-
-@pytest.mark.cloud()
-@pytest.mark.skipif(
- os.environ.get("LIGHTNING_BYOC_CLUSTER_ID") is None,
- reason="missing LIGHTNING_BYOC_CLUSTER_ID environment variable",
-)
-def test_v0_app_example_byoc_cloud() -> None:
- with run_app_in_cloud(
- os.path.join(_PATH_EXAMPLES, "v0"),
- extra_args=["--cluster-id", os.environ.get("LIGHTNING_BYOC_CLUSTER_ID")],
- ) as (_, view_page, fetch_logs, _):
- run_v0_app(fetch_logs, view_page)
-
-
-@pytest.mark.cloud()
-def test_v0_app_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "v0")) as (
- _,
- view_page,
- fetch_logs,
- _,
- ):
- run_v0_app(fetch_logs, view_page)
-
-
-@mock.patch(
- "lightning.app.runners.cloud.load_app_from_file",
- MagicMock(side_effect=ModuleNotFoundError("Module X not found")),
-)
-def test_load_app_from_file_module_error():
- empty_app = CloudRuntime.load_app_from_file(os.path.join(_PATH_EXAMPLES, "v0", "app.py"))
- assert isinstance(empty_app, LightningApp)
- assert isinstance(empty_app.root, EmptyFlow)
-
-
-def test_load_app_from_file():
- app = load_app_from_file(os.path.join(_PATH_EXAMPLES, "v0", "app.py"))
- assert isinstance(app, LightningApp)
diff --git a/tests/legacy/back-compatible-versions.txt b/tests/legacy/back-compatible-versions.txt
index 1d8e1abccfdd1..eb49457dfa157 100644
--- a/tests/legacy/back-compatible-versions.txt
+++ b/tests/legacy/back-compatible-versions.txt
@@ -98,3 +98,9 @@
2.1.3
2.2.0.post0
2.2.1
+2.2.2
+2.2.5
+2.3.0
+2.3.1
+2.3.2
+2.3.3
diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh
index 0de781b0c47c7..0aa0bacff168a 100755
--- a/tests/run_standalone_tests.sh
+++ b/tests/run_standalone_tests.sh
@@ -26,18 +26,17 @@ export PL_RUN_STANDALONE_TESTS=1
defaults=" -m coverage run --source ${source} --append -m pytest --no-header -v -s --timeout 120 "
echo "Using defaults: ${defaults}"
-# get the testing location as the fist argument
+# get the testing location as the first argument
test_path=$1
printf "source path: $test_path\n"
# collect all tests with parametrization based filtering with PL_RUN_STANDALONE_TESTS
standalone_tests=$(python3 -m pytest $test_path -q --collect-only --pythonwarnings ignore)
-printf "Collected tests: \n $standalone_tests"
+printf "Collected tests: \n $standalone_tests\n"
# match only lines with tests
-parametrizations=$(grep -oP '\S+::test_\S+' <<< "$standalone_tests")
+parametrizations=$(perl -nle 'print $& while m{\S+::test_\S+}g' <<< "$standalone_tests")
# convert the list to be array
parametrizations_arr=($parametrizations)
-
report=''
rm -f standalone_test_output.txt # in case it exists, remove it
@@ -47,7 +46,7 @@ function show_batched_output {
if [ -f standalone_test_output.txt ]; then # if exists
cat standalone_test_output.txt
# heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail
- if grep -iE 'error|exception|traceback|failed' standalone_test_output.txt | grep -vE 'on_exception|xfailed' | grep -qv -f testnames.txt; then
+ if perl -nle 'print if /error|(?Hello from component A
diff --git a/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html b/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html
deleted file mode 100644
index 3bfd9e24cb7f7..0000000000000
--- a/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html
+++ /dev/null
@@ -1 +0,0 @@
-
Hello from component B
diff --git a/tests/tests_app/cli/test_cd.py b/tests/tests_app/cli/test_cd.py
deleted file mode 100644
index 64865fa4ac46b..0000000000000
--- a/tests/tests_app/cli/test_cd.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import os
-import sys
-from unittest import mock
-
-import pytest
-from lightning.app.cli.commands import cd
-from lightning.app.cli.commands.pwd import pwd
-
-
-@mock.patch("lightning.app.cli.commands.cd.ls", mock.MagicMock())
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cd(monkeypatch):
- """This test validates cd behaves as expected."""
- ls = mock.MagicMock()
- monkeypatch.setattr(cd, "ls", ls)
-
- assert cd.cd("/") == "/"
- assert pwd() == "/"
- ls.ls.return_value = ["hero"]
- assert cd.cd("hero") == "/hero"
- assert pwd() == "/hero"
- ls.ls.return_value = ["something_else"]
- assert f"/hero{os.sep}something_else" == cd.cd("something_else")
- assert f"/hero{os.sep}something_else" == pwd()
- ls.ls.return_value = ["hello"]
- assert f"/hero{os.sep}something_else{os.sep}hello{os.sep}a" == cd.cd("hello/a")
- assert f"/hero{os.sep}something_else{os.sep}hello{os.sep}a" == pwd()
- assert f"/hero{os.sep}something_else" == cd.cd(f"..{os.sep}..")
- ls.ls.return_value = ["something_else"]
- assert f"/hero{os.sep}something_else" == pwd()
- assert cd.cd("..") == "/hero"
- assert pwd() == "/hero"
- assert cd.cd("/") == "/"
- assert pwd() == "/"
- ls.ls.return_value = ["a"]
- assert cd.cd("../a") == "/a"
- assert pwd() == "/a"
- ls.ls.return_value = ["thomas"]
- assert f"/a{os.sep}thomas{os.sep}hello" == cd.cd(f"thomas{os.sep}hello")
- assert f"/a{os.sep}thomas{os.sep}hello" == pwd()
- ls.ls.return_value = ["thomas"]
- assert f"/thomas{os.sep}hello" == cd.cd(f"/thomas{os.sep}hello")
- assert f"/thomas{os.sep}hello" == pwd()
- assert cd.cd("/") == "/"
- ls.ls.return_value = ["name with spaces"]
- assert cd.cd("name with spaces") == "/name with spaces"
- ls.ls.return_value = ["name with spaces 2"]
- assert cd.cd("name with spaces 2") == "/name with spaces/name with spaces 2"
-
- os.remove(cd._CD_FILE)
-
- mock_exit = mock.MagicMock()
- monkeypatch.setattr(cd, "_error_and_exit", mock_exit)
- assert cd.cd("/") == "/"
- ls.ls.return_value = ["project_a"]
- cd.cd("project_b")
- assert mock_exit._mock_call_args.args[0] == "no such file or directory: /project_b"
diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py
deleted file mode 100644
index cfc747d729919..0000000000000
--- a/tests/tests_app/cli/test_cli.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import os
-from unittest import mock
-from unittest.mock import MagicMock
-
-import pytest
-from click.testing import CliRunner
-from lightning.app import __version__
-from lightning.app.cli.lightning_cli import _main, login, logout, run
-from lightning.app.cli.lightning_cli_delete import delete
-from lightning.app.cli.lightning_cli_list import get_list, list_apps
-from lightning.app.utilities.exceptions import _ApiExceptionHandler
-
-
-@pytest.mark.parametrize("command", [_main, run, get_list, delete])
-def test_commands(command):
- runner = CliRunner()
- result = runner.invoke(command)
- assert result.exit_code == 0
-
-
-def test_main_lightning_cli_no_arguments():
- """Validate the Lightning CLI without args."""
- res = os.popen("lightning_app").read()
- assert "login " in res
- assert "logout " in res
- assert "run " in res
- assert "list " in res
- assert "delete " in res
- assert "show " in res
-
-
-def test_main_lightning_cli_help():
- """Validate the Lightning CLI."""
- res = os.popen("lightning_app --help").read()
- assert "login " in res
- assert "logout " in res
- assert "run " in res
- assert "list " in res
- assert "delete " in res
- assert "show " in res
-
- res = os.popen("lightning_app run --help").read()
- assert "app " in res
-
- # hidden run commands should not appear in the help text
- assert "server" not in res
- assert "flow" not in res
- assert "work" not in res
- assert "frontend" not in res
-
- # inspect show group
- res = os.popen("lightning_app show --help").read()
- assert "logs " in res
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.cli.cmd_apps._AppManager.list")
-def test_list_apps(list_command: mock.MagicMock):
- runner = CliRunner()
- runner.invoke(list_apps)
-
-
-@mock.patch("lightning.app.utilities.login.Auth._run_server")
-@mock.patch("lightning.app.utilities.login.Auth.clear")
-def test_cli_login(clear: mock.MagicMock, run_server: mock.MagicMock):
- runner = CliRunner()
- runner.invoke(login)
-
- clear.assert_called_once_with()
- run_server.assert_called_once()
-
-
-@mock.patch("pathlib.Path.unlink")
-@mock.patch("pathlib.Path.exists")
-@pytest.mark.parametrize("creds", [True, False])
-def test_cli_logout(exists: mock.MagicMock, unlink: mock.MagicMock, creds: bool):
- exists.return_value = creds
- runner = CliRunner()
- runner.invoke(logout)
-
- exists.assert_called_once_with()
- if creds:
- unlink.assert_called_once_with()
- else:
- unlink.assert_not_called()
-
-
-def test_lightning_cli_version():
- res = os.popen("lightning_app --version").read()
- assert __version__ in res
-
-
-def test_main_catches_api_exceptions():
- assert isinstance(_main, _ApiExceptionHandler)
diff --git a/tests/tests_app/cli/test_cloud_cli.py b/tests/tests_app/cli/test_cloud_cli.py
deleted file mode 100644
index 42f8d5d6b3d90..0000000000000
--- a/tests/tests_app/cli/test_cloud_cli.py
+++ /dev/null
@@ -1,220 +0,0 @@
-import enum
-import logging
-import os
-from dataclasses import dataclass
-from functools import partial
-from unittest import mock
-from unittest.mock import ANY, MagicMock, call
-
-import lightning.app.runners.backends.cloud as cloud_backend
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import run_app
-from lightning.app.runners import cloud
-from lightning.app.runners.cloud import CloudRuntime
-from lightning_cloud.openapi import (
- V1CloudSpace,
- V1ListCloudSpacesResponse,
- V1ListLightningappInstancesResponse,
- V1ListMembershipsResponse,
- V1Membership,
-)
-from lightning_cloud.openapi.rest import ApiException
-
-from tests_app import _PROJECT_ROOT
-
-_FILE_PATH = os.path.join(_PROJECT_ROOT, "tests", "tests_app", "core", "scripts", "app_metadata.py")
-
-
-@dataclass
-class AppMetadata:
- id: str
-
-
-@dataclass
-class FakeResponse:
- lightningapps = [AppMetadata(id="my_app")]
-
-
-class FakeLightningClient:
- def cloud_space_service_list_cloud_spaces(self, *args, **kwargs):
- return V1ListCloudSpacesResponse(cloudspaces=[])
-
- def lightningapp_instance_service_list_lightningapp_instances(self, *args, **kwargs):
- return V1ListLightningappInstancesResponse(lightningapps=[])
-
- def lightningapp_service_delete_lightningapp(self, id: str = None):
- assert id == "my_app"
-
- def projects_service_list_memberships(self):
- return V1ListMembershipsResponse(memberships=[V1Membership(name="test-project", project_id="test-project-id")])
-
-
-class CloudRuntimePatch(CloudRuntime):
- def __init__(self, *args, **kwargs):
- super_init = super().__init__
- if hasattr(super_init, "__wrapped__"):
- super_init.__wrapped__(self, *args, **kwargs)
- else:
- super_init(*args, **kwargs)
-
-
-class V1LightningappInstanceState(enum.Enum):
- FAILED = "failed"
- SUCCESS = "success"
-
-
-@dataclass
-class FailedStatus:
- phase = V1LightningappInstanceState.FAILED
-
-
-@dataclass
-class SuccessStatus:
- phase = V1LightningappInstanceState.SUCCESS
-
-
-@dataclass
-class RuntimeErrorResponse:
- id = "my_app"
- source_upload_url = "something"
- status = FailedStatus()
-
-
-@dataclass
-class RuntimeErrorResponse2:
- id = "my_app"
- source_upload_url = ""
- status = SuccessStatus()
-
-
-@dataclass
-class SuccessResponse:
- id = "my_app"
- source_upload_url = "something"
- status = SuccessStatus()
-
-
-@dataclass
-class ExceptionResponse:
- status = FailedStatus()
-
-
-class FakeLightningClientCreate(FakeLightningClient):
- def __init__(self, *args, create_response, **kwargs):
- super().__init__()
- self.create_response = create_response
-
- def cloud_space_service_create_cloud_space(self, *args, **kwargs):
- return V1CloudSpace(id="my_app", name="app")
-
- def cloud_space_service_create_lightning_run(self, project_id, cloudspace_id, body):
- assert project_id == "test-project-id"
- return self.create_response
-
- def cloud_space_service_create_lightning_run_instance(self, project_id, cloudspace_id, id, body):
- assert project_id == "test-project-id"
- return self.create_response
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
-@mock.patch("lightning.app.runners.runtime_type.CloudRuntime", CloudRuntimePatch)
-@pytest.mark.parametrize("create_response", [RuntimeErrorResponse(), RuntimeErrorResponse2()])
-def test_start_app(create_response, monkeypatch):
- monkeypatch.setattr(cloud, "V1LightningappInstanceState", MagicMock())
- monkeypatch.setattr(cloud, "CloudspaceIdRunsBody", MagicMock())
- monkeypatch.setattr(cloud, "V1Flowserver", MagicMock())
- monkeypatch.setattr(cloud, "V1LightningappInstanceSpec", MagicMock())
- monkeypatch.setattr(
- cloud_backend,
- "LightningClient",
- partial(FakeLightningClientCreate, create_response=create_response),
- )
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", MagicMock())
-
- runner = CliRunner()
-
- def run():
- result = runner.invoke(run_app, [_FILE_PATH, "--cloud", "--open-ui=False"], catch_exceptions=False)
- assert result.exit_code == 0
-
- if isinstance(create_response, RuntimeErrorResponse):
- cloud.V1LightningappInstanceState.FAILED = V1LightningappInstanceState.FAILED
- with pytest.raises(RuntimeError, match="Failed to create the application"):
- run()
- elif isinstance(create_response, RuntimeErrorResponse2):
- with pytest.raises(RuntimeError, match="The source upload url is empty."):
- run()
- else:
- run()
- mocks_calls = cloud.LocalSourceCodeDir._mock_mock_calls
- assert len(mocks_calls) == 5
- assert str(mocks_calls[0].kwargs["path"]) == os.path.dirname(_FILE_PATH)
- mocks_calls[1].assert_called_once()
- mocks_calls[2].assert_called_once(url="url")
-
- assert cloud.V1Flowserver._mock_call_args_list == [call(name="root.flow_b")]
-
- cloud.V1LightningappInstanceSpec._mock_call_args.assert_called_once(
- app_entrypoint_file=_FILE_PATH,
- enable_app_server=True,
- works=ANY,
- flow_servers=ANY,
- )
-
- cloud.CloudspaceIdRunsBody.assert_called_once()
-
-
-class HttpHeaderDict(dict):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.reason = kwargs["reason"]
- self.status = kwargs["status"]
- self.data = kwargs["data"]
-
- def getheaders(self):
- return {}
-
-
-class FakeLightningClientException(FakeLightningClient):
- def __init__(self, *args, message, **kwargs):
- super().__init__()
- self.message = message
-
- def cloud_space_service_list_cloud_spaces(self, *args, **kwargs):
- raise ApiException(
- http_resp=HttpHeaderDict(
- data=self.message,
- reason="",
- status=500,
- )
- )
-
-
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.runners.runtime_type.CloudRuntime", CloudRuntimePatch)
-@pytest.mark.parametrize(
- "message",
- [
- "Cannot create a new app, you have reached the maximum number (10) of apps. Either increase your quota or delete some of the existing apps" # noqa E501
- ],
-)
-def test_start_app_exception(message, monkeypatch, caplog):
- monkeypatch.setattr(cloud, "V1LightningappInstanceState", MagicMock())
- monkeypatch.setattr(cloud, "CloudspaceIdRunsBody", MagicMock())
- monkeypatch.setattr(cloud, "V1Flowserver", MagicMock())
- monkeypatch.setattr(cloud, "V1LightningappInstanceSpec", MagicMock())
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", MagicMock())
- monkeypatch.setattr(cloud, "logger", logging.getLogger())
-
- runner = CliRunner()
-
- fake_grid_rest_client = partial(FakeLightningClientException, message=message)
- with caplog.at_level(logging.ERROR), mock.patch(
- "lightning.app.runners.backends.cloud.LightningClient", fake_grid_rest_client
- ):
- result = runner.invoke(run_app, [_FILE_PATH, "--cloud", "--open-ui=False"], catch_exceptions=False)
- assert result.exit_code == 1
- assert caplog.messages == [message]
diff --git a/tests/tests_app/cli/test_cmd_apps.py b/tests/tests_app/cli/test_cmd_apps.py
deleted file mode 100644
index ca50d591e0dfa..0000000000000
--- a/tests/tests_app/cli/test_cmd_apps.py
+++ /dev/null
@@ -1,157 +0,0 @@
-from unittest import mock
-from unittest.mock import MagicMock
-
-import pytest as pytest
-from lightning.app.cli.cmd_apps import _AppList, _AppManager
-from lightning_cloud.openapi import (
- Externalv1LightningappInstance,
- V1LightningappInstanceSpec,
- V1LightningappInstanceState,
- V1LightningappInstanceStatus,
- V1LightningworkState,
- V1ListLightningappInstancesResponse,
- V1ListLightningworkResponse,
- V1ListMembershipsResponse,
- V1Membership,
-)
-from rich.text import Text
-
-
-@pytest.mark.parametrize(
- ("current_state", "desired_state", "expected"),
- [
- (
- V1LightningappInstanceStatus(phase=V1LightningappInstanceState.RUNNING),
- V1LightningappInstanceState.DELETED,
- Text("terminating"),
- ),
- (
- V1LightningappInstanceStatus(phase=V1LightningappInstanceState.STOPPED),
- V1LightningappInstanceState.RUNNING,
- Text("restarting"),
- ),
- (
- V1LightningappInstanceStatus(phase=V1LightningappInstanceState.PENDING),
- V1LightningappInstanceState.RUNNING,
- Text("restarting"),
- ),
- (
- V1LightningappInstanceStatus(phase=V1LightningappInstanceState.UNSPECIFIED, start_timestamp=None),
- V1LightningappInstanceState.RUNNING,
- Text("not yet started"),
- ),
- ],
-)
-def test_state_transitions(current_state, desired_state, expected):
- actual = _AppList._textualize_state_transitions(current_state=current_state, desired_state=desired_state)
- assert actual == expected
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
-@mock.patch("lightning.app.utilities.network.LightningClient.projects_service_list_memberships")
-def test_list_all_apps_paginated(list_memberships: mock.MagicMock, list_instances: mock.MagicMock):
- list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
- list_instances.side_effect = [
- V1ListLightningappInstancesResponse(
- lightningapps=[
- Externalv1LightningappInstance(
- name="test1",
- spec=V1LightningappInstanceSpec(desired_state=V1LightningappInstanceState.RUNNING),
- status=V1LightningappInstanceStatus(phase=V1LightningappInstanceState.RUNNING),
- )
- ],
- next_page_token="page-2",
- ),
- V1ListLightningappInstancesResponse(
- lightningapps=[
- Externalv1LightningappInstance(
- name="test2",
- spec=V1LightningappInstanceSpec(desired_state=V1LightningappInstanceState.STOPPED),
- status=V1LightningappInstanceStatus(phase=V1LightningappInstanceState.RUNNING),
- )
- ],
- ),
- ]
-
- cluster_manager = _AppManager()
- cluster_manager.list()
-
- list_memberships.assert_called_once()
- assert list_instances.mock_calls == [
- mock.call(project_id="default-project", limit=100, phase_in=[]),
- mock.call(project_id="default-project", page_token="page-2", limit=100, phase_in=[]),
- ]
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
-@mock.patch("lightning.app.utilities.network.LightningClient.projects_service_list_memberships")
-def test_list_all_apps(list_memberships: mock.MagicMock, list_instances: mock.MagicMock):
- list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
- list_instances.return_value = V1ListLightningappInstancesResponse(lightningapps=[])
-
- cluster_manager = _AppManager()
- cluster_manager.list()
-
- list_memberships.assert_called_once()
- list_instances.assert_called_once_with(project_id="default-project", limit=100, phase_in=[])
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.utilities.network.LightningClient.lightningwork_service_list_lightningwork")
-@mock.patch("lightning.app.utilities.network.LightningClient.projects_service_list_memberships")
-def test_list_components(list_memberships: mock.MagicMock, list_components: mock.MagicMock):
- list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
- list_components.return_value = V1ListLightningworkResponse(lightningworks=[])
-
- cluster_manager = _AppManager()
- cluster_manager.list_components(app_id="cheese")
-
- list_memberships.assert_called_once()
- list_components.assert_called_once_with(project_id="default-project", app_id="cheese", phase_in=[])
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.utilities.network.LightningClient.lightningwork_service_list_lightningwork")
-@mock.patch("lightning.app.utilities.network.LightningClient.projects_service_list_memberships")
-def test_list_components_with_phase(list_memberships: mock.MagicMock, list_components: mock.MagicMock):
- list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
- list_components.return_value = V1ListLightningworkResponse(lightningworks=[])
-
- cluster_manager = _AppManager()
- cluster_manager.list_components(app_id="cheese", phase_in=[V1LightningworkState.RUNNING])
-
- list_memberships.assert_called_once()
- list_components.assert_called_once_with(
- project_id="default-project", app_id="cheese", phase_in=[V1LightningworkState.RUNNING]
- )
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch("lightning.app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
-@mock.patch("lightning.app.utilities.network.LightningClient.projects_service_list_memberships")
-def test_list_apps_on_cluster(list_memberships: mock.MagicMock, list_instances: mock.MagicMock):
- list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
- list_instances.return_value = V1ListLightningappInstancesResponse(lightningapps=[])
-
- cluster_manager = _AppManager()
- cluster_manager.list()
-
- list_memberships.assert_called_once()
- list_instances.assert_called_once_with(project_id="default-project", limit=100, phase_in=[])
-
-
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-@mock.patch(
- "lightning.app.utilities.network.LightningClient.lightningapp_instance_service_delete_lightningapp_instance"
-)
-@mock.patch("lightning.app.cli.cmd_apps._get_project")
-def test_delete_app_on_cluster(get_project_mock: mock.MagicMock, delete_app_mock: mock.MagicMock):
- get_project_mock.return_value = V1Membership(project_id="default-project")
-
- cluster_manager = _AppManager()
- cluster_manager.delete(app_id="12345")
-
- delete_app_mock.assert_called()
- delete_app_mock.assert_called_once_with(project_id="default-project", id="12345")
diff --git a/tests/tests_app/cli/test_cmd_cli_delete.py b/tests/tests_app/cli/test_cmd_cli_delete.py
deleted file mode 100644
index 704a3d89fb481..0000000000000
--- a/tests/tests_app/cli/test_cmd_cli_delete.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import sys
-from unittest import mock
-
-import pytest
-from lightning.app.cli.lightning_cli_delete import _find_selected_app_instance_id
-from lightning_cloud.openapi import Externalv1LightningappInstance
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
-@mock.patch("lightning_cloud.login.Auth.authenticate", mock.MagicMock())
-@mock.patch("lightning.app.cli.lightning_cli_delete._AppManager.list_apps")
-def test_app_find_selected_app_instance_id_when_app_name_exists(list_apps_mock: mock.MagicMock):
- list_apps_mock.return_value = [
- Externalv1LightningappInstance(name="app-name", id="app-id"),
- ]
- returned_app_instance_id = _find_selected_app_instance_id(app_name="app-name")
- assert returned_app_instance_id == "app-id"
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
-@mock.patch("lightning_cloud.login.Auth.authenticate", mock.MagicMock())
-@mock.patch("lightning.app.cli.lightning_cli_delete._AppManager.list_apps")
-def test_app_find_selected_app_instance_id_when_app_id_exists(list_apps_mock: mock.MagicMock):
- list_apps_mock.return_value = [
- Externalv1LightningappInstance(name="app-name", id="app-id"),
- ]
- returned_app_instance_id = _find_selected_app_instance_id(app_name="app-id")
- assert returned_app_instance_id == "app-id"
diff --git a/tests/tests_app/cli/test_cmd_init.py b/tests/tests_app/cli/test_cmd_init.py
deleted file mode 100644
index c8a572107ba3c..0000000000000
--- a/tests/tests_app/cli/test_cmd_init.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import contextlib
-import os
-import re
-import shutil
-import subprocess
-
-import pytest
-from lightning.app.cli import cmd_init
-from lightning.app.utilities.imports import _IS_MACOS, _IS_WINDOWS
-
-
-def test_validate_init_name():
- # test that a good name works (mix chars)
- value = cmd_init._capture_valid_app_component_name("abc1-cde")
- assert value == "abc1-cde"
-
- # test that a good name works (letters only)
- value = cmd_init._capture_valid_app_component_name("abc-cde")
- assert value == "abc-cde"
-
- # assert bad input
- with pytest.raises(SystemExit) as e:
- value = cmd_init._capture_valid_app_component_name("abc-cde#")
-
- assert "Error: your Lightning app name" in str(e.value)
-
-
-@pytest.mark.skipif(_IS_WINDOWS or _IS_MACOS, reason="unsupported OS") # todo
-@pytest.mark.xfail(strict=False, reason="need app fast_dev_run to work via CLI")
-def test_make_app_template():
- template_name = "app-test-template"
- template_name_folder = re.sub("-", "_", template_name)
-
- # remove the template if there
- template_dir = os.path.join(os.getcwd(), template_name)
- with contextlib.suppress(Exception):
- shutil.rmtree(template_dir)
-
- # create template
- subprocess.check_output(f"lightning init app {template_name}", shell=True)
-
- # make sure app is not in the env
- env_output = subprocess.check_output("pip freeze", shell=True)
- assert template_name not in str(env_output)
-
- # install the app
- env_output = subprocess.check_output(
- f"cd {template_name} && pip install -r requirements.txt && pip install -e .", shell=True
- )
- env_output = subprocess.check_output("pip freeze", shell=True)
- assert template_name in str(env_output)
-
- app_dir = os.path.join(template_dir, f"{template_name_folder}/app.py")
- output = subprocess.check_output(f"lightning run app {app_dir} --fast_dev_run") # noqa
- # TODO: verify output
-
- # clean up the template dir
- with contextlib.suppress(Exception):
- shutil.rmtree(template_dir)
-
-
-@pytest.mark.xfail(strict=False, reason="need component fast_dev_run to work via CLI")
-def test_make_component_template():
- template_name = "component-test-template"
- template_name_folder = re.sub("-", "_", template_name)
-
- # remove the template if there
- template_dir = os.path.join(os.getcwd(), template_name)
- with contextlib.suppress(Exception):
- shutil.rmtree(template_dir)
-
- # create template
- subprocess.check_output(f"lightning init component {template_name}", shell=True)
-
- # make sure app is not in the env
- env_output = subprocess.check_output("pip freeze", shell=True)
- assert template_name not in str(env_output)
-
- # install the app
- env_output = subprocess.check_output(
- f"cd {template_name} && pip install -r requirements.txt && pip install -e .", shell=True
- )
- env_output = subprocess.check_output("pip freeze", shell=True)
- assert template_name in str(env_output)
-
- app_dir = os.path.join(template_dir, f"{template_name_folder}/app.py")
- output = subprocess.check_output(f"lightning run app {app_dir} --fast_dev_run") # noqa
- # TODO: verify output
-
- # clean up the template dir
- with contextlib.suppress(Exception):
- shutil.rmtree(template_dir)
diff --git a/tests/tests_app/cli/test_cmd_install.py b/tests/tests_app/cli/test_cmd_install.py
deleted file mode 100644
index 5fddaea097fb3..0000000000000
--- a/tests/tests_app/cli/test_cmd_install.py
+++ /dev/null
@@ -1,376 +0,0 @@
-import os
-import subprocess
-from pathlib import Path
-from unittest import mock
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli import cmd_install, lightning_cli
-from lightning.app.testing.helpers import _RunIf
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock())
-def test_valid_org_app_name():
- """Valid organization name."""
- runner = CliRunner()
-
- # assert a bad app name should fail
- fake_app = "fakeuser/impossible/name"
- result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app])
- assert "app name format must have organization/app-name" in result.output
-
- # assert a good name (but unavailable name) should work
- fake_app = "fakeuser/ALKKLJAUHREKJ21234KLAKJDLF"
- result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app])
- assert f"app: '{fake_app}' is not available on ⚡ Lightning AI ⚡" in result.output
- assert result.exit_code
-
- # assert a good (and availablea name) works
- # This should be an app that's always in the gallery
- real_app = "lightning/invideo"
- result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app])
- assert "Press enter to continue:" in result.output
-
-
-@pytest.mark.xfail(strict=False, reason="need to figure out how to authorize git clone from the private repo")
-def test_valid_unpublished_app_name():
- runner = CliRunner()
-
- # assert warning of non official app given
- real_app = "https://github.com/Lightning-AI/install-app"
- with pytest.raises(subprocess.CalledProcessError, match="WARNING"):
- subprocess.check_output(f"lightning install app {real_app}", shell=True, stderr=subprocess.STDOUT)
-
- # assert aborted install
- result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app], input="q")
- assert "Installation aborted!" in result.output
-
- # assert a bad app name should fail
- fake_app = "https://github.com/Lightning-AI/install-appdd"
- result = runner.invoke(lightning_cli.cmd_install.install_app, [fake_app, "--yes"])
- assert "Looks like the github url was not found" in result.output
-
- # assert a good (and availablea name) works
- result = runner.invoke(lightning_cli.cmd_install.install_app, [real_app])
- assert "Press enter to continue:" in result.output
-
-
-@pytest.mark.xfail(strict=False, reason="need to figure out how to authorize git clone from the private repo")
-def test_app_install(tmpdir, monkeypatch):
- """Tests unpublished app install."""
- monkeypatch.chdir(tmpdir)
-
- real_app = "https://github.com/Lightning-AI/install-app"
- test_app_pip_name = "install-app"
-
- # install app and verify it's in the env
- subprocess.check_output(f"lightning install app {real_app} --yes", shell=True)
- new_env_output = subprocess.check_output("pip freeze", shell=True)
- assert test_app_pip_name in str(new_env_output), f"{test_app_pip_name} should be in the env"
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock())
-def test_valid_org_component_name():
- runner = CliRunner()
-
- # assert a bad name should fail
- fake_component = "fakeuser/impossible/name"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component])
- assert "component name format must have organization/component-name" in result.output
-
- # assert a good name (but unavailable name) should work
- fake_component = "fakeuser/ALKKLJAUHREKJ21234KLAKJDLF"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component])
- assert f"component: '{fake_component}' is not available on ⚡ Lightning AI ⚡" in result.output
-
- # assert a good (and availablea name) works
- fake_component = "lightning/lit-slack-messenger"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component])
- assert "Press enter to continue:" in result.output
-
-
-def test_unpublished_component_url_parsing():
- runner = CliRunner()
-
- # assert a bad name should fail (no git@)
- fake_component = "https://github.com/Lightning-AI/LAI-slack-messenger"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [fake_component])
- assert "Error, your github url must be in the following format" in result.output
-
- # assert a good (and availablea name) works
- sha = "14f333456ffb6758bd19458e6fa0bf12cf5575e1"
- real_component = f"git+https://github.com/Lightning-AI/LAI-slack-messenger.git@{sha}"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [real_component])
- assert "Press enter to continue:" in result.output
-
-
-@pytest.mark.xfail(strict=False, reason="need to figure out how to authorize pip install from the private repo")
-@pytest.mark.parametrize(
- ("real_component", "test_component_pip_name"),
- [
- ("lightning/lit-slack-messenger", "lit-slack"),
- (
- "git+https://github.com/Lightning-AI/LAI-slack-messenger.git@14f333456ffb6758bd19458e6fa0bf12cf5575e1",
- "lit-slack",
- ),
- ],
-)
-def test_component_install(real_component, test_component_pip_name):
- """Tests both published and unpublished component installs."""
- # uninstall component just in case and verify it's not in the pip output
- env_output = subprocess.check_output(f"pip uninstall {test_component_pip_name} --yes && pip freeze", shell=True)
- assert test_component_pip_name not in str(env_output), f"{test_component_pip_name} should not be in the env"
-
- # install component and verify it's in the env
- new_env_output = subprocess.check_output(
- f"lightning install component {real_component} --yes && pip freeze", shell=True
- )
- assert test_component_pip_name in str(new_env_output), f"{test_component_pip_name} should be in the env"
-
- # clean up for test
- subprocess.run(f"pip uninstall {test_component_pip_name} --yes", shell=True)
- env_output = subprocess.check_output("pip freeze", shell=True)
- assert test_component_pip_name not in str(
- env_output
- ), f"{test_component_pip_name} should not be in the env after cleanup"
-
-
-def test_prompt_actions():
- # TODO: each of these installs must check that a package is installed in the environment correctly
- app_to_use = "lightning/invideo"
-
- runner = CliRunner()
-
- # assert that the user can cancel the command with any letter other than y
- result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input="b")
- assert "Installation aborted!" in result.output
-
- # assert that the install happens with --yes
- # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use, "--yes"])
- # assert result.exit_code == 0
-
- # assert that the install happens with y
- # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='y')
- # assert result.exit_code == 0
-
- # # assert that the install happens with yes
- # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='yes')
- # assert result.exit_code == 0
-
- # assert that the install happens with pressing enter
- # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use])
-
- # TODO: how to check the output when the user types ctrl+c?
- # result = runner.invoke(lightning_cli.cmd_install.install_app, [app_to_use], input='')
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock())
-def test_version_arg_component(tmpdir, monkeypatch):
- monkeypatch.chdir(tmpdir)
- runner = CliRunner()
-
- # Version does not exist
- component_name = "lightning/lit-slack-messenger"
- version_arg = "NOT-EXIST"
- result = runner.invoke(lightning_cli.cmd_install.install_component, [component_name, f"--version={version_arg}"])
- assert f"component: 'Version {version_arg} for {component_name}' is not" in str(result.exception)
- assert result.exit_code == 1
-
- # Version exists
- # This somwehow fail in test but not when you actually run it
- version_arg = "0.0.1"
- runner = CliRunner()
- result = runner.invoke(
- lightning_cli.cmd_install.install_component, [component_name, f"--version={version_arg}", "--yes"]
- )
- assert result.exit_code == 0
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock())
-@mock.patch("lightning.app.cli.cmd_install.os.chdir", mock.MagicMock())
-def test_version_arg_app(tmpdir):
- # Version does not exist
- app_name = "lightning/invideo"
- version_arg = "NOT-EXIST"
- runner = CliRunner()
- result = runner.invoke(lightning_cli.cmd_install.install_app, [app_name, f"--version={version_arg}"])
- assert f"app: 'Version {version_arg} for {app_name}' is not" in str(result.exception)
- assert result.exit_code == 1
-
- # Version exists
- version_arg = "0.0.2"
- runner = CliRunner()
- result = runner.invoke(lightning_cli.cmd_install.install_app, [app_name, f"--version={version_arg}", "--yes"])
- assert result.exit_code == 0
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess", mock.MagicMock())
-@mock.patch("lightning.app.cli.cmd_install.os.chdir", mock.MagicMock())
-@mock.patch("lightning.app.cli.cmd_install._show_install_app_prompt")
-def test_install_resolve_latest_version(mock_show_install_app_prompt, tmpdir):
- app_name = "lightning/invideo"
- runner = CliRunner()
- with mock.patch("lightning.app.cli.cmd_install.requests.get") as get_api_mock:
- get_api_mock.return_value.json.return_value = {
- "apps": [
- {
- "canDownloadSourceCode": True,
- "version": "0.0.2",
- "name": "lightning/invideo",
- },
- {
- "canDownloadSourceCode": True,
- "version": "0.0.4",
- "name": "lightning/invideo",
- },
- {
- "canDownloadSourceCode": True,
- "version": "0.0.5",
- "name": "another_app",
- },
- ]
- }
- runner.invoke(
- lightning_cli.cmd_install.install_app, [app_name, "--yes"]
- ) # no version specified so latest is installed
- assert mock_show_install_app_prompt.called
- assert mock_show_install_app_prompt.call_args[0][0]["version"] == "0.0.4"
-
-
-def test_proper_url_parsing():
- name = "lightning/invideo"
-
- # make sure org/app-name name is correct
- org, app = cmd_install._validate_name(name, resource_type="app", example="lightning/lit-slack-component")
- assert org == "lightning"
- assert app == "invideo"
-
- # resolve registry (orgs can have a private registry through their environment variables)
- registry_url = cmd_install._resolve_app_registry()
- assert registry_url == "https://lightning.ai/v1/apps"
-
- # load the component resource
- component_entry = cmd_install._resolve_resource(registry_url, name=name, version_arg="latest", resource_type="app")
-
- source_url, git_url, folder_name, git_sha = cmd_install._show_install_app_prompt(
- component_entry, app, org, True, resource_type="app"
- )
- assert folder_name == "LAI-InVideo-search-App"
- # FixMe: this need to be updated after release with updated org rename
- assert source_url == "https://github.com/Lightning-AI/LAI-InVideo-search-App"
- assert "#ref" not in git_url
- assert git_sha
-
-
-@_RunIf(skip_windows=True)
-def test_install_app_shows_error(tmpdir):
- app_folder_dir = Path(tmpdir / "some_random_directory").absolute()
- app_folder_dir.mkdir()
-
- with pytest.raises(SystemExit, match=f"Folder {str(app_folder_dir)} exists, please delete it and try again."):
- cmd_install._install_app_from_source(
- source_url=mock.ANY, git_url=mock.ANY, folder_name=str(app_folder_dir), overwrite=False
- )
-
-
-# def test_env_creation(tmpdir):
-# cwd = os.getcwd()
-# os.chdir(tmpdir)
-
-# # install app
-# cmd_install.app("lightning/install-app", True, cwd=tmpdir)
-
-# # assert app folder is installed with venv
-# assert "python" in set(os.listdir(os.path.join(tmpdir, "install-app/bin")))
-
-# # assert the deps are in the env
-# env_output = subprocess.check_output("source bin/activate && pip freeze", shell=True)
-# non_env_output = subprocess.check_output("pip freeze", shell=True)
-
-# # assert envs are not the same
-# assert env_output != non_env_output
-
-# # assert the reqs are in the env created and NOT in the non env
-# reqs = open(os.path.join(tmpdir, "install-app/requirements.txt")).read()
-# assert reqs in str(env_output) and reqs not in str(non_env_output)
-
-# # setup.py installs numpy
-# assert "numpy" in str(env_output)
-
-# # run the python script to make sure the file works (in a folder)
-# app_file = os.path.join(tmpdir, "install-app/src/app.py")
-# app_output = subprocess.check_output(f"source bin/activate && python {app_file}", shell=True)
-# assert "b'printed a\\ndeps loaded\\n'" == str(app_output)
-
-# # run the python script to make sure the file works (in root)
-# app_file = os.path.join(tmpdir, "install-app/app_b.py")
-# app_output = subprocess.check_output(f"source bin/activate && python {app_file}", shell=True)
-# assert "b'printed a\\n'" == str(app_output)
-
-# # reset dir
-# os.chdir(cwd)
-
-
-def test_app_and_component_gallery_app(monkeypatch):
- monkeypatch.setattr(cmd_install, "_install_app_from_source", mock.MagicMock())
- path = cmd_install.gallery_apps_and_components("lightning/flashy", True, "latest")
- assert path == os.path.join(os.getcwd(), "app.py")
-
-
-def test_app_and_component_gallery_component(monkeypatch):
- monkeypatch.setattr(cmd_install, "_install_app_from_source", mock.MagicMock())
- path = cmd_install.gallery_apps_and_components("lightning/lit-jupyter", True, "latest")
- assert path == os.path.join(os.getcwd(), "app.py")
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_APP_REGISTRY": "https://TODO/other_non_PL_registry"})
-def test_private_app_registry():
- registry = cmd_install._resolve_app_registry()
- assert registry == "https://TODO/other_non_PL_registry"
-
-
-def test_public_app_registry():
- registry = cmd_install._resolve_app_registry()
- assert registry == "https://lightning.ai/v1/apps"
-
-
-def test_public_component_registry():
- registry = cmd_install._resolve_component_registry()
- assert registry == "https://lightning.ai/v1/components"
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_COMPONENT_REGISTRY": "https://TODO/other_non_PL_registry"})
-def test_private_component_registry():
- registry = cmd_install._resolve_component_registry()
- assert registry == "https://TODO/other_non_PL_registry"
-
-
-@mock.patch("lightning.app.cli.cmd_install.subprocess")
-@mock.patch("lightning.app.cli.cmd_install.os.chdir", mock.MagicMock())
-@pytest.mark.parametrize(
- ("source_url", "git_url", "git_sha"),
- [
- (
- "https://github.com/PyTorchLightning/lightning-quick-start",
- "https://
@github.com/PyTorchLightning/lightning-quick-start",
- None,
- ),
- (
- "https://github.com/PyTorchLightning/lightning-quick-start",
- "https://@github.com/PyTorchLightning/lightning-quick-start",
- "git_sha",
- ),
- ],
-)
-def test_install_app_process(subprocess_mock, source_url, git_url, git_sha, tmpdir):
- app_folder_dir = Path(tmpdir / "some_random_directory").absolute()
- app_folder_dir.mkdir()
-
- cmd_install._install_app_from_source(
- source_url, git_url, folder_name=str(app_folder_dir), overwrite=True, git_sha=git_sha
- )
- assert subprocess_mock.check_output.call_args_list[0].args == (["git", "clone", git_url],)
- if git_sha:
- assert subprocess_mock.check_output.call_args_list[1].args == (["git", "checkout", git_sha],)
- assert subprocess_mock.call.call_args_list[0].args == ("pip install -r requirements.txt",)
- assert subprocess_mock.call.call_args_list[1].args == ("pip install -e .",)
diff --git a/tests/tests_app/cli/test_cmd_launch.py b/tests/tests_app/cli/test_cmd_launch.py
deleted file mode 100644
index 4b75c08de5dce..0000000000000
--- a/tests/tests_app/cli/test_cmd_launch.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import os
-import signal
-import time
-from functools import partial
-from multiprocessing import Process
-from pathlib import Path
-from unittest import mock
-from unittest.mock import ANY, MagicMock, Mock
-
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli_launch import run_flow, run_flow_and_servers, run_frontend, run_server
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.frontend.web import StaticWebFrontend
-from lightning.app.launcher import launcher
-from lightning.app.runners.runtime import load_app_from_file
-from lightning.app.testing.helpers import EmptyWork, _RunIf
-from lightning.app.utilities.app_commands import run_app_commands
-from lightning.app.utilities.network import find_free_network_port
-
-from tests_app import _PROJECT_ROOT
-
-_FILE_PATH = os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py")
-
-
-def test_run_frontend(monkeypatch):
- """Test that the CLI can be used to start the frontend server of a particular LightningFlow using the cloud
- dispatcher.
-
- This CLI call is made by Lightning AI and is not meant to be invoked by the user directly.
-
- """
- runner = CliRunner()
-
- port = find_free_network_port()
-
- start_server_mock = Mock()
- monkeypatch.setattr(StaticWebFrontend, "start_server", start_server_mock)
-
- result = runner.invoke(
- run_frontend,
- [
- str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"),
- "--flow-name",
- "root.aas",
- "--host",
- "localhost",
- "--port",
- port,
- ],
- )
- assert result.exit_code == 0
- start_server_mock.assert_called_once()
- start_server_mock.assert_called_with("localhost", port)
-
-
-class MockRedisQueue:
- _MOCKS = {}
-
- def __init__(self, name: str, default_timeout: float):
- self.name = name
- self.default_timeout = default_timeout
- self.queue = [None] # adding a dummy element.
-
- self._MOCKS[name] = MagicMock()
-
- def put(self, item):
- self._MOCKS[self.name].put(item)
- self.queue.put(item)
-
- def get(self, timeout: int = None):
- self._MOCKS[self.name].get(timeout=timeout)
- return self.queue.pop(0)
-
- @property
- def is_running(self):
- self._MOCKS[self.name].is_running()
- return True
-
-
-@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
-@mock.patch("lightning.app.launcher.launcher.check_if_redis_running", MagicMock(return_value=True))
-@mock.patch("lightning.app.launcher.launcher.start_server")
-def test_run_server(start_server_mock):
- runner = CliRunner()
- result = runner.invoke(
- run_server,
- [
- _FILE_PATH,
- "--queue-id",
- "1",
- "--host",
- "http://127.0.0.1:7501/view",
- "--port",
- "6000",
- ],
- catch_exceptions=False,
- )
- assert result.exit_code == 0
- start_server_mock.assert_called_once_with(
- host="http://127.0.0.1:7501/view",
- port=6000,
- api_publish_state_queue=ANY,
- api_delta_queue=ANY,
- api_response_queue=ANY,
- spec=ANY,
- apis=ANY,
- )
- kwargs = start_server_mock._mock_call_args.kwargs
- assert isinstance(kwargs["api_publish_state_queue"], MockRedisQueue)
- assert kwargs["api_publish_state_queue"].name.startswith("1")
- assert isinstance(kwargs["api_delta_queue"], MockRedisQueue)
- assert kwargs["api_delta_queue"].name.startswith("1")
-
-
-def mock_server(should_catch=False, sleep=1000):
- if should_catch:
-
- def _sigterm_handler(*_):
- time.sleep(100)
-
- signal.signal(signal.SIGTERM, _sigterm_handler)
-
- time.sleep(sleep)
-
-
-def run_forever_process():
- while True:
- time.sleep(1)
-
-
-def run_for_2_seconds_and_raise():
- time.sleep(2)
- raise RuntimeError("existing")
-
-
-def exit_successfully_immediately():
- return
-
-
-def start_servers(should_catch=False, sleep=1000):
- processes = [
- (
- "p1",
- launcher.start_server_in_process(target=partial(mock_server, should_catch=should_catch, sleep=sleep)),
- ),
- (
- "p2",
- launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)),
- ),
- (
- "p3",
- launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)),
- ),
- ]
-
- launcher.manage_server_processes(processes)
-
-
-@_RunIf(skip_windows=True)
-def test_manage_server_processes():
- p = Process(target=partial(start_servers, sleep=0.5))
- p.start()
- p.join()
-
- assert p.exitcode == 0
-
- p = Process(target=start_servers)
- p.start()
- p.join(0.5)
- p.terminate()
- p.join()
-
- assert p.exitcode in [-15, 0]
-
- p = Process(target=partial(start_servers, should_catch=True))
- p.start()
- p.join(0.5)
- p.terminate()
- p.join()
-
- assert p.exitcode in [-15, 1]
-
-
-def start_processes(**functions):
- processes = []
- for name, fn in functions.items():
- processes.append((name, launcher.start_server_in_process(fn)))
- launcher.manage_server_processes(processes)
-
-
-@_RunIf(skip_windows=True)
-def test_manage_server_processes_one_process_gets_killed(capfd):
- functions = {"p1": run_forever_process, "p2": run_for_2_seconds_and_raise}
- p = Process(target=start_processes, kwargs=functions)
- p.start()
-
- for _ in range(40):
- time.sleep(1)
- if p.exitcode is not None:
- break
- assert p.exitcode == 1
- captured = capfd.readouterr()
- assert (
- "Found dead components with non-zero exit codes, exiting execution!!! Components: \n"
- "| Name | Exit Code |\n|------|-----------|\n| p2 | 1 |\n" in captured.out
- )
-
-
-@_RunIf(skip_windows=True)
-def test_manage_server_processes_all_processes_exits_with_zero_exitcode(capfd):
- functions = {
- "p1": exit_successfully_immediately,
- "p2": exit_successfully_immediately,
- }
- p = Process(target=start_processes, kwargs=functions)
- p.start()
-
- for _ in range(40):
- time.sleep(1)
- if p.exitcode is not None:
- break
- assert p.exitcode == 0
- captured = capfd.readouterr()
- assert "All the components are inactive with exitcode 0. Exiting execution!!!" in captured.out
-
-
-@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock())
-@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
-@mock.patch("lightning.app.launcher.launcher.manage_server_processes", Mock())
-def test_run_flow_and_servers(monkeypatch):
- runner = CliRunner()
-
- start_server_mock = Mock()
- monkeypatch.setattr(launcher, "start_server_in_process", start_server_mock)
-
- runner.invoke(
- run_flow_and_servers,
- [
- str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"),
- "--base-url",
- "https://some.url",
- "--queue-id",
- "1",
- "--host",
- "http://127.0.0.1:7501/view",
- "--port",
- 6000,
- "--flow-port",
- "root.aas",
- 6001,
- "--flow-port",
- "root.bbs",
- 6002,
- ],
- catch_exceptions=False,
- )
-
- start_server_mock.assert_called()
- assert start_server_mock.call_count == 4
-
-
-@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue)
-@mock.patch("lightning.app.launcher.launcher.WorkRunner")
-def test_run_work(mock_work_runner, monkeypatch):
- run_app_commands(_FILE_PATH)
- app = load_app_from_file(_FILE_PATH)
- names = [w.name for w in app.works]
-
- mocked_queue = MagicMock()
- mocked_queue.get.return_value = EmptyWork()
- monkeypatch.setattr(
- QueuingSystem,
- "get_work_queue",
- MagicMock(return_value=mocked_queue),
- )
-
- assert names == [
- "root.flow_a_1.work_a",
- "root.flow_a_2.work_a",
- "root.flow_b.work_b",
- ]
-
- for name in names:
- launcher.run_lightning_work(
- file=_FILE_PATH,
- work_name=name,
- queue_id="1",
- )
- kwargs = mock_work_runner._mock_call_args.kwargs
- assert isinstance(kwargs["work"], EmptyWork)
- assert kwargs["work_name"] == name
- assert isinstance(kwargs["caller_queue"], MockRedisQueue)
- assert kwargs["caller_queue"].name.startswith("1")
- assert isinstance(kwargs["delta_queue"], MockRedisQueue)
- assert kwargs["delta_queue"].name.startswith("1")
- assert isinstance(kwargs["readiness_queue"], MockRedisQueue)
- assert kwargs["readiness_queue"].name.startswith("1")
- assert isinstance(kwargs["error_queue"], MockRedisQueue)
- assert kwargs["error_queue"].name.startswith("1")
- assert isinstance(kwargs["request_queue"], MockRedisQueue)
- assert kwargs["request_queue"].name.startswith("1")
- assert isinstance(kwargs["response_queue"], MockRedisQueue)
- assert kwargs["response_queue"].name.startswith("1")
- assert isinstance(kwargs["copy_request_queue"], MockRedisQueue)
- assert kwargs["copy_request_queue"].name.startswith("1")
- assert isinstance(kwargs["copy_response_queue"], MockRedisQueue)
- assert kwargs["copy_response_queue"].name.startswith("1")
-
- MockRedisQueue._MOCKS["healthz"].is_running.assert_called()
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
-@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock())
-@mock.patch("lightning.app.LightningApp._run")
-@mock.patch("lightning.app.launcher.launcher.CloudBackend")
-def test_run_flow(mock_cloud_backend, mock_lightning_app_run):
- runner = CliRunner()
-
- base_url = "https://lightning.ai/me/apps"
-
- result = runner.invoke(
- run_flow,
- [_FILE_PATH, "--queue-id=1", f"--base-url={base_url}"],
- catch_exceptions=False,
- )
- assert result.exit_code == 0
- mock_lightning_app_run.assert_called_once()
- assert len(mock_cloud_backend._mock_mock_calls) == 13
diff --git a/tests/tests_app/cli/test_cmd_pl_init.py b/tests/tests_app/cli/test_cmd_pl_init.py
deleted file mode 100644
index 192811a56f5a4..0000000000000
--- a/tests/tests_app/cli/test_cmd_pl_init.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import os
-import sys
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from click.testing import CliRunner
-from lightning.app.cli import lightning_cli
-from lightning.app.cli.cmd_pl_init import _can_encode_icon, download_frontend, pl_app
-
-
-def test_pl_app_input_paths_do_not_exist(tmp_path):
- """Test that the CLI prints an error message if the code directory or the script path does not exist."""
- runner = CliRunner()
-
- source_dir = tmp_path / "code"
- script_file = tmp_path / "code" / "script.py"
-
- result = runner.invoke(lightning_cli.init_pl_app, (str(source_dir), str(script_file)))
- assert result.exit_code == 1
- assert "The given source directory does not exist:" in result.output
-
- source_dir.mkdir(parents=True)
-
- result = runner.invoke(lightning_cli.init_pl_app, (str(source_dir), str(script_file)))
- assert result.exit_code == 1
- assert "The given script path does not exist:" in result.output
-
- script_file_as_folder = tmp_path / "code" / "folder"
- script_file_as_folder.mkdir(parents=True)
- result = runner.invoke(lightning_cli.init_pl_app, (str(source_dir), str(script_file_as_folder)))
- assert result.exit_code == 1
- assert "The given script path must be a file, you passed:" in result.output
-
-
-def test_pl_app_script_path_not_subpath(tmp_path):
- """Test that the CLI prints an error message if the provided script path is not a subpath of the source dir."""
- runner = CliRunner()
-
- source_dir = tmp_path / "code"
- script_file = tmp_path / "not_code" / "script.py"
-
- source_dir.mkdir(parents=True)
- script_file.parent.mkdir(parents=True)
- script_file.touch()
-
- result = runner.invoke(lightning_cli.init_pl_app, (str(source_dir), str(script_file)), catch_exceptions=False)
- assert result.exit_code == 1
- assert "The given script path must be a subpath of the source directory." in result.output
-
-
-def test_pl_app_destination_app_already_exists(tmp_path, monkeypatch):
- """Test that the CLI prints an error message if an app with the same name already exists."""
- runner = CliRunner()
- monkeypatch.chdir(tmp_path)
-
- source_dir = tmp_path / "code"
- script_file = source_dir / "script.py"
- source_dir.mkdir(parents=True)
- script_file.parent.mkdir(parents=True, exist_ok=True)
- script_file.touch()
-
- # monkeypatch.chdir(tmp_path)
- app_folder = tmp_path / "existing-app"
- app_folder.mkdir(parents=True)
-
- result = runner.invoke(lightning_cli.init_pl_app, (str(source_dir), str(script_file), "--name", "existing-app"))
- assert result.exit_code == 1
- assert "There is already an app with the name existing-app in the current working directory" in result.output
-
-
-def test_pl_app_incorrect_number_of_arguments(tmp_path):
- """Test that the CLI prints an error message if more than two input arguments for the source are provided."""
- runner = CliRunner()
- result = runner.invoke(lightning_cli.init_pl_app, ("one", "two", "three"))
- assert result.exit_code == 1
- assert "Incorrect number of arguments. You passed (one, two, three) but only either one argument" in result.output
-
-
-def test_pl_app_download_frontend(tmp_path):
- build_dir = tmp_path / "app" / "ui" / "build"
- download_frontend(build_dir)
- contents = os.listdir(build_dir)
- assert "index.html" in contents
- assert "static" in contents
-
-
-def test_pl_app_encode_icon(monkeypatch):
- stdout_mock = Mock(wraps=sys.stdout)
- monkeypatch.setattr(sys, "stdout", stdout_mock)
-
- stdout_mock.encoding = "utf-8"
- assert _can_encode_icon("📂")
- assert _can_encode_icon("📄")
-
- stdout_mock.encoding = "ascii"
- assert not _can_encode_icon("📂")
- assert not _can_encode_icon("📄")
-
-
-@pytest.mark.parametrize(
- ("cwd", "source_dir", "script_path"),
- [
- ("./", "./", "train.py"),
- ("./", "./code", "./code/train.py"),
- ],
-)
-@mock.patch("lightning.app.cli.cmd_pl_init.project_file_from_template")
-@mock.patch("lightning.app.cli.cmd_pl_init.download_frontend")
-def test_pl_app_relative_paths(_, __, cwd, source_dir, script_path, tmp_path, monkeypatch):
- source_dir = tmp_path / source_dir
- source_dir.mkdir(parents=True, exist_ok=True)
- script_path = tmp_path / script_path
- script_path.parent.mkdir(parents=True, exist_ok=True)
- script_path.touch()
- cwd = tmp_path / cwd
- monkeypatch.chdir(cwd)
-
- pl_app(source_dir=str(source_dir), script_path=str(script_path), name="app-name", overwrite=False)
- assert (cwd / "app-name").is_dir()
-
- expected_source_files = set(os.listdir(source_dir))
- if cwd == source_dir:
- expected_source_files.remove("app-name")
- assert set(os.listdir(cwd / "app-name" / "source")) == expected_source_files
diff --git a/tests/tests_app/cli/test_cmd_react_ui_init.py b/tests/tests_app/cli/test_cmd_react_ui_init.py
deleted file mode 100644
index 0f28420e291ba..0000000000000
--- a/tests/tests_app/cli/test_cmd_react_ui_init.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-
-import lightning.app as la
-import pytest
-from lightning.app.cli import cmd_init, cmd_react_ui_init
-from lightning.app.testing.helpers import _RunIf
-
-
-@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is None, reason="not running in GH actions.")
-@pytest.mark.xfail(strict=False, reason="need to figure out how to mock not having npm")
-def test_missing_npm():
- with pytest.raises(SystemExit, match="This machine is missing 'npm'"):
- cmd_react_ui_init._check_react_prerequisites()
-
-
-@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is None, reason="not running in GH actions.")
-@pytest.mark.xfail(strict=False, reason="need to figure out how to mock not having node")
-def test_missing_nodejs():
- with pytest.raises(SystemExit, match="This machine is missing 'node'"):
- cmd_react_ui_init._check_react_prerequisites()
-
-
-@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is None, reason="not running in GH actions")
-@pytest.mark.xfail(strict=False, reason="need to figure out how to mock not having yarn")
-def test_missing_yarn():
- with pytest.raises(SystemExit, match="This machine is missing 'yarn'"):
- cmd_react_ui_init._check_react_prerequisites()
-
-
-@_RunIf(skip_windows=True)
-def test_copy_and_setup_react_ui(tmpdir):
- dest_dir = os.path.join(tmpdir, "react-ui")
- os.system(f"lightning_app init react-ui --dest_dir={dest_dir}")
-
- # make sure package is minimal
- files = sorted(f for f in os.listdir(dest_dir) if f != "__pycache__")
- assert len(files) == 3, "should only be 3 objects: readme.md, example_app.py and ui dir"
-
- # make sure index.html has the vite app placeholder
- with open(dest_dir + "/ui/dist/index.html") as fo:
- index_content = fo.read()
- assert "Vite App " in index_content
-
- # read the compiled js file
- js_file = [x for x in os.listdir(os.path.join(dest_dir, "ui", "dist", "assets")) if ".js" in x]
- js_file = os.path.join(dest_dir, f"ui/dist/assets/{js_file[0]}")
- with open(js_file) as fo:
- index_content = fo.read()
-
- # if this is in the compiled file, the compilation worked and the app will work
- assert "Total number of prints in your terminal:" in index_content, "react app was not compiled properly"
- assert "LightningState.subscribe" in index_content, "react app was not compiled properly"
-
-
-@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is None, reason="not running in GH actions")
-def test_correct_num_react_template_files():
- template_dir = os.path.join(la.__path__[0], "cli/react-ui-template")
- files = cmd_init._ls_recursively(template_dir)
- # TODO: remove lock file!!!
- assert len(files) == 16, "react-ui template files must be minimal... do not add nice to haves"
diff --git a/tests/tests_app/cli/test_cmd_show_logs.py b/tests/tests_app/cli/test_cmd_show_logs.py
deleted file mode 100644
index 80d30d5cd12a2..0000000000000
--- a/tests/tests_app/cli/test_cmd_show_logs.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from unittest import mock
-
-from click.testing import CliRunner
-from lightning.app.cli.lightning_cli import show
-
-
-@mock.patch("lightning.app.cli.commands.logs.LightningClient")
-@mock.patch("lightning.app.cli.commands.logs._get_project")
-def test_show_logs_errors(_, client):
- """Test that the CLI prints the errors for the show logs command."""
- runner = CliRunner()
-
- # Response prep
- app = mock.MagicMock()
- app.name = "My-FakeApp"
- app.display_name = "My_FakeApp"
- work = mock.MagicMock()
- work.name = "MyFakeWork"
- flow = mock.MagicMock()
- flow.name = "MyFakeFlow"
-
- # No apps ever run
- apps = {}
- client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
-
- result = runner.invoke(show.commands["logs"], ["NonExistentApp"])
-
- assert result.exit_code == 1
- assert "Error: You don't have any application in the cloud" in result.output
-
- # App not specified
- apps = {app}
- client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
-
- result = runner.invoke(show.commands["logs"])
-
- assert result.exit_code == 1
- assert "Please select one of the following: [My_FakeApp]" in str(result.output)
-
- # App does not exit
- apps = {app}
- client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
-
- result = runner.invoke(show.commands["logs"], ["ThisAppDoesNotExist"])
-
- assert result.exit_code == 1
- assert "The Lightning App 'ThisAppDoesNotExist' does not exist." in str(result.output)
-
- # Component does not exist
- apps = {app}
- works = {work}
- flows = {flow}
- client.return_value.lightningapp_instance_service_list_lightningapp_instances.return_value.lightningapps = apps
- client.return_value.lightningwork_service_list_lightningwork.return_value.lightningworks = works
- app.spec.flow_servers = flows
-
- result = runner.invoke(show.commands["logs"], ["My_FakeApp", "NonExistentComponent"])
-
- assert result.exit_code == 1
- assert "Component 'root.NonExistentComponent' does not exist in app My_FakeApp." in result.output
diff --git a/tests/tests_app/cli/test_connect.py b/tests/tests_app/cli/test_connect.py
deleted file mode 100644
index 458238a7826d5..0000000000000
--- a/tests/tests_app/cli/test_connect.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import json
-import os
-from unittest.mock import MagicMock
-
-import click
-import psutil
-import pytest
-from lightning.app import _PROJECT_ROOT
-from lightning.app.cli.connect.app import (
- _list_app_commands,
- _resolve_command_path,
- _retrieve_connection_to_an_app,
- connect_app,
- disconnect_app,
-)
-from lightning.app.utilities import cli_helpers
-from lightning.app.utilities.commands import base
-
-
-def monkeypatch_connection(monkeypatch, tmpdir, ppid):
- connection_path = os.path.join(tmpdir, ppid)
- monkeypatch.setattr("lightning.app.cli.connect.app._clean_lightning_connection", MagicMock())
- monkeypatch.setattr("lightning.app.cli.connect.app._PPID", ppid)
- monkeypatch.setattr("lightning.app.cli.connect.app._LIGHTNING_CONNECTION", tmpdir)
- monkeypatch.setattr("lightning.app.cli.connect.app._LIGHTNING_CONNECTION_FOLDER", connection_path)
- return connection_path
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=2)
-def test_connect_disconnect_local(tmpdir, monkeypatch):
- disconnect_app()
-
- with pytest.raises(Exception, match="Connection wasn't successful. Is your app localhost running ?"):
- connect_app("localhost")
-
- with open(os.path.join(os.path.dirname(__file__), "jsons/connect_1.json")) as f:
- data = json.load(f)
-
- data["paths"]["/command/command_with_client"]["post"]["cls_path"] = os.path.join(
- _PROJECT_ROOT,
- data["paths"]["/command/command_with_client"]["post"]["cls_path"],
- )
-
- messages = []
-
- disconnect_app()
-
- def fn(msg):
- messages.append(msg)
-
- monkeypatch.setattr(click, "echo", fn)
-
- response = MagicMock()
- response.status_code = 200
- response.json.return_value = data
- monkeypatch.setattr(cli_helpers.requests, "get", MagicMock(return_value=response))
- connect_app("localhost")
- assert _retrieve_connection_to_an_app() == ("localhost", None)
- command_path = _resolve_command_path("nested_command")
- assert not os.path.exists(command_path)
- command_path = _resolve_command_path("command_with_client")
- assert os.path.exists(command_path)
- messages = []
- connect_app("localhost")
- assert messages == ["You are connected to the local Lightning App."]
-
- messages = []
- disconnect_app()
- assert messages == ["You are disconnected from the local Lightning App."]
- messages = []
- disconnect_app()
- assert messages == [
- "You aren't connected to any Lightning App."
- " Please use `lightning_app connect app_name_or_id` to connect to one."
- ]
-
- assert _retrieve_connection_to_an_app() == (None, None)
-
-
-def test_connect_disconnect_cloud(tmpdir, monkeypatch):
- disconnect_app()
-
- ppid_1 = str(psutil.Process(os.getpid()).ppid())
- ppid_2 = "222"
-
- target_file = _resolve_command_path("command_with_client")
-
- if os.path.exists(target_file):
- os.remove(target_file)
-
- with open(os.path.join(os.path.dirname(__file__), "jsons/connect_1.json")) as f:
- data = json.load(f)
-
- data["paths"]["/command/command_with_client"]["post"]["cls_path"] = os.path.join(
- _PROJECT_ROOT,
- data["paths"]["/command/command_with_client"]["post"]["cls_path"],
- )
-
- messages = []
-
- def fn(msg):
- messages.append(msg)
-
- monkeypatch.setattr(click, "echo", fn)
-
- response = MagicMock()
- response.status_code = 200
- response.json.return_value = data
- monkeypatch.setattr(cli_helpers.requests, "get", MagicMock(return_value=response))
- project = MagicMock()
- project.project_id = "custom_project_name"
- monkeypatch.setattr(cli_helpers, "_get_project", MagicMock(return_value=project))
- client = MagicMock()
- lightningapps = MagicMock()
-
- app = MagicMock()
- app.display_name = "example"
- app.id = "1234"
-
- lightningapps.lightningapps = [app]
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = lightningapps
- monkeypatch.setattr(cli_helpers, "LightningClient", MagicMock(return_value=client))
-
- monkeypatch.setattr(base, "_get_project", MagicMock(return_value=project))
-
- artifact = MagicMock()
- artifact.filename = "commands/command_with_client.py"
- artifacts = MagicMock()
- artifacts.artifacts = [artifact]
- client.lightningapp_instance_service_list_lightningapp_instance_artifacts.return_value = artifacts
- monkeypatch.setattr(base, "LightningClient", MagicMock(return_value=client))
-
- with open(data["paths"]["/command/command_with_client"]["post"]["cls_path"], "rb") as f:
- response.content = f.read()
-
- connect_app("example")
- assert _retrieve_connection_to_an_app() == ("example", "1234")
- commands = _list_app_commands()
- assert commands == ["command with client", "command without client", "nested command"]
- command_path = _resolve_command_path("nested_command")
- assert not os.path.exists(command_path)
- command_path = _resolve_command_path("command_with_client")
- assert os.path.exists(command_path)
- messages = []
- connect_app("example")
- assert messages == ["You are already connected to the cloud Lightning App: example."]
-
- _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_2)
-
- messages = []
- connect_app("example")
- assert "The lightning App CLI now responds to app commands" in messages[0]
-
- messages = []
- disconnect_app()
- assert messages == ["You are disconnected from the cloud Lightning App: example."]
-
- _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1)
-
- messages = []
- disconnect_app()
- assert "You aren't connected to any Lightning App" in messages[0]
-
- messages = []
- disconnect_app()
- assert messages == [
- "You aren't connected to any Lightning App."
- " Please use `lightning_app connect app_name_or_id` to connect to one."
- ]
-
- assert _retrieve_connection_to_an_app() == (None, None)
diff --git a/tests/tests_app/cli/test_connect_data.py b/tests/tests_app/cli/test_connect_data.py
deleted file mode 100644
index 5778b12fa8855..0000000000000
--- a/tests/tests_app/cli/test_connect_data.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import sys
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app.cli.connect import data
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
-def test_connect_data_no_project(monkeypatch):
- from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(memberships=[])
- monkeypatch.setattr(data, "LightningClient", MagicMock(return_value=client))
-
- _error_and_exit = MagicMock()
- monkeypatch.setattr(data, "_error_and_exit", _error_and_exit)
-
- _get_project = MagicMock()
- _get_project.return_value = V1Membership(name="project-0", project_id="project-id-0")
- monkeypatch.setattr(data, "_get_project", _get_project)
-
- data.connect_data("imagenet", region="us-east-1", source="imagenet", destination="", project_name="project-0")
-
- _get_project.assert_called()
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
-def test_connect_data(monkeypatch):
- from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[
- V1Membership(name="project-0", project_id="project-id-0"),
- V1Membership(name="project-1", project_id="project-id-1"),
- V1Membership(name="project 2", project_id="project-id-2"),
- ]
- )
- monkeypatch.setattr(data, "LightningClient", MagicMock(return_value=client))
-
- _error_and_exit = MagicMock()
- monkeypatch.setattr(data, "_error_and_exit", _error_and_exit)
- data.connect_data("imagenet", region="us-east-1", source="imagenet", destination="", project_name="project-0")
-
- _error_and_exit.assert_called_with(
- "Only public S3 folders are supported for now. Please, open a Github issue with your use case."
- )
-
- data.connect_data("imagenet", region="us-east-1", source="s3://imagenet", destination="", project_name="project-0")
-
- client.data_connection_service_create_data_connection.assert_called_with(
- project_id="project-id-0",
- body=Create(
- name="imagenet",
- aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
- ),
- )
diff --git a/tests/tests_app/cli/test_cp.py b/tests/tests_app/cli/test_cp.py
deleted file mode 100644
index 303d93968a8c1..0000000000000
--- a/tests/tests_app/cli/test_cp.py
+++ /dev/null
@@ -1,242 +0,0 @@
-import os
-import sys
-from pathlib import PosixPath
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app.cli.commands import cp
-from lightning.app.cli.commands.cd import _CD_FILE, cd
-from lightning_cloud.openapi import (
- Externalv1Cluster,
- Externalv1LightningappInstance,
- V1CloudSpace,
- V1ClusterDriver,
- V1ClusterSpec,
- V1KubernetesClusterDriver,
- V1LightningappInstanceArtifact,
- V1LightningappInstanceSpec,
- V1ListCloudSpacesResponse,
- V1ListClustersResponse,
- V1ListLightningappInstanceArtifactsResponse,
- V1ListLightningappInstancesResponse,
- V1ListMembershipsResponse,
- V1ListProjectClusterBindingsResponse,
- V1Membership,
- V1ProjectClusterBinding,
- V1UploadProjectArtifactResponse,
-)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_local_to_remote(tmpdir, monkeypatch):
- error_and_exit = MagicMock()
- monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="project-0")]
- )
-
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
- lightningapps=[Externalv1LightningappInstance(name="app-name-0", id="app-id-0")]
- )
-
- client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id="my-cluster", cluster_name="my-cluster")]
- )
-
- result = MagicMock()
- result.get.return_value = V1UploadProjectArtifactResponse(urls=["http://foo.bar"])
- client.lightningapp_instance_service_upload_project_artifact.return_value = result
-
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
-
- assert cd("/", verify=False) == "/"
- cp.cp(str(tmpdir), "r:.")
- assert error_and_exit._mock_call_args_list[0].args[0] == "Uploading files at the project level isn't allowed yet."
-
- assert cd("/project-0/app-name-0", verify=False) == "/project-0/app-name-0"
- with open(f"{tmpdir}/a.txt", "w") as f:
- f.write("hello world !")
-
- file_uploader = MagicMock()
- monkeypatch.setattr(cp, "FileUploader", file_uploader)
-
- cp.cp(str(tmpdir), "r:.")
- assert file_uploader._mock_call_args[1]["name"] == f"{tmpdir}/a.txt"
-
- os.remove(_CD_FILE)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_cloud_to_local(tmpdir, monkeypatch):
- error_and_exit = MagicMock()
- monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="project-0")]
- )
-
- clusters = MagicMock()
- clusters.clusters = [MagicMock()]
- client.projects_service_list_project_cluster_bindings.return_value = clusters
-
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
- lightningapps=[Externalv1LightningappInstance(name="app-name-0", id="app-id-0")]
- )
-
- artifacts = [
- V1LightningappInstanceArtifact(filename=".file_1.txt", url="http://foo.bar/file_1.txt", size_bytes=123),
- V1LightningappInstanceArtifact(
- filename=".folder_1/file_2.txt", url="http://foo.bar/folder_1/file_2.txt", size_bytes=123
- ),
- V1LightningappInstanceArtifact(
- filename=".folder_2/folder_3/file_3.txt", url="http://foo.bar/folder_2/folder_3/file_3.txt", size_bytes=123
- ),
- V1LightningappInstanceArtifact(
- filename=".folder_4/file_4.txt", url="http://foo.bar/folder_4/file_4.txt", size_bytes=123
- ),
- ]
-
- client.lightningapp_instance_service_list_project_artifacts.return_value = (
- V1ListLightningappInstanceArtifactsResponse(artifacts=artifacts)
- )
-
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
-
- assert cd("/", verify=False) == "/"
- cp.cp(str(tmpdir), "r:.")
- assert error_and_exit._mock_call_args_list[0].args[0] == "Uploading files at the project level isn't allowed yet."
-
- assert cd("/project-0/app-name-0", verify=False) == "/project-0/app-name-0"
-
- download_file = MagicMock()
- monkeypatch.setattr(cp, "_download_file", download_file)
-
- cp.cp("r:.", str(tmpdir))
-
- assert len(download_file.call_args_list) == 4
- for i, call in enumerate(download_file.call_args_list):
- assert call.args[0] == PosixPath(tmpdir / artifacts[i].filename)
- assert call.args[1] == artifacts[i].url
-
- # cleanup
- os.remove(_CD_FILE)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_sanitize_path():
- path, is_remote = cp._sanitize_path("r:default-project", "/")
- assert path == "/default-project"
- assert is_remote
-
- path, _ = cp._sanitize_path("r:foo", "/default-project")
- assert path == "/default-project/foo"
-
- path, _ = cp._sanitize_path("foo", "/default-project")
- assert path == "foo"
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_zip_arg_order(monkeypatch):
- assert cd("/", verify=False) == "/"
-
- error_and_exit = MagicMock()
- monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=MagicMock()))
- cp.cp("./my-resource", "r:./my-resource", zip=True)
- error_and_exit.assert_called_once()
- assert "Zipping uploads isn't supported yet" in error_and_exit.call_args_list[0].args[0]
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_zip_src_path_too_short(monkeypatch):
- error_and_exit = MagicMock()
- monkeypatch.setattr(cp, "_error_and_exit", error_and_exit)
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=MagicMock()))
- cp.cp("r:/my-project", ".", zip=True)
- error_and_exit.assert_called_once()
- assert "The source path must be at least two levels deep" in error_and_exit.call_args_list[0].args[0]
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_zip_remote_to_local_cloudspace_artifact(monkeypatch):
- assert cd("/", verify=False) == "/"
-
- token_getter = MagicMock()
- token_getter._get_api_token.return_value = "my-token"
- monkeypatch.setattr(cp, "_AuthTokenGetter", MagicMock(return_value=token_getter))
-
- client = MagicMock()
- client.cluster_service_list_clusters.return_value = V1ListClustersResponse(
- default_cluster="my-cluster",
- clusters=[
- Externalv1Cluster(
- id="my-cluster",
- spec=V1ClusterSpec(
- driver=V1ClusterDriver(kubernetes=V1KubernetesClusterDriver(root_domain_name="my-domain"))
- ),
- )
- ],
- )
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="my-project", project_id="my-project-id")]
- )
- client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=[V1CloudSpace(name="my-cloudspace", id="my-cloudspace-id")],
- )
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
-
- download_file = MagicMock()
- monkeypatch.setattr(cp, "_download_file", download_file)
-
- cloudspace_artifact = "r:/my-project/my-cloudspace/my-artifact"
- cp.cp(cloudspace_artifact, ".", zip=True)
-
- download_file.assert_called_once()
- assert download_file.call_args_list[0].args[0] == "./my-artifact.zip"
- assert (
- download_file.call_args_list[0].args[1]
- == "https://storage.my-domain/v1/projects/my-project-id/artifacts/download"
- + "?prefix=/cloudspaces/my-cloudspace-id/my-artifact&token=my-token"
- )
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_cp_zip_remote_to_local_app_artifact(monkeypatch):
- assert cd("/", verify=False) == "/"
-
- token_getter = MagicMock()
- token_getter._get_api_token.return_value = "my-token"
- monkeypatch.setattr(cp, "_AuthTokenGetter", MagicMock(return_value=token_getter))
-
- client = MagicMock()
- client.cluster_service_get_cluster.return_value = Externalv1Cluster(
- spec=V1ClusterSpec(driver=V1ClusterDriver(kubernetes=V1KubernetesClusterDriver(root_domain_name="my-domain")))
- )
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="my-project", project_id="my-project-id")]
- )
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = V1ListLightningappInstancesResponse(
- lightningapps=[
- Externalv1LightningappInstance(
- name="my-app", id="my-app-id", spec=V1LightningappInstanceSpec(cluster_id="my-cluster")
- )
- ]
- )
- monkeypatch.setattr(cp, "LightningClient", MagicMock(return_value=client))
-
- download_file = MagicMock()
- monkeypatch.setattr(cp, "_download_file", download_file)
-
- app_artifact = "r:/my-project/my-app/my-artifact"
- cp.cp(app_artifact, ".", zip=True)
-
- download_file.assert_called_once()
- assert download_file.call_args_list[0].args[0] == "./my-artifact.zip"
- assert (
- download_file.call_args_list[0].args[1]
- == "https://storage.my-domain/v1/projects/my-project-id/artifacts/download"
- + "?prefix=/lightningapps/my-app-id/my-artifact&token=my-token"
- )
diff --git a/tests/tests_app/cli/test_ls.py b/tests/tests_app/cli/test_ls.py
deleted file mode 100644
index 06b18fff5d5ae..0000000000000
--- a/tests/tests_app/cli/test_ls.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import os
-import sys
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app.cli.commands import cd, ls
-from lightning_cloud.openapi import (
- Externalv1LightningappInstance,
- V1LightningappInstanceArtifact,
- V1ListCloudSpacesResponse,
- V1ListLightningappInstanceArtifactsResponse,
- V1ListLightningappInstancesResponse,
- V1ListMembershipsResponse,
- V1Membership,
-)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_ls(monkeypatch):
- """This test validates ls behaves as expected."""
- if os.path.exists(cd._CD_FILE):
- os.remove(cd._CD_FILE)
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[
- V1Membership(name="project-0", project_id="project-id-0"),
- V1Membership(name="project-1", project_id="project-id-1"),
- V1Membership(name="project 2", project_id="project-id-2"),
- ]
- )
-
- client.lightningapp_instance_service_list_lightningapp_instances().get.return_value = (
- V1ListLightningappInstancesResponse(
- lightningapps=[
- Externalv1LightningappInstance(name="app-name-0", id="app-id-0"),
- Externalv1LightningappInstance(name="app-name-1", id="app-id-1"),
- Externalv1LightningappInstance(name="app name 2", id="app-id-1"),
- ]
- )
- )
-
- client.cloud_space_service_list_cloud_spaces().get.return_value = V1ListCloudSpacesResponse(cloudspaces=[])
-
- clusters = MagicMock()
- clusters.clusters = [MagicMock()]
- client.projects_service_list_project_cluster_bindings.return_value = clusters
-
- def fn(*args, prefix, **kwargs):
- splits = [split for split in prefix.split("/") if split != ""]
- if len(splits) == 2:
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="file_1.txt"),
- V1LightningappInstanceArtifact(filename="folder_1/file_2.txt"),
- V1LightningappInstanceArtifact(filename="folder_2/folder_3/file_3.txt"),
- V1LightningappInstanceArtifact(filename="folder_2/file_4.txt"),
- ]
- )
- if splits[-1] == "folder_1":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[V1LightningappInstanceArtifact(filename="file_2.txt")]
- )
- if splits[-1] == "folder_2":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="folder_3/file_3.txt"),
- V1LightningappInstanceArtifact(filename="file_4.txt"),
- ]
- )
- if splits[-1] == "folder_3":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="file_3.txt"),
- ]
- )
- return None
-
- client.lightningapp_instance_service_list_project_artifacts = fn
-
- monkeypatch.setattr(ls, "LightningClient", MagicMock(return_value=client))
-
- assert ls.ls() == ["project-0", "project-1", "project 2"]
- assert cd.cd("project-0", verify=False) == "/project-0"
-
- assert ls.ls() == ["app name 2", "app-name-0", "app-name-1"]
- assert f"/project-0{os.sep}app-name-1" == cd.cd("app-name-1", verify=False)
- assert ls.ls() == ["file_1.txt", "folder_1", "folder_2"]
- assert f"/project-0{os.sep}app-name-1{os.sep}folder_1" == cd.cd("folder_1", verify=False)
- assert ls.ls() == ["file_2.txt"]
- assert f"/project-0{os.sep}app-name-1{os.sep}folder_2" == cd.cd("../folder_2", verify=False)
- assert ls.ls() == ["folder_3", "file_4.txt"]
- assert f"/project-0{os.sep}app-name-1{os.sep}folder_2{os.sep}folder_3" == cd.cd("folder_3", verify=False)
- assert ls.ls() == ["file_3.txt"]
-
- assert cd.cd("/project 2", verify=False) == "/project 2"
- assert ls.ls() == ["app name 2", "app-name-0", "app-name-1"]
- assert cd.cd("app name 2", verify=False) == "/project 2/app name 2"
- assert ls.ls() == ["file_1.txt", "folder_1", "folder_2"]
-
- os.remove(cd._CD_FILE)
diff --git a/tests/tests_app/cli/test_rm.py b/tests/tests_app/cli/test_rm.py
deleted file mode 100644
index 9751f9f81d873..0000000000000
--- a/tests/tests_app/cli/test_rm.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import sys
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app.cli.commands import cd, ls, rm
-from lightning_cloud.openapi import (
- Externalv1LightningappInstance,
- V1LightningappInstanceArtifact,
- V1ListCloudSpacesResponse,
- V1ListLightningappInstanceArtifactsResponse,
- V1ListLightningappInstancesResponse,
- V1ListMembershipsResponse,
- V1Membership,
-)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
-def test_rm(monkeypatch):
- """This test validates rm behaves as expected."""
- if os.path.exists(cd._CD_FILE):
- os.remove(cd._CD_FILE)
-
- client = MagicMock()
- client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[
- V1Membership(name="project-0", project_id="project-id-0"),
- V1Membership(name="project-1", project_id="project-id-1"),
- V1Membership(name="project 2", project_id="project-id-2"),
- ]
- )
-
- client.lightningapp_instance_service_list_lightningapp_instances().get.return_value = (
- V1ListLightningappInstancesResponse(
- lightningapps=[
- Externalv1LightningappInstance(name="app-name-0", id="app-id-0"),
- Externalv1LightningappInstance(name="app-name-1", id="app-id-1"),
- Externalv1LightningappInstance(name="app name 2", id="app-id-1"),
- ]
- )
- )
-
- client.cloud_space_service_list_cloud_spaces().get.return_value = V1ListCloudSpacesResponse(cloudspaces=[])
-
- clusters = MagicMock()
- clusters.clusters = [MagicMock()]
- client.projects_service_list_project_cluster_bindings.return_value = clusters
-
- def fn(*args, prefix, **kwargs):
- splits = [split for split in prefix.split("/") if split != ""]
- if len(splits) == 2:
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="file_1.txt"),
- V1LightningappInstanceArtifact(filename="folder_1/file_2.txt"),
- V1LightningappInstanceArtifact(filename="folder_2/folder_3/file_3.txt"),
- V1LightningappInstanceArtifact(filename="folder_2/file_4.txt"),
- ]
- )
- if splits[-1] == "folder_1":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[V1LightningappInstanceArtifact(filename="file_2.txt")]
- )
- if splits[-1] == "folder_2":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="folder_3/file_3.txt"),
- V1LightningappInstanceArtifact(filename="file_4.txt"),
- ]
- )
- if splits[-1] == "folder_3":
- return V1ListLightningappInstanceArtifactsResponse(
- artifacts=[
- V1LightningappInstanceArtifact(filename="file_3.txt"),
- ]
- )
- return None
-
- client.lightningapp_instance_service_list_project_artifacts = fn
-
- client.lightningapp_instance_service_delete_project_artifact = MagicMock()
-
- monkeypatch.setattr(rm, "LightningClient", MagicMock(return_value=client))
- monkeypatch.setattr(ls, "LightningClient", MagicMock(return_value=client))
-
- assert ls.ls() == ["project-0", "project-1", "project 2"]
- assert cd.cd("project-0", verify=False) == "/project-0"
-
- assert f"/project-0{os.sep}app-name-1" == cd.cd("app-name-1", verify=False)
-
- assert f"/project-0{os.sep}app-name-1{os.sep}folder_1" == cd.cd("folder_1", verify=False)
-
- rm.rm("file_2.txt")
-
- kwargs = client.lightningapp_instance_service_delete_project_artifact._mock_call_args.kwargs
- assert kwargs["project_id"] == "project-id-0"
- assert kwargs["filename"] == "/lightningapps/app-id-1/folder_1/file_2.txt"
-
- os.remove(cd._CD_FILE)
diff --git a/tests/tests_app/cli/test_run_app.py b/tests/tests_app/cli/test_run_app.py
deleted file mode 100644
index d570e618b7226..0000000000000
--- a/tests/tests_app/cli/test_run_app.py
+++ /dev/null
@@ -1,222 +0,0 @@
-import logging
-import os
-from pathlib import Path
-from unittest import mock
-
-import click
-import lightning.app.core.constants as constants
-import pytest
-from click.testing import CliRunner
-from lightning.app import LightningApp
-from lightning.app.cli.lightning_cli import _run_app, run_app
-from lightning.app.runners.runtime_type import RuntimeType
-from lightning.app.utilities.app_helpers import convert_print_to_logger_info
-
-from tests_app import _PROJECT_ROOT
-
-
-@mock.patch("click.launch")
-@pytest.mark.parametrize("open_ui", [True, False])
-def test_lightning_run_app(lauch_mock: mock.MagicMock, open_ui, caplog, monkeypatch):
- """This test validates the command is runned properly and the LightningApp method is being executed."""
- monkeypatch.setattr("lightning.app._logger", logging.getLogger())
-
- original_method = LightningApp._run
-
- @convert_print_to_logger_info
- def _lightning_app_run_and_logging(self, *args, **kwargs):
- original_method(self, *args, **kwargs)
- print("1" if open_ui else "0")
- print(self)
-
- with caplog.at_level(logging.INFO):
- with mock.patch("lightning.app.LightningApp._run", _lightning_app_run_and_logging):
- runner = CliRunner()
- pytest_env = os.environ.pop("PYTEST_CURRENT_TEST")
- try:
- result = runner.invoke(
- run_app,
- [
- os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- "--blocking",
- "False",
- "--open-ui",
- str(open_ui),
- ],
- catch_exceptions=False,
- )
- finally:
- os.environ["PYTEST_CURRENT_TEST"] = pytest_env
- # capture logs.
- if open_ui:
- # Get the designated port
- port = constants.APP_SERVER_PORT
-
- lauch_mock.assert_called_with(f"http://127.0.0.1:{port}/view")
- else:
- lauch_mock.assert_not_called()
- assert result.exit_code == 0
- assert len(caplog.messages) == 4
- assert bool(int(caplog.messages[0])) is open_ui
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_URL": "https://beta.lightning.ai"})
-@mock.patch("lightning.app.cli.lightning_cli.dispatch")
-@pytest.mark.parametrize("open_ui", [True, False])
-def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog, monkeypatch):
- """This test validates the command has ran properly when --cloud argument is passed.
-
- It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the
- call to `dispatch` for the right arguments.
-
- """
- monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
-
- with caplog.at_level(logging.INFO):
- _run_app(
- file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- cloud=True,
- without_server=False,
- name="",
- blocking=False,
- open_ui=open_ui,
- no_cache=True,
- env=("FOO=bar",),
- secret=("BAR=my-secret",),
- run_app_comment_commands=False,
- enable_basic_auth="",
- )
-
- # Get the designated port
- port = constants.APP_SERVER_PORT
-
- # capture logs.
- # TODO(yurij): refactor the test, check if the actual HTTP request is being sent and that the proper admin
- # page is being opened
- mock_dispatch.assert_called_with(
- Path(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py")),
- RuntimeType.CLOUD,
- start_server=True,
- blocking=False,
- open_ui=open_ui,
- name="",
- no_cache=True,
- env_vars={"FOO": "bar"},
- secrets={"BAR": "my-secret"},
- run_app_comment_commands=False,
- enable_basic_auth="",
- port=port,
- )
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_URL": "https://beta.lightning.ai"})
-@mock.patch("lightning.app.cli.lightning_cli.dispatch")
-@pytest.mark.parametrize("open_ui", [True, False])
-def test_lightning_run_app_cloud_with_run_app_commands(mock_dispatch: mock.MagicMock, open_ui, caplog, monkeypatch):
- """This test validates the command has ran properly when --cloud argument is passed.
-
- It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the
- call to `dispatch` for the right arguments.
-
- """
- monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
-
- with caplog.at_level(logging.INFO):
- _run_app(
- file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- cloud=True,
- without_server=False,
- name="",
- blocking=False,
- open_ui=open_ui,
- no_cache=True,
- env=("FOO=bar",),
- secret=("BAR=my-secret",),
- run_app_comment_commands=True,
- enable_basic_auth="",
- )
-
- # Get the designated port
- port = constants.APP_SERVER_PORT
-
- # capture logs.
- # TODO(yurij): refactor the test, check if the actual HTTP request is being sent and that the proper admin
- # page is being opened
- mock_dispatch.assert_called_with(
- Path(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py")),
- RuntimeType.CLOUD,
- start_server=True,
- blocking=False,
- open_ui=open_ui,
- name="",
- no_cache=True,
- env_vars={"FOO": "bar"},
- secrets={"BAR": "my-secret"},
- run_app_comment_commands=True,
- enable_basic_auth="",
- port=port,
- )
-
-
-def test_lightning_run_app_secrets(monkeypatch):
- """Validates that running apps only supports the `--secrets` argument if the `--cloud` argument is passed."""
- monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
-
- with pytest.raises(click.exceptions.ClickException):
- _run_app(
- file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- cloud=False,
- without_server=False,
- name="",
- blocking=False,
- open_ui=False,
- no_cache=True,
- env=(),
- secret=("FOO=my-secret"),
- run_app_comment_commands=False,
- enable_basic_auth="",
- )
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_URL": "https://beta.lightning.ai"})
-@mock.patch("lightning.app.cli.lightning_cli.dispatch")
-def test_lightning_run_app_enable_basic_auth_passed(mock_dispatch: mock.MagicMock, caplog, monkeypatch):
- """This test just validates the command has ran properly when --enable-basic-auth argument is passed.
-
- It checks the call to `dispatch` for the right arguments.
-
- """
- monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
-
- with caplog.at_level(logging.INFO):
- _run_app(
- file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- cloud=True,
- without_server=False,
- name="",
- blocking=False,
- open_ui=False,
- no_cache=True,
- env=("FOO=bar",),
- secret=("BAR=my-secret",),
- run_app_comment_commands=False,
- enable_basic_auth="username:password",
- )
-
- # Get the designated port
- port = constants.APP_SERVER_PORT
-
- mock_dispatch.assert_called_with(
- Path(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py")),
- RuntimeType.CLOUD,
- start_server=True,
- blocking=False,
- open_ui=False,
- name="",
- no_cache=True,
- env_vars={"FOO": "bar"},
- secrets={"BAR": "my-secret"},
- run_app_comment_commands=False,
- enable_basic_auth="username:password",
- port=port,
- )
diff --git a/tests/tests_app/components/__init__.py b/tests/tests_app/components/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/components/database/test_client_server.py b/tests/tests_app/components/database/test_client_server.py
deleted file mode 100644
index decc610130f40..0000000000000
--- a/tests/tests_app/components/database/test_client_server.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import os
-import sys
-import tempfile
-import time
-import traceback
-from pathlib import Path
-from time import sleep
-from typing import List, Optional
-from uuid import uuid4
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.components.database import Database, DatabaseClient
-from lightning.app.components.database.utilities import _GeneralModel, _pydantic_column_type
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.utilities.imports import _is_sqlmodel_available
-
-if _is_sqlmodel_available():
- from sqlalchemy import Column
- from sqlmodel import Field, SQLModel
-
- class Secret(SQLModel):
- name: str
- value: str
-
- class TestConfig(SQLModel, table=True):
- __table_args__ = {"extend_existing": True}
-
- id: Optional[int] = Field(default=None, primary_key=True)
- name: str
- secrets: List[Secret] = Field(..., sa_column=Column(_pydantic_column_type(List[Secret])))
-
-
-class Work(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.done = False
-
- def run(self, client: DatabaseClient):
- rows = client.select_all()
- while len(rows) == 0:
- print(rows)
- sleep(0.1)
- rows = client.select_all()
- self.done = True
-
-
-@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
-def test_client_server():
- database_path = Path("database.db").resolve()
- if database_path.exists():
- os.remove(database_path)
-
- secrets = [Secret(name="example", value="secret")]
-
- general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a")
- assert general.cls_name == "TestConfig"
- assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}'
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self._token = str(uuid4())
- self.db = Database(models=[TestConfig])
- self._client = None
- self.tracker = None
- self.work = Work()
-
- def run(self):
- self.db.run(token=self._token)
-
- if not self.db.alive():
- return
-
- if not self._client:
- self._client = DatabaseClient(model=TestConfig, db_url=self.db.url, token=self._token)
-
- assert self._client
-
- self.work.run(self._client)
-
- if self.tracker is None:
- self._client.insert(TestConfig(name="name", secrets=secrets))
- elem = self._client.select_all(TestConfig)[0]
- assert elem.name == "name"
- self.tracker = "update"
- assert isinstance(elem.secrets[0], Secret)
- assert elem.secrets[0].name == "example"
- assert elem.secrets[0].value == "secret"
-
- elif self.tracker == "update":
- elem = self._client.select_all(TestConfig)[0]
- elem.name = "new_name"
- self._client.update(elem)
-
- elem = self._client.select_all(TestConfig)[0]
- assert elem.name == "new_name"
- self.tracker = "delete"
-
- elif self.tracker == "delete" and self.work.done:
- self.work.stop()
-
- elem = self._client.select_all(TestConfig)[0]
- elem = self._client.delete(elem)
-
- assert not self._client.select_all(TestConfig)
- self._client.insert(TestConfig(name="name", secrets=secrets))
-
- assert self._client.select_all(TestConfig)
- self.stop()
-
- app = LightningApp(Flow())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- database_path = Path("database.db").resolve()
- if database_path.exists():
- os.remove(database_path)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
-@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
-def test_work_database_restart():
- id = str(uuid4()).split("-")[0]
-
- class Flow(LightningFlow):
- def __init__(self, db_root=".", restart=False):
- super().__init__()
- self._db_filename = os.path.join(db_root, id)
- self.db = Database(db_filename=self._db_filename, models=[TestConfig])
- self._client = None
- self.restart = restart
-
- def run(self):
- self.db.run()
-
- if not self.db.alive():
- return
- if not self._client:
- self._client = DatabaseClient(self.db.db_url, None, model=TestConfig)
-
- if not self.restart:
- self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
- self.stop()
- else:
- assert os.path.exists(self._db_filename)
- assert len(self._client.select_all()) == 1
- self.stop()
-
- with tempfile.TemporaryDirectory() as tmpdir:
- app = LightningApp(Flow(db_root=tmpdir))
- MultiProcessRuntime(app).dispatch()
-
- # Note: Waiting for SIGTERM signal to be handled
- sleep(2)
-
- app = LightningApp(Flow(db_root=tmpdir, restart=True))
- MultiProcessRuntime(app).dispatch()
-
- # Note: Waiting for SIGTERM signal to be handled
- sleep(2)
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
-@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
-def test_work_database_periodic_store():
- id = str(uuid4()).split("-")[0]
-
- class Flow(LightningFlow):
- def __init__(self, db_root="."):
- super().__init__()
- self._db_filename = os.path.join(db_root, id)
- self.db = Database(db_filename=self._db_filename, models=[TestConfig], store_interval=1)
- self._client = None
- self._start_time = None
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
- self.db.run()
-
- if not self.db.alive():
- return
-
- if not self._client:
- self._client = DatabaseClient(self.db.db_url, None, model=TestConfig)
-
- if self._start_time is None:
- self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
- self._start_time = time.time()
-
- elif (time.time() - self._start_time) > 2:
- assert os.path.exists(self._db_filename)
- assert len(self._client.select_all()) == 1
- self.stop()
-
- try:
- with tempfile.TemporaryDirectory() as tmpdir:
- app = LightningApp(Flow(tmpdir))
- MultiProcessRuntime(app).dispatch()
- except Exception:
- print(traceback.print_exc())
diff --git a/tests/tests_app/components/multi_node/__init__.py b/tests/tests_app/components/multi_node/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/components/multi_node/test_base.py b/tests/tests_app/components/multi_node/test_base.py
deleted file mode 100644
index a39802f855378..0000000000000
--- a/tests/tests_app/components/multi_node/test_base.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from re import escape
-from unittest import mock
-
-import pytest
-from lightning.app import CloudCompute, LightningWork
-from lightning.app.components import MultiNode
-from lightning_utilities.test.warning import no_warning_call
-
-
-def test_multi_node_warn_running_locally():
- class Work(LightningWork):
- def run(self):
- pass
-
- with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=2, ...)` but ")):
- MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
-
- with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
- MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))
-
-
-@mock.patch("lightning.app.components.multi_node.base.is_running_in_cloud", mock.Mock(return_value=True))
-def test_multi_node_separate_cloud_computes():
- class Work(LightningWork):
- def run(self):
- pass
-
- m = MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
-
- assert len({w.cloud_compute._internal_id for w in m.ws}) == len(m.ws)
diff --git a/tests/tests_app/components/multi_node/test_fabric.py b/tests/tests_app/components/multi_node/test_fabric.py
deleted file mode 100644
index 5fbcd20123a5e..0000000000000
--- a/tests/tests_app/components/multi_node/test_fabric.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import os
-from copy import deepcopy
-from functools import partial
-from unittest import mock
-
-import lightning.fabric as lf
-import pytest
-from lightning.app.components.multi_node.fabric import _FabricRunExecutor
-from lightning_utilities.core.imports import module_available
-from lightning_utilities.test.warning import no_warning_call
-
-
-class DummyFabric(lf.Fabric):
- def run(self):
- pass
-
-
-def dummy_callable(**kwargs):
- fabric = DummyFabric(**kwargs)
- return fabric._all_passed_kwargs
-
-
-def dummy_init(self, **kwargs):
- self._all_passed_kwargs = kwargs
-
-
-def _get_args_after_tracer_injection(**kwargs):
- with mock.patch.object(lf.Fabric, "__init__", dummy_init):
- ret_val = _FabricRunExecutor.run(
- local_rank=0,
- work_run=partial(dummy_callable, **kwargs),
- main_address="1.2.3.4",
- main_port=5,
- node_rank=6,
- num_nodes=7,
- nprocs=8,
- )
- env_vars = deepcopy(os.environ)
- return ret_val, env_vars
-
-
-def check_lightning_fabric_mps():
- if module_available("lightning.fabric"):
- return lf.accelerators.MPSAccelerator.is_available()
- return False
-
-
-@pytest.mark.skipif(not check_lightning_fabric_mps(), reason="Fabric not available or mps not available")
-@pytest.mark.parametrize(
- ("accelerator_given", "accelerator_expected"), [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]
-)
-def test_fabric_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
- warning_str = r"Forcing `accelerator=cpu` as MPS does not support distributed training."
- if accelerator_expected != accelerator_given:
- warning_context = pytest.warns(UserWarning, match=warning_str)
- else:
- warning_context = no_warning_call(match=warning_str + "*")
-
- with warning_context:
- ret_val, _ = _get_args_after_tracer_injection(accelerator=accelerator_given)
- assert ret_val["accelerator"] == accelerator_expected
-
-
-@pytest.mark.parametrize(
- ("args_given", "args_expected"),
- [
- ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
- ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
- ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
- ],
-)
-@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test")
-def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):
- # ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
- if lf.accelerators.MPSAccelerator.is_available():
- args_expected["accelerator"] = "cpu"
-
- ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
-
- for k, v in args_expected.items():
- assert ret_val[k] == v
-
- assert env_vars["MASTER_ADDR"] == "1.2.3.4"
- assert env_vars["MASTER_PORT"] == "5"
- assert env_vars["GROUP_RANK"] == "6"
- assert env_vars["RANK"] == str(0 + 6 * 8)
- assert env_vars["LOCAL_RANK"] == "0"
- assert env_vars["WORLD_SIZE"] == str(7 * 8)
- assert env_vars["LOCAL_WORLD_SIZE"] == "8"
- assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
- assert env_vars["LT_CLI_USED"] == "1"
-
-
-@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available")
-def test_run_executor_invalid_strategy_instances():
- with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
- _, _ = _get_args_after_tracer_injection(strategy=lf.strategies.DDPStrategy(start_method="spawn"))
diff --git a/tests/tests_app/components/multi_node/test_trainer.py b/tests/tests_app/components/multi_node/test_trainer.py
deleted file mode 100644
index 1258cbe0176e0..0000000000000
--- a/tests/tests_app/components/multi_node/test_trainer.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import os
-from copy import deepcopy
-from functools import partial
-from unittest import mock
-
-import pytest
-import pytorch_lightning as pl
-from lightning.app.components.multi_node.trainer import _LightningTrainerRunExecutor
-from lightning_utilities.core.imports import module_available
-from lightning_utilities.test.warning import no_warning_call
-
-
-def dummy_callable(**kwargs):
- t = pl.Trainer(**kwargs)
- return t._all_passed_kwargs
-
-
-def dummy_init(self, **kwargs):
- self._all_passed_kwargs = kwargs
-
-
-def _get_args_after_tracer_injection(**kwargs):
- with mock.patch.object(pl.Trainer, "__init__", dummy_init):
- ret_val = _LightningTrainerRunExecutor.run(
- local_rank=0,
- work_run=partial(dummy_callable, **kwargs),
- main_address="1.2.3.4",
- main_port=5,
- node_rank=6,
- num_nodes=7,
- nprocs=8,
- )
- env_vars = deepcopy(os.environ)
- return ret_val, env_vars
-
-
-def check_lightning_pytorch_and_mps():
- if module_available("pytorch_lightning"):
- return pl.accelerators.MPSAccelerator.is_available()
- return False
-
-
-@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required")
-@pytest.mark.parametrize(
- ("accelerator_given", "accelerator_expected"), [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]
-)
-def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
- warning_str = r"Forcing `accelerator=cpu` as MPS does not support distributed training."
- if accelerator_expected != accelerator_given:
- warning_context = pytest.warns(UserWarning, match=warning_str)
- else:
- warning_context = no_warning_call(match=warning_str + "*")
-
- with warning_context:
- ret_val, _ = _get_args_after_tracer_injection(accelerator=accelerator_given)
- assert ret_val["accelerator"] == accelerator_expected
-
-
-@pytest.mark.parametrize(
- ("args_given", "args_expected"),
- [
- ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
- ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
- ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
- ],
-)
-@pytest.mark.skipif(not module_available("torch"), reason="PyTorch is not available")
-def test_trainer_run_executor_arguments_choices(
- args_given: dict,
- args_expected: dict,
-):
- if pl.accelerators.MPSAccelerator.is_available():
- args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately
-
- ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
-
- for k, v in args_expected.items():
- assert ret_val[k] == v
-
- assert env_vars["MASTER_ADDR"] == "1.2.3.4"
- assert env_vars["MASTER_PORT"] == "5"
- assert env_vars["GROUP_RANK"] == "6"
- assert env_vars["RANK"] == str(0 + 6 * 8)
- assert env_vars["LOCAL_RANK"] == "0"
- assert env_vars["WORLD_SIZE"] == str(7 * 8)
- assert env_vars["LOCAL_WORLD_SIZE"] == "8"
- assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
-
-
-@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available")
-def test_trainer_run_executor_invalid_strategy_instances():
- with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
- _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPStrategy(start_method="spawn"))
diff --git a/tests/tests_app/components/python/scripts/a.py b/tests/tests_app/components/python/scripts/a.py
deleted file mode 100644
index 414be73ce4a51..0000000000000
--- a/tests/tests_app/components/python/scripts/a.py
+++ /dev/null
@@ -1 +0,0 @@
-print("Hello World !")
diff --git a/tests/tests_app/components/python/scripts/b.py b/tests/tests_app/components/python/scripts/b.py
deleted file mode 100644
index 53254da11906b..0000000000000
--- a/tests/tests_app/components/python/scripts/b.py
+++ /dev/null
@@ -1,3 +0,0 @@
-import sys
-
-print(sys.argv)
diff --git a/tests/tests_app/components/python/scripts/c.py b/tests/tests_app/components/python/scripts/c.py
deleted file mode 100644
index eb56e6de70a61..0000000000000
--- a/tests/tests_app/components/python/scripts/c.py
+++ /dev/null
@@ -1,4 +0,0 @@
-import os
-
-if __name__ == "__main__":
- assert int(os.environ["VARIABLE"]) == 0
diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py
deleted file mode 100644
index f0d14ebe518d8..0000000000000
--- a/tests/tests_app/components/python/test_python.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import os
-import tarfile
-
-import pytest
-from lightning.app.components.python import PopenPythonScript, TracerPythonScript
-from lightning.app.components.python.tracer import Code
-from lightning.app.storage.drive import Drive
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.testing.testing import run_work_isolated
-from lightning.app.utilities.component import _set_work_context
-from lightning.app.utilities.enum import CacheCallsKeys
-from tests_app import _PROJECT_ROOT
-
-COMPONENTS_SCRIPTS_FOLDER = str(os.path.join(_PROJECT_ROOT, "tests/tests_app/components/python/scripts/"))
-
-
-def test_non_existing_python_script():
- match = "tests/components/python/scripts/0.py"
- with pytest.raises(FileNotFoundError, match=match):
- python_script = PopenPythonScript(match)
- run_work_isolated(python_script)
- assert not python_script.has_started
-
- python_script = TracerPythonScript(match, raise_exception=False)
- run_work_isolated(python_script)
- assert python_script.has_failed
-
-
-def test_simple_python_script():
- python_script = PopenPythonScript(COMPONENTS_SCRIPTS_FOLDER + "a.py")
- run_work_isolated(python_script)
- assert python_script.has_succeeded
-
- python_script = TracerPythonScript(COMPONENTS_SCRIPTS_FOLDER + "a.py")
- run_work_isolated(python_script)
- assert python_script.has_succeeded
-
-
-def test_simple_popen_python_script_with_kwargs():
- python_script = PopenPythonScript(
- COMPONENTS_SCRIPTS_FOLDER + "b.py",
- script_args="--arg_0=hello --arg_1=world",
- )
- run_work_isolated(python_script)
- assert python_script.has_succeeded
-
-
-@_RunIf(skip_windows=True)
-def test_popen_python_script_failure():
- python_script = PopenPythonScript(
- COMPONENTS_SCRIPTS_FOLDER + "c.py",
- env={"VARIABLE": "1"},
- raise_exception=False,
- )
- run_work_isolated(python_script)
- assert python_script.has_failed
- assert "Exception(self.exit_code)" in python_script.status.message
-
-
-def test_tracer_python_script_with_kwargs():
- python_script = TracerPythonScript(
- COMPONENTS_SCRIPTS_FOLDER + "b.py",
- script_args="--arg_0=hello --arg_1=world",
- raise_exception=False,
- )
- run_work_isolated(python_script)
- assert python_script.has_succeeded
-
- python_script = TracerPythonScript(
- COMPONENTS_SCRIPTS_FOLDER + "c.py",
- env={"VARIABLE": "1"},
- raise_exception=False,
- )
- run_work_isolated(python_script)
- assert python_script.has_failed
-
-
-def test_tracer_component_with_code():
- """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments
- are cleaned."""
-
- drive = Drive("lit://code")
- drive.component_name = "something"
- code = Code(drive=drive, name="sample.tar.gz")
-
- with open("file.py", "w") as f:
- f.write('raise Exception("An error")')
-
- with tarfile.open("sample.tar.gz", "w:gz") as tar:
- tar.add("file.py")
-
- drive.put("sample.tar.gz")
- os.remove("file.py")
- os.remove("sample.tar.gz")
-
- python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code)
- run_work_isolated(python_script, params={"--a": "1"}, restart_count=0)
- assert "An error" in python_script.status.message
-
- with open("file.py", "w") as f:
- f.write("import sys\n")
- f.write("print(sys.argv)\n")
-
- with tarfile.open("sample.tar.gz", "w:gz") as tar:
- tar.add("file.py")
-
- _set_work_context()
- drive.put("sample.tar.gz")
- os.remove("file.py")
- os.remove("sample.tar.gz")
-
- with open("file.py", "w") as f:
- f.write('raise Exception("An error")')
-
- call_hash = python_script._calls[CacheCallsKeys.LATEST_CALL_HASH]
- python_script._calls[call_hash]["statuses"].pop(-1)
- python_script._calls[call_hash]["statuses"].pop(-1)
-
- run_work_isolated(python_script, params={"--a": "1"}, restart_count=1)
- assert python_script.has_succeeded
- assert python_script.script_args == ["--b=1", "--a=1"]
- os.remove("file.py")
- os.remove("sample.tar.gz")
-
-
-def test_tracer_component_with_code_in_dir(tmp_path):
- """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments
- are cleaned."""
-
- drive = Drive("lit://code")
- drive.component_name = "something"
- code = Code(drive=drive, name="sample.tar.gz")
-
- with open("file.py", "w") as f:
- f.write('raise Exception("An error")')
-
- with tarfile.open("sample.tar.gz", "w:gz") as tar:
- tar.add("file.py")
-
- drive.put("sample.tar.gz")
- os.remove("file.py")
- os.remove("sample.tar.gz")
-
- python_script = TracerPythonScript("file.py", script_args=["--b=1"], raise_exception=False, code=code)
- run_work_isolated(python_script, params={"--a": "1"}, restart_count=0, code_dir=str(tmp_path))
- assert "An error" in python_script.status.message
-
- assert os.path.exists(os.path.join(str(tmp_path), "file.py"))
diff --git a/tests/tests_app/components/sample_package_repo/external_lightning_component_package/__init__.py b/tests/tests_app/components/sample_package_repo/external_lightning_component_package/__init__.py
deleted file mode 100644
index f818186c96f55..0000000000000
--- a/tests/tests_app/components/sample_package_repo/external_lightning_component_package/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from lightning.app import LightningFlow, LightningWork
-
-
-class MyCustomLightningWork(LightningWork):
- @staticmethod
- def special_method():
- return "Hi, I'm an external lightning work component and can be added to any lightning project."
-
-
-class MyCustomLightningFlow(LightningFlow):
- @staticmethod
- def special_method():
- return "Hi, I'm an external lightning flow component and can be added to any lightning project."
-
-
-def exported_lightning_components():
- return [MyCustomLightningWork, MyCustomLightningFlow]
diff --git a/tests/tests_app/components/sample_package_repo/setup.py b/tests/tests_app/components/sample_package_repo/setup.py
deleted file mode 100644
index fdfeb663fe7a2..0000000000000
--- a/tests/tests_app/components/sample_package_repo/setup.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import json
-import os
-
-from setuptools import find_packages, setup
-from setuptools.command.install import install
-
-LIGHTNING_COMPONENT_INFO = {
- "package": "external_lightning_component_package",
- "version": "0.0.1",
- "entry_point": "myorg.lightning_modules",
-}
-
-
-class PostInstallCommand(install):
- def run(self):
- install.run(self)
- os.system(f"echo Installed lightning component package: {json.dumps(json.dumps(LIGHTNING_COMPONENT_INFO))}")
-
-
-setup(
- name=LIGHTNING_COMPONENT_INFO["package"],
- version=LIGHTNING_COMPONENT_INFO["version"],
- description="example of an external lightning package that contains lightning components",
- author="manskx",
- author_email="mansy@grid.ai",
- url="grid.ai",
- download_url="https://github.com/Lightning-AI/lightning",
- license="TBD",
- packages=find_packages(exclude=["tests", "docs"]),
- long_description="example of an external lightning package that contains lightning components",
- long_description_content_type="text/markdown",
- include_package_data=True,
- zip_safe=False,
- keywords=["deep learning", "pytorch", "AI"],
- python_requires=">=3.6",
- entry_points={
- "lightning.app.external_components": [
- f"{LIGHTNING_COMPONENT_INFO['entry_point']}= "
- f"{LIGHTNING_COMPONENT_INFO['package']}:exported_lightning_components",
- ],
- },
- cmdclass={
- "install": PostInstallCommand,
- },
- setup_requires=["wheel"],
-)
diff --git a/tests/tests_app/components/serve/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py
deleted file mode 100644
index 0da489907b69f..0000000000000
--- a/tests/tests_app/components/serve/test_auto_scaler.py
+++ /dev/null
@@ -1,222 +0,0 @@
-import time
-import uuid
-from unittest import mock
-from unittest.mock import patch
-
-import pytest
-from fastapi import HTTPException
-from lightning.app import CloudCompute, LightningWork
-from lightning.app.components import AutoScaler, ColdStartProxy, Text
-from lightning.app.components.serve.auto_scaler import _LoadBalancer
-
-
-class EmptyWork(LightningWork):
- def run(self):
- pass
-
-
-class AutoScaler1(AutoScaler):
- def scale(self, replicas: int, metrics) -> int:
- # only upscale
- return replicas + 1
-
-
-class AutoScaler2(AutoScaler):
- def scale(self, replicas: int, metrics) -> int:
- # only downscale
- return replicas - 1
-
-
-@patch("uvicorn.run")
-@patch("lightning.app.components.serve.auto_scaler._LoadBalancer.url")
-@patch("lightning.app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
-def test_num_replicas_not_above_max_replicas(*_):
- """Test self.num_replicas doesn't exceed max_replicas."""
- max_replicas = 6
- auto_scaler = AutoScaler1(
- EmptyWork,
- min_replicas=1,
- max_replicas=max_replicas,
- scale_out_interval=0.001,
- scale_in_interval=0.001,
- )
-
- for _ in range(max_replicas + 1):
- time.sleep(0.002)
- auto_scaler.run()
-
- assert auto_scaler.num_replicas == max_replicas
-
-
-@patch("uvicorn.run")
-@patch("lightning.app.components.serve.auto_scaler._LoadBalancer.url")
-@patch("lightning.app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
-def test_num_replicas_not_below_min_replicas(*_):
- """Test self.num_replicas doesn't exceed max_replicas."""
- min_replicas = 1
- auto_scaler = AutoScaler2(
- EmptyWork,
- min_replicas=min_replicas,
- max_replicas=4,
- scale_out_interval=0.001,
- scale_in_interval=0.001,
- )
-
- for _ in range(3):
- time.sleep(0.002)
- auto_scaler.run()
-
- assert auto_scaler.num_replicas == min_replicas
-
-
-@pytest.mark.parametrize(
- ("replicas", "metrics", "expected_replicas"),
- [
- pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"),
- pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"),
- pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"),
- pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"),
- ],
-)
-def test_scale(replicas, metrics, expected_replicas):
- """Test `scale()`, the default scaling strategy."""
- auto_scaler = AutoScaler(
- EmptyWork,
- min_replicas=1,
- max_replicas=8,
- max_batch_size=1,
- )
-
- assert auto_scaler.scale(replicas, metrics) == expected_replicas
-
-
-def test_scale_from_zero_min_replica():
- auto_scaler = AutoScaler(
- EmptyWork,
- min_replicas=0,
- max_replicas=2,
- max_batch_size=10,
- )
-
- resp = auto_scaler.scale(0, {"pending_requests": 0, "pending_works": 0})
- assert resp == 0
-
- resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 0})
- assert resp == 1
-
- resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 1})
- assert resp <= 0
-
-
-def test_create_work_cloud_compute_cloned():
- """Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
- cloud_compute = CloudCompute("gpu")
- auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
- _ = auto_scaler.create_work()
- assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
-
-
-fastapi_mock = mock.MagicMock()
-mocked_fastapi_creater = mock.MagicMock(return_value=fastapi_mock)
-
-
-@patch("lightning.app.components.serve.auto_scaler._create_fastapi", mocked_fastapi_creater)
-@patch("lightning.app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock())
-def test_API_ACCESS_ENDPOINT_creation():
- auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text)
- assert auto_scaler.load_balancer._api_name == "EmptyWork"
-
- auto_scaler.load_balancer.run()
- fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")
-
-
-def test_autoscaler_scale_up(monkeypatch):
- monkeypatch.setattr(AutoScaler, "num_pending_works", 0)
- monkeypatch.setattr(AutoScaler, "num_pending_requests", 100)
- monkeypatch.setattr(AutoScaler, "scale", mock.MagicMock(return_value=1))
- monkeypatch.setattr(AutoScaler, "create_work", mock.MagicMock())
- monkeypatch.setattr(AutoScaler, "add_work", mock.MagicMock())
-
- auto_scaler = AutoScaler(EmptyWork, min_replicas=0, max_replicas=4, scale_out_interval=0.001)
-
- # Mocking the attributes
- auto_scaler._last_autoscale = time.time() - 100000
- auto_scaler.num_replicas = 0
-
- # triggering scale up
- auto_scaler.autoscale()
- auto_scaler.scale.assert_called_once()
- auto_scaler.create_work.assert_called_once()
- auto_scaler.add_work.assert_called_once()
-
-
-def test_autoscaler_scale_down(monkeypatch):
- monkeypatch.setattr(AutoScaler, "num_pending_works", 0)
- monkeypatch.setattr(AutoScaler, "num_pending_requests", 0)
- monkeypatch.setattr(AutoScaler, "scale", mock.MagicMock(return_value=0))
- monkeypatch.setattr(AutoScaler, "remove_work", mock.MagicMock())
- monkeypatch.setattr(AutoScaler, "workers", mock.MagicMock())
-
- auto_scaler = AutoScaler(EmptyWork, min_replicas=0, max_replicas=4, scale_in_interval=0.001)
-
- # Mocking the attributes
- auto_scaler._last_autoscale = time.time() - 100000
- auto_scaler.num_replicas = 1
- auto_scaler.__dict__["load_balancer"] = mock.MagicMock()
-
- # triggering scale up
- auto_scaler.autoscale()
- auto_scaler.scale.assert_called_once()
- auto_scaler.remove_work.assert_called_once()
-
-
-class TestLoadBalancerProcessRequest:
- @pytest.mark.asyncio()
- async def test_workers_not_ready_with_cold_start_proxy(self, monkeypatch):
- monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
- load_balancer = _LoadBalancer(
- input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
- )
- req_id = uuid.uuid4().hex
- await load_balancer.process_request("test", req_id)
- load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")
-
- @pytest.mark.asyncio()
- async def test_workers_not_ready_without_cold_start_proxy(self, monkeypatch):
- load_balancer = _LoadBalancer(
- input_type=Text,
- output_type=Text,
- endpoint="/predict",
- )
- req_id = uuid.uuid4().hex
- # populating the responses so the while loop exists
- load_balancer._responses = {req_id: "Dummy"}
- with pytest.raises(HTTPException):
- await load_balancer.process_request("test", req_id)
-
- @pytest.mark.asyncio()
- async def test_workers_have_no_capacity_with_cold_start_proxy(self, monkeypatch):
- monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
- load_balancer = _LoadBalancer(
- input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
- )
- load_balancer._fastapi_app = mock.MagicMock()
- load_balancer._fastapi_app.num_current_requests = 1000
- load_balancer.servers.append(mock.MagicMock())
- req_id = uuid.uuid4().hex
- await load_balancer.process_request("test", req_id)
- load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")
-
- @pytest.mark.asyncio()
- async def test_workers_are_free(self):
- load_balancer = _LoadBalancer(
- input_type=Text,
- output_type=Text,
- endpoint="/predict",
- )
- load_balancer.servers.append(mock.MagicMock())
- req_id = uuid.uuid4().hex
- # populating the responses so the while loop exists
- load_balancer._responses = {req_id: "Dummy"}
- await load_balancer.process_request("test", req_id)
- assert load_balancer._batch == [(req_id, "test")]
diff --git a/tests/tests_app/components/serve/test_model_inference_api.py b/tests/tests_app/components/serve/test_model_inference_api.py
deleted file mode 100644
index 64c11445386b0..0000000000000
--- a/tests/tests_app/components/serve/test_model_inference_api.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import base64
-import multiprocessing as mp
-import os
-from unittest.mock import ANY, MagicMock
-
-import pytest
-from lightning.app.components.serve import serve
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.imports import _is_numpy_available, _is_torch_available
-from lightning.app.utilities.network import _configure_session, find_free_network_port
-from tests_app import _PROJECT_ROOT
-
-if _is_numpy_available():
- import numpy as np
-
-if _is_torch_available():
- import torch
-
-
-class ImageServer(serve.ModelInferenceAPI):
- def build_model(self):
- return lambda x: x
-
- def predict(self, image):
- image = self.model(image)
- return torch.from_numpy(np.asarray(image))
-
-
-def target_fn(port, workers):
- image_server = ImageServer(input="image", output="image", port=port, workers=workers)
- image_server.run()
-
-
-@pytest.mark.xfail(strict=False, reason="test has been ignored for a while and seems not to be working :(")
-@pytest.mark.skipif(not (_is_torch_available() and _is_numpy_available()), reason="Missing torch and numpy")
-@pytest.mark.parametrize("workers", [0])
-# avoid the error: Failed to establish a new connection: [WinError 10061] No connection could be made because the
-# target machine actively refused it
-@_RunIf(skip_windows=True)
-def test_model_inference_api(workers):
- port = find_free_network_port()
- process = mp.Process(target=target_fn, args=(port, workers))
- process.start()
-
- image_path = os.path.join(_PROJECT_ROOT, "docs/source-app/_static/images/logo.png")
- with open(image_path, "rb") as f:
- imgstr = base64.b64encode(f.read()).decode("UTF-8")
-
- session = _configure_session()
- res = session.post(f"http://127.0.0.1:{port}/predict", params={"data": imgstr})
- process.terminate()
- # TODO: Investigate why this doesn't match exactly `imgstr`.
- assert res.json()
- process.kill()
-
-
-class EmptyServer(serve.ModelInferenceAPI):
- def build_model(self):
- return lambda x: x
-
- def serialize(self, x):
- return super().serialize(x)
-
- def deserialize(self, x):
- return super().deserialize(x)
-
- def predict(self, x):
- return super().predict(x)
-
-
-def test_model_inference_api_mock(monkeypatch):
- monkeypatch.setattr(serve, "uvicorn", MagicMock())
- comp = EmptyServer()
- comp.run()
- serve.uvicorn.run.assert_called_once_with(app=ANY, host=comp.host, port=comp.port, log_level="error")
-
- with pytest.raises(Exception, match="Only input in"):
- EmptyServer(input="something")
-
- with pytest.raises(Exception, match="Only output in"):
- EmptyServer(output="something")
diff --git a/tests/tests_app/components/serve/test_python_server.py b/tests/tests_app/components/serve/test_python_server.py
deleted file mode 100644
index 9d353f98ef33e..0000000000000
--- a/tests/tests_app/components/serve/test_python_server.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import multiprocessing as mp
-
-from lightning.app.components import Category, Image, Number, PythonServer, Text
-from lightning.app.utilities.network import _configure_session, find_free_network_port
-
-
-class SimpleServer(PythonServer):
- def __init__(self, port):
- super().__init__(port=port)
- self._model = None
-
- def setup(self):
- self._model = lambda x: x
-
- def predict(self, data):
- return {"prediction": self._model(data.payload)}
-
-
-def target_fn(port):
- image_server = SimpleServer(port=port)
- image_server.run()
-
-
-def test_python_server_component():
- port = find_free_network_port()
- process = mp.Process(target=target_fn, args=(port,))
- process.start()
- session = _configure_session()
- res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
- process.terminate()
- assert res.json()["prediction"] == "test"
- process.kill()
-
-
-def test_image_sample_data():
- data = Image().get_sample_data()
- assert isinstance(data, dict)
- assert "image" in data
- assert len(data["image"]) > 100
-
-
-def test_text_sample_data():
- data = Text().get_sample_data()
- assert isinstance(data, dict)
- assert "text" in data
- assert len(data["text"]) > 20
-
-
-def test_number_sample_data():
- data = Number().get_sample_data()
- assert isinstance(data, dict)
- assert "prediction" in data
- assert data["prediction"] == 463
-
-
-def test_category_sample_data():
- data = Category().get_sample_data()
- assert isinstance(data, dict)
- assert "category" in data
- assert data["category"] == 463
diff --git a/tests/tests_app/components/serve/test_streamlit.py b/tests/tests_app/components/serve/test_streamlit.py
deleted file mode 100644
index 29eff20d3206a..0000000000000
--- a/tests/tests_app/components/serve/test_streamlit.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import os
-import sys
-from unittest import mock
-
-import lightning.app
-import pytest
-from lightning.app.components.serve.streamlit import ServeStreamlit, _build_model, _PatchedWork
-from lightning_utilities.core.imports import RequirementCache
-
-_STREAMLIT_AVAILABLE = RequirementCache("streamlit")
-
-
-class ServeStreamlitTest(ServeStreamlit):
- def __init__(self):
- super().__init__()
-
- self.test_variable = -1
-
- @property
- def test_property(self):
- return self.test_variable
-
- def test_method(self):
- return "test_method"
-
- @staticmethod
- def test_staticmethod():
- return "test_staticmethod"
-
- def build_model(self):
- return "model"
-
- def render():
- pass
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-@mock.patch("lightning.app.components.serve.streamlit.subprocess")
-def test_streamlit_start_stop_server(subprocess_mock):
- """Test that `ServeStreamlit.run()` invokes subprocess.Popen with the right parameters."""
- work = ServeStreamlitTest()
- work._name = "test_work"
- work._host = "hostname"
- work._port = 1111
-
- work.run()
-
- subprocess_mock.Popen.assert_called_once()
-
- env_variables = subprocess_mock.method_calls[0].kwargs["env"]
- call_args = subprocess_mock.method_calls[0].args[0]
- assert call_args == [
- sys.executable,
- "-m",
- "streamlit",
- "run",
- lightning.app.components.serve.streamlit.__file__,
- "--server.address",
- "hostname",
- "--server.port",
- "1111",
- "--server.headless",
- "true",
- ]
-
- assert env_variables["LIGHTNING_COMPONENT_NAME"] == "test_work"
- assert env_variables["LIGHTNING_WORK"] == "ServeStreamlitTest"
- assert env_variables["LIGHTNING_WORK_MODULE_FILE"] == __file__
-
- assert "LIGHTNING_COMPONENT_NAME" not in os.environ
- assert "LIGHTNING_WORK" not in os.environ
- assert "LIGHTNING_WORK_MODULE_FILE" not in os.environ
-
- work.on_exit()
- subprocess_mock.Popen().kill.assert_called_once()
-
-
-def test_patched_work():
- class TestState:
- test_variable = 1
-
- patched_work = _PatchedWork(TestState(), ServeStreamlitTest)
-
- assert patched_work.test_variable == 1
- assert patched_work.test_property == 1
- assert patched_work.test_method() == "test_method"
- assert patched_work.test_staticmethod() == "test_staticmethod"
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-def test_build_model():
- import streamlit as st
-
- st.session_state = {}
- st.spinner = mock.MagicMock()
-
- class TestState:
- test_variable = 1
-
- patched_work = _PatchedWork(TestState(), ServeStreamlitTest)
- patched_work.build_model = mock.MagicMock(return_value="test_model")
-
- _build_model(patched_work)
-
- assert st.session_state["_model"] == "test_model"
- assert patched_work.model == "test_model"
- patched_work.build_model.assert_called_once()
-
- patched_work.build_model.reset_mock()
-
- _build_model(patched_work)
-
- patched_work.build_model.assert_not_called()
diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py
deleted file mode 100644
index e6e87112ac8e1..0000000000000
--- a/tests/tests_app/conftest.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import contextlib
-import os
-import shutil
-import signal
-import threading
-from datetime import datetime
-from pathlib import Path
-from threading import Thread
-
-import psutil
-import py
-import pytest
-from lightning.app.core import constants
-from lightning.app.utilities.app_helpers import _collect_child_process_pids
-from lightning.app.utilities.component import _set_context
-from lightning.app.utilities.packaging import cloud_compute
-from lightning.app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
-from lightning.app.utilities.state import AppState
-
-os.environ["LIGHTNING_DISPATCHED"] = "1"
-
-original_method = Thread._wait_for_tstate_lock
-
-
-def fn(self, *args, timeout=None, **kwargs):
- original_method(self, *args, timeout=1, **kwargs)
-
-
-Thread._wait_for_tstate_lock = fn
-
-
-def pytest_sessionfinish(session, exitstatus):
- """Pytest hook that get called after whole test run finished, right before returning the exit status to the
- system."""
- # kill all the processes and threads created by parent
- # TODO this isn't great. We should have each tests doing it's own cleanup
- current_process = psutil.Process()
- for child in current_process.children(recursive=True):
- with contextlib.suppress(psutil.NoSuchProcess):
- params = child.as_dict() or {}
- cmd_lines = params.get("cmdline", [])
- # we shouldn't kill the resource tracker from multiprocessing. If we do,
- # `atexit` will throw as it uses resource tracker to try to clean up
- if cmd_lines and "resource_tracker" in cmd_lines[-1]:
- continue
- child.kill()
-
- main_thread = threading.current_thread()
- for t in threading.enumerate():
- if t is not main_thread:
- t.join(0)
-
- for child_pid in _collect_child_process_pids(os.getpid()):
- os.kill(child_pid, signal.SIGTERM)
-
-
-@pytest.fixture(autouse=True)
-def cleanup():
- from lightning.app.utilities.app_helpers import _LightningAppRef
-
- yield
- _LightningAppRef._app_instance = None
- shutil.rmtree("./storage", ignore_errors=True)
- shutil.rmtree("./.storage", ignore_errors=True)
- shutil.rmtree("./.shared", ignore_errors=True)
- if os.path.isfile(_APP_CONFIG_FILENAME):
- os.remove(_APP_CONFIG_FILENAME)
- _set_context(None)
-
-
-@pytest.fixture(autouse=True)
-def clear_app_state_state_variables():
- """Resets global variables in order to prevent interference between tests."""
- yield
- import lightning.app.utilities.state
-
- lightning.app.utilities.state._STATE = None
- lightning.app.utilities.state._LAST_STATE = None
- AppState._MY_AFFILIATION = ()
- if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
- cloud_compute._CLOUD_COMPUTE_STORE.clear()
-
-
-@pytest.fixture()
-def another_tmpdir(tmp_path: Path) -> py.path.local:
- random_dir = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
- tmp_path = os.path.join(tmp_path, random_dir)
- return py.path.local(tmp_path)
-
-
-@pytest.fixture()
-def caplog(caplog):
- """Workaround for https://github.com/pytest-dev/pytest/issues/3697.
-
- Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
-
- """
- import logging
-
- root_logger = logging.getLogger()
- root_propagate = root_logger.propagate
- root_logger.propagate = True
-
- propagation_dict = {
- name: logging.getLogger(name).propagate
- for name in logging.root.manager.loggerDict
- if name.startswith("lightning.app")
- }
- for name in propagation_dict:
- logging.getLogger(name).propagate = True
-
- yield caplog
-
- root_logger.propagate = root_propagate
- for name, propagate in propagation_dict.items():
- logging.getLogger(name).propagate = propagate
-
-
-@pytest.fixture()
-def patch_constants(request):
- """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for the
- duration of a test.
-
- Example::
-
- @pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True)
- def test_my_stuff(patch_constants):
- ...
-
- """
- # Set constants
- old_constants = {}
- for constant, value in request.param.items():
- old_constants[constant] = getattr(constants, constant)
- setattr(constants, constant, value)
-
- yield
-
- for constant, value in old_constants.items():
- setattr(constants, constant, value)
diff --git a/tests/tests_app/core/__init__.py b/tests/tests_app/core/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/core/lightning_app/__init__.py b/tests/tests_app/core/lightning_app/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/core/lightning_app/test_configure_layout.py b/tests/tests_app/core/lightning_app/test_configure_layout.py
deleted file mode 100644
index 757cf42738b65..0000000000000
--- a/tests/tests_app/core/lightning_app/test_configure_layout.py
+++ /dev/null
@@ -1,243 +0,0 @@
-from re import escape
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.frontend.stream_lit import StreamlitFrontend
-from lightning.app.frontend.web import StaticWebFrontend
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.testing.helpers import EmptyFlow
-from lightning.app.utilities.imports import _IS_WINDOWS
-
-
-@pytest.mark.parametrize("return_val", [1, None, set(), "string"])
-def test_invalid_layout(return_val):
- class Root(EmptyFlow):
- def configure_layout(self):
- return return_val
-
- root = Root()
- with pytest.raises(TypeError, match=escape("The return value of configure_layout() in `Root`")):
- LightningApp(root)
-
-
-def test_invalid_layout_missing_content_key():
- class Root(EmptyFlow):
- def configure_layout(self):
- return [{"name": "one"}]
-
- root = Root()
- with pytest.raises(
- ValueError, match=escape("A dictionary returned by `Root.configure_layout()` is missing a key 'content'.")
- ):
- LightningApp(root)
-
-
-def test_invalid_layout_unsupported_content_value():
- class Root(EmptyFlow):
- def configure_layout(self):
- return [{"name": "one", "content": [1, 2, 3]}]
-
- root = Root()
-
- with pytest.raises(
- ValueError,
- match=escape("A dictionary returned by `Root.configure_layout()"),
- ):
- LightningApp(root)
-
-
-class StreamlitFrontendFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if self.counter > 2:
- self.stop()
- self.counter += 1
-
- def configure_layout(self):
- frontend = StreamlitFrontend(render_fn=_render_streamlit_fn)
- frontend.start_server = Mock()
- frontend.stop_server = Mock()
- return frontend
-
-
-def _render_streamlit_fn():
- pass
-
-
-class StaticWebFrontendFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if self.counter > 2:
- self.stop()
- self.counter += 1
-
- def configure_layout(self):
- frontend = StaticWebFrontend(serve_dir="a/b/c")
- frontend.start_server = Mock()
- frontend.stop_server = Mock()
- return frontend
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="hanging... need to be fixed") # fixme
-@pytest.mark.parametrize("flow", [StaticWebFrontendFlow, StreamlitFrontendFlow])
-@mock.patch("lightning.app.runners.multiprocess.find_free_network_port")
-def test_layout_leaf_node(find_ports_mock, flow):
- find_ports_mock.side_effect = lambda: 100
- flow = flow()
- app = LightningApp(flow)
- assert flow._layout == {}
- # we copy the dict here because after we dispatch the dict will get update with new instances
- # as the layout gets updated during the loop.
- frontends = app.frontends.copy()
- MultiProcessRuntime(app).dispatch()
- assert flow.counter == 3
-
- # The target url is available for the frontend after we started the servers in dispatch
- assert flow._layout == {"target": "http://localhost:100/root"}
- assert app.frontends[flow.name].flow is flow
-
- # we start the servers for the frontends that we collected at the time of app instantiation
- frontends[flow.name].start_server.assert_called_once()
-
- # leaf layout nodes can't be changed, they stay the same from when they first got configured
- assert app.frontends[flow.name] == frontends[flow.name]
-
-
-def test_default_content_layout():
- class SimpleFlow(EmptyFlow):
- def configure_layout(self):
- frontend = StaticWebFrontend(serve_dir="a/b/c")
- frontend.start_server = Mock()
- return frontend
-
- class TestContentComponent(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.component0 = SimpleFlow()
- self.component1 = SimpleFlow()
- self.component2 = SimpleFlow()
-
- root = TestContentComponent()
- LightningApp(root)
- assert root._layout == [
- {"name": "root.component0", "content": "root.component0"},
- {"name": "root.component1", "content": "root.component1"},
- {"name": "root.component2", "content": "root.component2"},
- ]
-
-
-def test_url_content_layout():
- class TestContentComponent(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.component0 = EmptyFlow()
- self.component1 = EmptyFlow()
-
- def configure_layout(self):
- return [
- {"name": "one", "content": self.component0},
- {"name": "url", "content": "https://lightning.ai"},
- {"name": "two", "content": self.component1},
- ]
-
- root = TestContentComponent()
- LightningApp(root)
- assert root._layout == [
- {"name": "one", "content": "root.component0"},
- {"name": "url", "content": "https://lightning.ai", "target": "https://lightning.ai"},
- {"name": "two", "content": "root.component1"},
- ]
-
-
-def test_single_content_layout():
- """Test that returning a single dict also works (does not have to be returned in a list)."""
-
- class TestContentComponent(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.component0 = EmptyFlow()
- self.component1 = EmptyFlow()
-
- def configure_layout(self):
- return {"name": "single", "content": self.component1}
-
- root = TestContentComponent()
- LightningApp(root)
- assert root._layout == [{"name": "single", "content": "root.component1"}]
-
-
-class DynamicContentComponent(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.component0 = EmptyFlow()
- self.component1 = EmptyFlow()
- self.counter = 0
- self.configure_layout_called = 0
-
- def run(self):
- self.run_assertion()
- self.counter += 1
- if self.counter == 3:
- self.stop()
-
- def configure_layout(self):
- self.configure_layout_called += 1
- tabs = [
- {"name": "one", "content": self.component0},
- {"name": f"{self.counter}", "content": self.component1},
- ]
- # reverse the order of the two tabs every time the counter is odd
- if self.counter % 2 != 0:
- tabs = tabs[::-1]
- return tabs
-
- def run_assertion(self):
- """Assert that the layout changes as the counter changes its value."""
- layout_even = [
- {"name": "one", "content": "root.component0"},
- {"name": f"{self.counter}", "content": "root.component1"},
- ]
- layout_odd = layout_even[::-1]
- assert (
- self.counter % 2 == 0
- and self._layout == layout_even
- or self.counter % 2 == 1
- and self._layout == layout_odd
- )
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="hanging... need to be fixed") # fixme
-def test_dynamic_content_layout_update():
- """Test that the `configure_layout()` gets called as part of the loop and can return new layouts."""
- flow = DynamicContentComponent()
- app = LightningApp(flow)
- MultiProcessRuntime(app).dispatch()
- assert flow.configure_layout_called == 5
-
-
-@mock.patch("lightning.app.utilities.layout.is_running_in_cloud", return_value=True)
-def test_http_url_warning(*_):
- class Root(EmptyFlow):
- def configure_layout(self):
- return [
- {"name": "warning expected", "content": "http://github.com/very/long/link/to/display"},
- {"name": "no warning expected", "content": "https://github.com"},
- ]
-
- root = Root()
-
- with pytest.warns(
- UserWarning,
- match=escape("You configured an http link http://github.com/very/long/link... but it won't be accessible"),
- ):
- LightningApp(root)
diff --git a/tests/tests_app/core/scripts/app_metadata.py b/tests/tests_app/core/scripts/app_metadata.py
deleted file mode 100644
index 7194ab2c2a649..0000000000000
--- a/tests/tests_app/core/scripts/app_metadata.py
+++ /dev/null
@@ -1,61 +0,0 @@
-from lightning.app.core.app import LightningApp
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.frontend.web import StaticWebFrontend
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-
-
-class WorkA(LightningWork):
- def __init__(self):
- """WorkA."""
- super().__init__()
-
- def run(self):
- pass
-
-
-class WorkB(LightningWork):
- def __init__(self):
- """WorkB."""
- super().__init__(cloud_compute=CloudCompute("gpu"))
-
- def run(self):
- pass
-
-
-class FlowA(LightningFlow):
- def __init__(self):
- """FlowA Component."""
- super().__init__()
- self.work_a = WorkA()
-
- def run(self):
- pass
-
-
-class FlowB(LightningFlow):
- def __init__(self):
- """FlowB."""
- super().__init__()
- self.work_b = WorkB()
-
- def run(self):
- pass
-
- def configure_layout(self):
- return StaticWebFrontend(serve_dir=".")
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- """RootFlow."""
- super().__init__()
- self.flow_a_1 = FlowA()
- self.flow_a_2 = FlowA()
- self.flow_b = FlowB()
-
- def run(self):
- self.stop()
-
-
-app = LightningApp(RootFlow())
diff --git a/tests/tests_app/core/scripts/app_with_env.py b/tests/tests_app/core/scripts/app_with_env.py
deleted file mode 100644
index 6493a1918bb20..0000000000000
--- a/tests/tests_app/core/scripts/app_with_env.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import os
-
-from lightning.app import CloudCompute, LightningApp, LightningWork
-
-
-class MyWork(LightningWork):
- def __init__(self):
- super().__init__(cloud_compute=CloudCompute(name=os.environ.get("COMPUTE_NAME", "default")))
-
- def run(self):
- pass
-
-
-app = LightningApp(MyWork())
diff --git a/tests/tests_app/core/scripts/app_with_local_import.py b/tests/tests_app/core/scripts/app_with_local_import.py
deleted file mode 100644
index 38ff38df54419..0000000000000
--- a/tests/tests_app/core/scripts/app_with_local_import.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from app_metadata import RootFlow
-from lightning.app.core.app import LightningApp
-
-app = LightningApp(RootFlow())
diff --git a/tests/tests_app/core/scripts/empty.py b/tests/tests_app/core/scripts/empty.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/core/scripts/example_1.py b/tests/tests_app/core/scripts/example_1.py
deleted file mode 100644
index 486a0566ff80d..0000000000000
--- a/tests/tests_app/core/scripts/example_1.py
+++ /dev/null
@@ -1 +0,0 @@
-from numbers import Rational # noqa F401
diff --git a/tests/tests_app/core/scripts/example_2.py b/tests/tests_app/core/scripts/example_2.py
deleted file mode 100644
index 54e6f30908c85..0000000000000
--- a/tests/tests_app/core/scripts/example_2.py
+++ /dev/null
@@ -1 +0,0 @@
-from lightning.app import LightningApp # noqa F401
diff --git a/tests/tests_app/core/scripts/lightning_cli.py b/tests/tests_app/core/scripts/lightning_cli.py
deleted file mode 100644
index 486747206b334..0000000000000
--- a/tests/tests_app/core/scripts/lightning_cli.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from lightning.app.utilities.imports import _is_pytorch_lightning_available, _is_torch_available
-
-if _is_torch_available():
- import torch
- from torch.utils.data import DataLoader, Dataset
-
-if _is_pytorch_lightning_available():
- from pytorch_lightning import LightningDataModule, LightningModule, cli
-
-if __name__ == "__main__":
-
- class RandomDataset(Dataset):
- def __init__(self, size, length):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- return self.data[index]
-
- def __len__(self):
- return self.len
-
- class BoringDataModule(LightningDataModule):
- def train_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def val_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def test_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def predict_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- class BoringModel(LightningModule):
- def __init__(self):
- super().__init__()
- self.layer = torch.nn.Linear(32, 2)
-
- def forward(self, x):
- return self.layer(x)
-
- def training_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("train_loss", loss)
- return {"loss": loss}
-
- def validation_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("valid_loss", loss)
-
- def test_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("test_loss", loss)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.layer.parameters(), lr=0.1)
-
- cli.LightningCLI(BoringModel, BoringDataModule)
diff --git a/tests/tests_app/core/scripts/lightning_overrides.py b/tests/tests_app/core/scripts/lightning_overrides.py
deleted file mode 100644
index 8cb69c0fc727b..0000000000000
--- a/tests/tests_app/core/scripts/lightning_overrides.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from lightning.app.utilities.imports import _is_pytorch_lightning_available, _is_torch_available
-
-if _is_torch_available():
- from torch.utils.data import Dataset
-
-if _is_pytorch_lightning_available():
- from lightning.fabric import Fabric
- from pytorch_lightning import LightningDataModule, LightningModule, Trainer
- from pytorch_lightning.accelerators.accelerator import Accelerator
- from pytorch_lightning.callbacks import Callback
- from pytorch_lightning.loggers import Logger
- from pytorch_lightning.plugins import PrecisionPlugin
- from pytorch_lightning.profilers import Profiler
- from torchmetrics import Metric
-
-
-if __name__ == "__main__":
-
- class RandomDataset(Dataset):
- pass
-
- class BoringDataModule(LightningDataModule):
- pass
-
- class BoringModel(LightningModule):
- pass
-
- class BoringTrainer(Trainer):
- pass
-
- class BoringPrecisionPlugin(PrecisionPlugin):
- pass
-
- class BoringAccelerator(Accelerator):
- pass
-
- class BoringCallback(Callback):
- pass
-
- class BoringLogger(Logger):
- pass
-
- class BoringMetric(Metric):
- pass
-
- class BoringFabric(Fabric):
- pass
-
- class BoringProfiler(Profiler):
- pass
diff --git a/tests/tests_app/core/scripts/lightning_trainer.py b/tests/tests_app/core/scripts/lightning_trainer.py
deleted file mode 100644
index 2b9e92fabc799..0000000000000
--- a/tests/tests_app/core/scripts/lightning_trainer.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import argparse
-
-from lightning.app.utilities.imports import _is_pytorch_lightning_available, _is_torch_available
-
-if _is_torch_available():
- import torch
- from torch.utils.data import DataLoader, Dataset
-
-if _is_pytorch_lightning_available():
- import pytorch_lightning as pl
- from pytorch_lightning import LightningDataModule, LightningModule
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--max_epochs", type=int, default=10)
- args = parser.parse_args()
-
- class RandomDataset(Dataset):
- def __init__(self, size, length):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- return self.data[index]
-
- def __len__(self):
- return self.len
-
- class BoringDataModule(LightningDataModule):
- def train_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def val_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def test_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def predict_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- class BoringModel(LightningModule):
- def __init__(self):
- super().__init__()
- self.layer = torch.nn.Linear(32, 2)
-
- def forward(self, x):
- return self.layer(x)
-
- def training_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("train_loss", loss)
- return {"loss": loss}
-
- def validation_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("valid_loss", loss)
-
- def test_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("test_loss", loss)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.layer.parameters(), lr=0.1)
-
- model = BoringModel()
- datamodule = BoringDataModule()
- trainer = pl.Trainer(**vars(args))
- trainer.fit(model, datamodule)
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/tests_app/core/scripts/registry.py b/tests/tests_app/core/scripts/registry.py
deleted file mode 100644
index 0ae13b05af9fb..0000000000000
--- a/tests/tests_app/core/scripts/registry.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from lightning.app.utilities.imports import _is_pytorch_lightning_available
-
-if _is_pytorch_lightning_available():
- import torch
- from pytorch_lightning import LightningDataModule, LightningModule
- from pytorch_lightning.cli import LightningCLI
- from torch.utils.data import DataLoader, Dataset
-
- class RandomDataset(Dataset):
- def __init__(self, size, length):
- self.len = length
- self.data = torch.randn(length, size)
-
- def __getitem__(self, index):
- return self.data[index]
-
- def __len__(self):
- return self.len
-
- class BoringDataModule(LightningDataModule):
- def __init__(self, root_folder: str = "./", batch_size: int = 32):
- super().__init__()
-
- def train_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def val_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def test_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def predict_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- class BoringDataModule2(LightningDataModule):
- def __init__(self, root_folder: str = "./", batch_size: int = 32, num_workers: int = 6):
- super().__init__()
-
- def train_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def val_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def test_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- def predict_dataloader(self):
- return DataLoader(RandomDataset(32, 64), batch_size=2)
-
- class BoringModel(LightningModule):
- def __init__(self, hidden_size: int = 16):
- super().__init__()
- self.layer = torch.nn.Linear(32, 2)
-
- def forward(self, x):
- return self.layer(x)
-
- def training_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("train_loss", loss)
- return {"loss": loss}
-
- def validation_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("valid_loss", loss)
-
- def test_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("test_loss", loss)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.layer.parameters(), lr=0.1)
-
- class BoringModel2(LightningModule):
- def __init__(self, hidden_size: int = 16, batch_norm: bool = False):
- super().__init__()
- self.layer = torch.nn.Linear(32, 2)
-
- def forward(self, x):
- return self.layer(x)
-
- def training_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("train_loss", loss)
- return {"loss": loss}
-
- def validation_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("valid_loss", loss)
-
- def test_step(self, batch, batch_idx):
- loss = self(batch).sum()
- self.log("test_loss", loss)
-
- def configure_optimizers(self):
- return torch.optim.SGD(self.layer.parameters(), lr=0.1)
-
-
-if __name__ == "__main__":
- LightningCLI()
diff --git a/tests/tests_app/core/scripts/script_with_error.py b/tests/tests_app/core/scripts/script_with_error.py
deleted file mode 100644
index b3a669f13fcf3..0000000000000
--- a/tests/tests_app/core/scripts/script_with_error.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from lightning.app import LightningApp, LightningFlow
-
-
-class EmptyFlow(LightningFlow):
- def run(self):
- pass
-
-
-if __name__ == "__main__":
- # trigger a Python exception `IndexError: list index out of range` before we can load the app
- _ = [1, 2, 3][4]
-
- app = LightningApp(EmptyFlow())
diff --git a/tests/tests_app/core/scripts/two_apps.py b/tests/tests_app/core/scripts/two_apps.py
deleted file mode 100644
index 944e11b67d67d..0000000000000
--- a/tests/tests_app/core/scripts/two_apps.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.app import LightningApp, LightningFlow
-
-
-class EmptyFlow(LightningFlow):
- def run(self):
- pass
-
-
-app_1 = LightningApp(EmptyFlow())
-app_2 = LightningApp(EmptyFlow())
diff --git a/tests/tests_app/core/test_constants.py b/tests/tests_app/core/test_constants.py
deleted file mode 100644
index 489334a06e87e..0000000000000
--- a/tests/tests_app/core/test_constants.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import os
-from unittest import mock
-
-from lightning.app.core.constants import get_lightning_cloud_url
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_URL": "https://beta.lightning.ai"})
-def test_defaults():
- assert get_lightning_cloud_url() == "https://beta.lightning.ai"
diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py
deleted file mode 100644
index 65ac6fcab2bf7..0000000000000
--- a/tests/tests_app/core/test_lightning_api.py
+++ /dev/null
@@ -1,595 +0,0 @@
-import asyncio
-import contextlib
-import json
-import logging
-import multiprocessing as mp
-import os
-import sys
-from copy import deepcopy
-from multiprocessing import Process
-from pathlib import Path
-from time import sleep, time
-from unittest import mock
-
-import aiohttp
-import lightning.app
-import pytest
-import requests
-from deepdiff import DeepDiff, Delta
-from fastapi import HTTPException, Request
-from httpx import AsyncClient
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.api.http_methods import Post
-from lightning.app.core import api
-from lightning.app.core.api import (
- UIRefresher,
- fastapi_service,
- global_app_state_store,
- register_global_routes,
- start_server,
-)
-from lightning.app.core.constants import APP_SERVER_PORT
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.drive import Drive
-from lightning.app.testing.helpers import _MockQueue
-from lightning.app.utilities.app_status import AppStatus
-from lightning.app.utilities.component import _set_frontend_context, _set_work_context
-from lightning.app.utilities.enum import AppStage
-from lightning.app.utilities.load_app import extract_metadata_from_app
-from lightning.app.utilities.redis import check_if_redis_running
-from lightning.app.utilities.state import AppState, headers_for
-from pydantic import BaseModel
-
-register_global_routes()
-
-
-class WorkA(LightningWork):
- def __init__(self):
- super().__init__(parallel=True, start_with_flow=False)
- self.var_a = 0
- self.drive = Drive("lit://test_app_state_api")
-
- def run(self):
- state = AppState()
- assert state._my_affiliation == ("work_a",)
- # this would download and push data to the REST API.
- assert state.var_a == 0
- assert isinstance(state.drive, Drive)
- assert state.drive.component_name == "root.work_a"
-
- with open("test_app_state_api.txt", "w") as f:
- f.write("here")
- state.drive.put("test_app_state_api.txt")
- state.var_a = -1
-
-
-class _A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work_a = WorkA()
-
- def run(self):
- if self.work_a.var_a == -1:
- self.stop()
- self.work_a.run()
-
-
-@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow on Windows or macOs")
-def test_app_state_api():
- """This test validates the AppState can properly broadcast changes from work within its own process."""
- app = LightningApp(_A(), log_level="debug")
- MultiProcessRuntime(app, start_server=True).dispatch()
- assert app.root.work_a.var_a == -1
- _set_work_context()
- assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"]
- _set_frontend_context()
- assert app.root.work_a.drive.list(".") == ["test_app_state_api.txt"]
- os.remove("test_app_state_api.txt")
-
-
-class A2(LightningFlow):
- def __init__(self):
- super().__init__()
- self.var_a = 0
- self.a = _A()
-
- def update_state(self):
- state = AppState()
- # this would download and push data to the REST API.
- assert state.a.work_a.var_a == 0
- assert state.var_a == 0
- state.var_a = -1
-
- def run(self):
- if self.var_a == 0:
- self.update_state()
- elif self.var_a == -1:
- self.stop()
-
-
-@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow on Windows or macOs")
-def test_app_state_api_with_flows():
- """This test validates the AppState can properly broadcast changes from flows."""
- app = LightningApp(A2(), log_level="debug")
- MultiProcessRuntime(app, start_server=True).dispatch()
- assert app.root.var_a == -1
-
-
-class NestedFlow(LightningFlow):
- def run(self):
- pass
-
- def configure_layout(self):
- return {"name": "main", "content": "https://te"}
-
-
-class FlowA(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
- self.flow = NestedFlow()
- self.dict = lightning.app.structures.Dict(**{"0": NestedFlow()})
- self.list = lightning.app.structures.List(*[NestedFlow()])
-
- def run(self):
- self.counter += 1
- if self.counter >= 3:
- self.stop()
-
- def configure_layout(self):
- return [
- {"name": "main_1", "content": "https://te"},
- {"name": "main_2", "content": self.flow},
- {"name": "main_3", "content": self.dict["0"]},
- {"name": "main_4", "content": self.list[0]},
- ]
-
-
-class AppStageTestingApp(LightningApp):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter_running = 0
- self.counter_stopped = 0
- self.counter = 0
-
- def _change_stage(self, enum):
- previous_state = deepcopy(self.state)
- current_state = self.state
- current_state["app_state"]["stage"] = enum.value
- deep_diff = DeepDiff(previous_state, current_state, verbose_level=2)
- self.api_delta_queue.put(Delta(deep_diff))
-
- def maybe_apply_changes(self):
- if self.counter_stopped == 1 and self.counter_running == 1:
- if self.counter == 0:
- self._change_stage(AppStage.RUNNING)
- self.counter += 1
- if self.counter == 3:
- self._change_stage(AppStage.STOPPING)
-
- # emulate pending from the UI.
- elif self.stage == AppStage.BLOCKING:
- self._change_stage(AppStage.RUNNING)
- self.counter_running += 1
-
- elif self.root.counter == 2:
- self._change_stage(AppStage.RESTARTING)
- self.counter_stopped += 1
-
- super().maybe_apply_changes()
-
-
-# FIXME: This test doesn't assert anything
-@pytest.mark.xfail(strict=False, reason="TODO: Resolve flaky test.")
-def test_app_stage_from_frontend():
- """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would start
- and stop the app."""
- app = AppStageTestingApp(FlowA(), log_level="debug")
- app.stage = AppStage.BLOCKING
- MultiProcessRuntime(app, start_server=True).dispatch()
-
-
-def test_update_publish_state_and_maybe_refresh_ui():
- """This test checks that the method properly:
-
- - receives the state from the `publish_state_queue` and populates the app_state_store
- - receives a notification to refresh the UI and makes a GET Request (streamlit).
-
- """
- app = AppStageTestingApp(FlowA(), log_level="debug")
- publish_state_queue = _MockQueue("publish_state_queue")
- api_response_queue = _MockQueue("api_response_queue")
-
- publish_state_queue.put((app.state_with_changes, None))
-
- thread = UIRefresher(publish_state_queue, api_response_queue)
- thread.run_once()
-
- assert global_app_state_store.get_app_state("1234") == app.state_with_changes
- global_app_state_store.remove("1234")
- global_app_state_store.add("1234")
-
-
-@pytest.mark.parametrize("x_lightning_type", ["DEFAULT", "STREAMLIT"])
-@pytest.mark.anyio()
-async def test_start_server(x_lightning_type, monkeypatch):
- """This test relies on FastAPI TestClient and validates that the REST API properly provides:
-
- - the state on GET /api/v1/state
- - push a delta when making a POST request to /api/v1/state
-
- """
-
- class InfiniteQueue(_MockQueue):
- def get(self, timeout: int = 0):
- return self._queue[0]
-
- app = AppStageTestingApp(FlowA(), log_level="debug")
- app._update_layout()
- app.stage = AppStage.BLOCKING
- publish_state_queue = InfiniteQueue("publish_state_queue")
- change_state_queue = _MockQueue("change_state_queue")
- has_started_queue = _MockQueue("has_started_queue")
- api_response_queue = _MockQueue("api_response_queue")
- state = app.state_with_changes
- publish_state_queue.put((state, AppStatus(is_ui_ready=True, work_statuses={})))
- spec = extract_metadata_from_app(app)
- ui_refresher = start_server(
- publish_state_queue,
- change_state_queue,
- api_response_queue,
- has_started_queue=has_started_queue,
- uvicorn_run=False,
- spec=spec,
- )
- headers = headers_for({"type": x_lightning_type})
-
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- with pytest.raises(Exception, match="X-Lightning-Session-UUID"):
- await client.get("/api/v1/spec")
-
- with pytest.raises(Exception, match="X-Lightning-Session-ID"):
- await client.get("/api/v1/spec", headers={"X-Lightning-Session-UUID": headers["X-Lightning-Session-UUID"]})
-
- response = await client.get("/api/v1/spec", headers=headers)
- assert response.json() == spec
-
- with pytest.raises(Exception, match="X-Lightning-Session-UUID"):
- await client.get("/api/v1/state")
-
- with pytest.raises(Exception, match="X-Lightning-Session-ID"):
- await client.get("/api/v1/state", headers={"X-Lightning-Session-UUID": headers["X-Lightning-Session-UUID"]})
-
- response = await client.get("/api/v1/state", headers=headers)
- assert response.json() == state
- assert response.status_code == 200
-
- new_state = deepcopy(state)
- new_state["vars"]["counter"] += 1
-
- with pytest.raises(Exception, match="X-Lightning-Session-UUID"):
- await client.post("/api/v1/state")
-
- with pytest.raises(Exception, match="X-Lightning-Session-ID"):
- await client.post(
- "/api/v1/state", headers={"X-Lightning-Session-UUID": headers["X-Lightning-Session-UUID"]}
- )
-
- response = await client.post("/api/v1/state", json={"stage": "running"}, headers=headers)
- assert change_state_queue._queue[0].to_dict() == {
- "values_changed": {"root['app_state']['stage']": {"new_value": "running"}}
- }
- assert response.status_code == 200
-
- response = await client.get("/api/v1/layout")
- assert json.loads(response.json()) == [
- {"name": "main_1", "content": "https://te", "target": "https://te"},
- {"name": "main_2", "content": "https://te"},
- {"name": "main_3", "content": "https://te"},
- {"name": "main_4", "content": "https://te"},
- ]
-
- response = await client.get("/api/v1/status")
- assert response.json() == {"is_ui_ready": True, "work_statuses": {}}
-
- response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
- assert change_state_queue._queue[1].to_dict() == {
- "values_changed": {"root['vars']['counter']": {"new_value": 1}}
- }
- assert response.status_code == 200
-
- response = await client.post(
- "/api/v1/delta",
- json={
- "delta": {
- "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
- }
- },
- headers=headers,
- )
- assert change_state_queue._queue[2].to_dict() == {
- "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
- }
- assert response.status_code == 200
-
- monkeypatch.setattr(api, "ENABLE_PULLING_STATE_ENDPOINT", False)
-
- response = await client.get("/api/v1/state", headers=headers)
- assert response.status_code == 405
-
- response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
- assert response.status_code == 200
-
- monkeypatch.setattr(api, "ENABLE_PUSHING_STATE_ENDPOINT", False)
-
- response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
- assert response.status_code == 405
-
- response = await client.post(
- "/api/v1/delta",
- json={
- "delta": {
- "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
- }
- },
- headers=headers,
- )
- assert change_state_queue._queue[2].to_dict() == {
- "values_changed": {"root['flows']['video_search']['vars']['should_process']": {"new_value": True}}
- }
- assert response.status_code == 405
-
- # used to clean the app_state_store to following test.
- global_app_state_store.remove("1234")
- global_app_state_store.add("1234")
-
- del client
- ui_refresher.join(0)
-
-
-@pytest.mark.parametrize(
- ("path", "expected_status_code"), [("/api/v1", 404), ("/api/v1/asdf", 404), ("/api/asdf", 404), ("/api", 404)]
-)
-@pytest.mark.anyio()
-async def test_state_api_routes(path, expected_status_code):
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- response = await client.get(path)
- assert response.status_code == expected_status_code
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="redis not running")
-@pytest.mark.anyio()
-async def test_health_endpoint_success():
- global_app_state_store.store = {}
- global_app_state_store.add("1234")
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- # will respond 503 if redis is not running
- response = await client.get("/healthz")
- assert response.status_code == 500
- assert response.json() == {"status": "failure", "reason": "State is empty {}"}
- global_app_state_store.set_app_state("1234", {"state": None})
- response = await client.get("/healthz")
- assert response.status_code == 200
- assert response.json() == {"status": "ok"}
- global_app_state_store.remove("1234")
- global_app_state_store.store = {}
- global_app_state_store.add("1234")
-
-
-@pytest.mark.skipif(
- check_if_redis_running(), reason="this is testing the failure condition " "for which the redis should not run"
-)
-@pytest.mark.anyio()
-async def test_health_endpoint_failure(monkeypatch):
- monkeypatch.setenv("LIGHTNING_APP_STATE_URL", "http://someurl") # adding this to make is_running_in_cloud pass
- monkeypatch.setitem(os.environ, "LIGHTNING_CLOUD_QUEUE_TYPE", "redis")
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- # will respond 503 if redis is not running
- response = await client.get("/healthz")
- assert response.status_code == 500
-
-
-@pytest.mark.parametrize(
- ("path", "expected_status_code"),
- [
- ("/", 200),
- ("/asdf", 200),
- ("/view/component_a", 200),
- ],
-)
-@pytest.mark.anyio()
-async def test_frontend_routes(path, expected_status_code):
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- response = await client.get(path)
- assert response.status_code == expected_status_code
-
-
-@pytest.mark.xfail(sys.platform == "linux", reason="No idea why... need to be fixed") # fixme
-def test_start_server_started():
- """This test ensures has_started_queue receives a signal when the REST API has started."""
- api_publish_state_queue = mp.Queue()
- api_delta_queue = mp.Queue()
- has_started_queue = mp.Queue()
- api_response_queue = mp.Queue()
- kwargs = {
- "api_publish_state_queue": api_publish_state_queue,
- "api_delta_queue": api_delta_queue,
- "has_started_queue": has_started_queue,
- "api_response_queue": api_response_queue,
- "port": 1111,
- "root_path": "",
- }
-
- server_proc = mp.Process(target=start_server, kwargs=kwargs)
- server_proc.start()
- # requires to wait for the UI to be clicked on.
-
- # wait for server to be ready
- assert has_started_queue.get() == "SERVER_HAS_STARTED"
- server_proc.kill()
-
-
-@mock.patch("uvicorn.run")
-@mock.patch("lightning.app.core.api.UIRefresher")
-@pytest.mark.parametrize("host", ["http://0.0.0.1", "0.0.0.1"])
-def test_start_server_info_message(ui_refresher, uvicorn_run, caplog, monkeypatch, host):
- api_publish_state_queue = _MockQueue()
- api_delta_queue = _MockQueue()
- has_started_queue = _MockQueue()
- api_response_queue = _MockQueue()
- kwargs = {
- "host": host,
- "port": 1111,
- "api_publish_state_queue": api_publish_state_queue,
- "api_delta_queue": api_delta_queue,
- "has_started_queue": has_started_queue,
- "api_response_queue": api_response_queue,
- "root_path": "test",
- }
-
- monkeypatch.setattr(api, "logger", logging.getLogger())
-
- with caplog.at_level(logging.INFO):
- start_server(**kwargs)
-
- assert "Your app has started. View it in your browser: http://0.0.0.1:1111/view" in caplog.text
-
- ui_refresher.assert_called_once()
- uvicorn_run.assert_called_once_with(host="0.0.0.1", port=1111, log_level="error", app=mock.ANY, root_path="test")
-
-
-class InputRequestModel(BaseModel):
- index: int
- name: str
-
-
-class OutputRequestModel(BaseModel):
- name: str
- counter: int
-
-
-async def handler():
- print("Has been called")
- return "Hello World !"
-
-
-class FlowAPI(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if self.counter == 501:
- self.stop()
-
- def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
- self.counter += 1
- if config.index % 5 == 0:
- raise HTTPException(status_code=400, detail="HERE")
- assert request.body()
- assert request.json()
- assert request.headers
- assert request.method
- return OutputRequestModel(name=config.name, counter=self.counter)
-
- def configure_api(self):
- return [Post("/api/v1/request", self.request), Post("/api/v1/handler", handler)]
-
-
-def target():
- app = LightningApp(FlowAPI())
- MultiProcessRuntime(app).dispatch()
-
-
-async def async_request(url: str, data: InputRequestModel):
- async with aiohttp.ClientSession() as session, session.post(url, json=data.dict()) as result:
- return await result.json()
-
-
-@pytest.mark.xfail(strict=False, reason="No idea why... need to be fixed") # fixme
-def test_configure_api():
- # Setup
- process = Process(target=target)
- process.start()
- time_left = 15
- while time_left > 0:
- try:
- requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
- break
- except requests.exceptions.ConnectionError:
- sleep(0.1)
- time_left -= 0.1
-
- # Test Upload File
- with open(__file__, "rb") as fo:
- files = {"uploaded_file": fo}
-
- response = requests.put(f"http://localhost:{APP_SERVER_PORT}/api/v1/upload_file/test", files=files)
- assert response.json() == "Successfully uploaded 'test' to the Drive"
-
- url = f"http://localhost:{APP_SERVER_PORT}/api/v1/request"
-
- N = 500
- coros = []
- for index in range(N):
- coros.append(async_request(url, InputRequestModel(index=index, name="hello")))
-
- t0 = time()
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- results = loop.run_until_complete(asyncio.gather(*coros))
- response_time = time() - t0
- print(f"RPS: {N / response_time}")
- assert response_time < 10
- assert len(results) == N
- assert all(r.get("detail", None) == ("HERE" if i % 5 == 0 else None) for i, r in enumerate(results))
-
- response = requests.post(f"http://localhost:{APP_SERVER_PORT}/api/v1/handler")
- assert response.status_code == 200
-
- # Stop the Application
- with contextlib.suppress(Exception):
- response = requests.post(url, json=InputRequestModel(index=0, name="hello").dict())
-
- # Teardown
- time_left = 5
- while time_left > 0:
- if process.exitcode == 0:
- break
- sleep(0.1)
- time_left -= 0.1
- assert process.exitcode == 0
- process.kill()
-
-
-@pytest.mark.anyio()
-@mock.patch("lightning.app.core.api.UIRefresher", mock.MagicMock())
-async def test_get_annotations(tmpdir):
- cwd = os.getcwd()
- os.chdir(tmpdir)
-
- Path("lightning-annotations.json").write_text('[{"test": 3}]')
-
- try:
- app = AppStageTestingApp(FlowA(), log_level="debug")
- app._update_layout()
- app.stage = AppStage.BLOCKING
- change_state_queue = _MockQueue("change_state_queue")
- has_started_queue = _MockQueue("has_started_queue")
- api_response_queue = _MockQueue("api_response_queue")
- spec = extract_metadata_from_app(app)
- start_server(
- None,
- change_state_queue,
- api_response_queue,
- has_started_queue=has_started_queue,
- uvicorn_run=False,
- spec=spec,
- )
-
- async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
- response = await client.get("/api/v1/annotations")
- assert response.json() == [{"test": 3}]
- finally:
- # Cleanup
- os.chdir(cwd)
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
deleted file mode 100644
index 49d6bdbf44f4e..0000000000000
--- a/tests/tests_app/core/test_lightning_app.py
+++ /dev/null
@@ -1,1190 +0,0 @@
-import contextlib
-import logging
-import os
-import pickle
-from re import escape
-from time import sleep, time
-from unittest import mock
-
-import pytest
-from deepdiff import Delta
-from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork # F401
-from lightning.app.api.request_types import _DeltaRequest
-from lightning.app.core.constants import (
- FLOW_DURATION_SAMPLES,
- FLOW_DURATION_THRESHOLD,
- REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
- STATE_UPDATE_TIMEOUT,
-)
-from lightning.app.core.queues import BaseQueue, MultiProcessQueue, RedisQueue
-from lightning.app.frontend import StreamlitFrontend
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.path import Path, _storage_root_dir
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.testing.testing import LightningTestApp
-from lightning.app.utilities.app_helpers import affiliation
-from lightning.app.utilities.enum import AppStage, WorkStageStatus, WorkStopReasons
-from lightning.app.utilities.imports import _IS_WINDOWS
-from lightning.app.utilities.packaging import cloud_compute
-from lightning.app.utilities.redis import check_if_redis_running
-from lightning.app.utilities.warnings import LightningFlowWarning
-from lightning_utilities.core.imports import RequirementCache
-from pympler import asizeof
-
-from tests_app import _PROJECT_ROOT
-
-_STREAMLIT_AVAILABLE = RequirementCache("streamlit")
-
-logger = logging.getLogger()
-
-
-def test_lightning_app_requires_root_run_method():
- """Test that a useful exception is raised if the root flow does not override the run method."""
- with pytest.raises(
- TypeError, match=escape("The root flow passed to `LightningApp` does not override the `run()` method")
- ):
- LightningApp(LightningFlow())
-
- class FlowWithoutRun(LightningFlow):
- pass
-
- with pytest.raises(
- TypeError, match=escape("The root flow passed to `LightningApp` does not override the `run()` method")
- ):
- LightningApp(FlowWithoutRun())
-
- class FlowWithRun(LightningFlow):
- def run(self):
- pass
-
- LightningApp(FlowWithRun()) # no error
-
-
-class B1(LightningFlow):
- def __init__(self):
- super().__init__()
-
- def run(self):
- pass
-
-
-class A1(LightningFlow):
- def __init__(self):
- super().__init__()
- self.b = B1()
-
- def run(self):
- pass
-
-
-class Work(LightningWork):
- def __init__(self, cache_calls: bool = True):
- super().__init__(cache_calls=cache_calls)
- self.counter = 0
- self.has_finished = False
-
- def run(self):
- self.counter = self.counter + 1
- if self.cache_calls or self.counter >= 3:
- self.has_finished = True
-
-
-class SimpleFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work_a = Work(cache_calls=True)
- self.work_b = Work(cache_calls=False)
-
- def run(self):
- if self.work_a.has_finished and self.work_b.has_finished:
- self.stop()
- self.work_a.run()
- self.work_b.run()
-
-
-def test_simple_app(tmpdir):
- comp = SimpleFlow()
- app = LightningApp(comp, log_level="debug")
- assert app.root == comp
- expected = {
- "app_state": mock.ANY,
- "vars": {"_layout": mock.ANY, "_paths": {}},
- "calls": {},
- "flows": {},
- "structures": {},
- "works": {
- "work_b": {
- "vars": {
- "has_finished": False,
- "counter": 0,
- "_cloud_compute": mock.ANY,
- "_host": mock.ANY,
- "_url": "",
- "_future_url": "",
- "_internal_ip": "",
- "_public_ip": "",
- "_paths": {},
- "_port": None,
- "_restarting": False,
- "_display_name": "",
- },
- "calls": {"latest_call_hash": None},
- "changes": {},
- },
- "work_a": {
- "vars": {
- "has_finished": False,
- "counter": 0,
- "_cloud_compute": mock.ANY,
- "_host": mock.ANY,
- "_url": "",
- "_future_url": "",
- "_internal_ip": "",
- "_public_ip": "",
- "_paths": {},
- "_port": None,
- "_restarting": False,
- "_display_name": "",
- },
- "calls": {"latest_call_hash": None},
- "changes": {},
- },
- },
- "changes": {},
- }
- assert app.state == expected
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- assert comp.work_a.has_finished
- assert comp.work_b.has_finished
- # possible the `work_a` takes for ever to
- # start and `work_b` has already completed multiple iterations.
- assert comp.work_a.counter == 1
- assert comp.work_b.counter >= 3
-
-
-class WorkCounter(LightningWork):
- def __init__(self):
- super().__init__()
- self.c = 0
-
- def run(self):
- self.c = 1
-
-
-class E(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_e = WorkCounter()
-
- def run(self):
- self.w_e.run()
-
-
-class D(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_d = WorkCounter()
- self.e = E()
-
- def run(self):
- self.w_d.run()
- self.e.run()
-
-
-class C(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_c = WorkCounter()
- self.d = D()
-
- def run(self):
- self.w_c.run()
- self.d.run()
-
-
-class B(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_b = WorkCounter()
- self.c = C()
-
- def run(self):
- self.w_b.run()
- self.c.run()
-
-
-class A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_a = WorkCounter()
- self.b = B()
-
- def run(self):
- self.w_a.run()
- self.b.run()
- if self.b.c.d.e.w_e.c == 1:
- self.stop()
-
-
-def test_nested_component_names():
- root = A()
- assert root.name == "root"
- assert root.w_a.name == "root.w_a"
- assert root.b.name == "root.b"
- assert root.b.w_b.name == "root.b.w_b"
- assert root.b.c.name == "root.b.c"
- assert root.b.c.w_c.name == "root.b.c.w_c"
- assert root.b.c.d.name == "root.b.c.d"
- assert root.b.c.d.e.name == "root.b.c.d.e"
- assert root.b.c.d.e.w_e.name == "root.b.c.d.e.w_e"
- assert root.b.c.d.w_d.name == "root.b.c.d.w_d"
-
-
-def test_get_component_by_name():
- app = LightningApp(A())
- assert app.root in app.flows
- assert app.get_component_by_name("root") is app.root
- assert app.get_component_by_name("root.b") is app.root.b
- assert app.get_component_by_name("root.w_a") is app.root.w_a
- assert app.get_component_by_name("root.b.w_b") is app.root.b.w_b
- assert app.get_component_by_name("root.b.c.d.e") is app.root.b.c.d.e
-
-
-def test_get_component_by_name_raises():
- app = LightningApp(A())
-
- for name in ("", "ro", "roott"):
- with pytest.raises(ValueError, match=f"Invalid component name {name}."):
- app.get_component_by_name(name)
-
- with pytest.raises(AttributeError, match="Component 'root' has no child component with name ''"):
- app.get_component_by_name("root.")
-
- with pytest.raises(AttributeError, match="Component 'root' has no child component with name 'x'"):
- app.get_component_by_name("root.x")
-
- with pytest.raises(AttributeError, match="Component 'root.b' has no child component with name 'x'"):
- app.get_component_by_name("root.b.x")
-
- with pytest.raises(AttributeError, match="Component 'root.b.w_b' has no child component with name 'c'"):
- app.get_component_by_name("root.b.w_b.c")
-
-
-def test_nested_component():
- app = LightningApp(A(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.w_a.c == 1
- assert app.root.b.w_b.c == 1
- assert app.root.b.c.w_c.c == 1
- assert app.root.b.c.d.w_d.c == 1
- assert app.root.b.c.d.e.w_e.c == 1
-
-
-class WorkCCC(LightningWork):
- def run(self):
- pass
-
-
-class CC(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work_cc = WorkCCC()
-
- def run(self):
- pass
-
-
-class BB(LightningFlow):
- def __init__(self):
- super().__init__()
- self.c1 = CC()
- self.c2 = CC()
-
- def run(self):
- pass
-
-
-class AA(LightningFlow):
- def __init__(self):
- super().__init__()
- self.b = BB()
-
- def run(self):
- pass
-
-
-def test_component_affiliation():
- app = LightningApp(AA())
- a_affiliation = affiliation(app.root)
- assert a_affiliation == ()
- b_affiliation = affiliation(app.root.b)
- assert b_affiliation == ("b",)
- c1_affiliation = affiliation(app.root.b.c1)
- assert c1_affiliation == ("b", "c1")
- c2_affiliation = affiliation(app.root.b.c2)
- assert c2_affiliation == ("b", "c2")
- work_cc_affiliation = affiliation(app.root.b.c2.work_cc)
- assert work_cc_affiliation == ("b", "c2", "work_cc")
-
-
-class Work4(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.var_a = 0
- self.has_finished = False
-
- def run(self):
- self.var_a = 1
- sleep(2)
- # This would never been reached as the app would exit before
- self.has_finished = True
-
-
-class A4(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = Work4()
-
- def run(self):
- self.work.run()
- if self.work.var_a == 1:
- self.stop()
-
-
-@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
-def test_setattr_multiprocessing(runtime_cls, tmpdir):
- app = LightningApp(A4())
- runtime_cls(app, start_server=False).dispatch()
- assert app.root.work.var_a == 1
- assert not app.root.work.has_finished
-
-
-class CounterFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class SimpleApp2(LightningApp):
- def run_once(self):
- if self.root.counter == 5:
- self.stage = AppStage.RESTARTING
- return super().run_once()
-
- def _apply_restarting(self):
- super()._apply_restarting()
- assert self.stage == AppStage.BLOCKING
- return True
-
-
-def test_app_restarting_move_to_blocking(tmpdir):
- """Validates sending restarting move the app to blocking again."""
- app = SimpleApp2(CounterFlow(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class FlowWithFrontend(LightningFlow):
- def run(self):
- pass
-
- def configure_layout(self):
- return StreamlitFrontend(render_fn=lambda _: None)
-
-
-class AppWithFrontend(LightningApp):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.run_once_call_count = 0
-
- def run_once(self):
- # by the time run_once gets called the first time, the target_url for the frontend should be set
- # and be present in both the LightningApp.state and the LightningApp._original_state
- assert self.state["vars"]["_layout"]["target"].startswith("http://localhost")
- assert self._original_state["vars"]["_layout"]["target"].startswith("http://localhost")
- assert self.run_once_call_count or self.state == self._original_state
-
- self.run_once_call_count += 1
- if self.run_once_call_count == 3:
- return True, 0.0
- return super().run_once()
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-@mock.patch("lightning.app.frontend.stream_lit.StreamlitFrontend.start_server")
-@mock.patch("lightning.app.frontend.stream_lit.StreamlitFrontend.stop_server")
-def test_app_starts_with_complete_state_copy(_, __):
- """Test that the LightningApp captures the initial state in a separate copy when _run() gets called."""
- app = AppWithFrontend(FlowWithFrontend(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.run_once_call_count == 3
-
-
-class EmptyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- pass
-
-
-@pytest.mark.parametrize(
- ("queue_type_cls", "default_timeout"),
- [
- (MultiProcessQueue, STATE_UPDATE_TIMEOUT),
- pytest.param(
- RedisQueue,
- REDIS_QUEUES_READ_DEFAULT_TIMEOUT,
- marks=pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running"),
- ),
- ],
-)
-@pytest.mark.parametrize(
- ("sleep_time", "expect"),
- [
- (0, 9),
- pytest.param(9, 10.0, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme
- ],
-)
-@pytest.mark.flaky(reruns=5)
-def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQueue, sleep_time, expect):
- """This test validates the `_collect_deltas_from_ui_and_work_queues` can aggregate multiple delta together in a
- time window."""
-
- class SlowQueue(queue_type_cls):
- def batch_get(self, timeout, count):
- out = super().get(timeout)
- sleep(sleep_time)
- return [out]
-
- app = LightningApp(EmptyFlow())
-
- app.delta_queue = SlowQueue("api_delta_queue", default_timeout)
- if queue_type_cls is RedisQueue:
- app.delta_queue.clear()
-
- def make_delta(i):
- return _DeltaRequest(Delta({"values_changed": {"root['vars']['counter']": {"new_value": i}}}))
-
- # flowed the queue with mocked delta
- for i in range(expect + 10):
- app.delta_queue.put(make_delta(i))
-
- # Wait for a bit because multiprocessing.Queue doesn't run in the same thread and takes some time for writes
- sleep(0.001)
-
- delta = app._collect_deltas_from_ui_and_work_queues()[-1]
- generated = delta.to_dict()["values_changed"]["root['vars']['counter']"]["new_value"]
- if sleep_time:
- assert generated == expect, (generated, expect)
- else:
- # validate the flow should have aggregated at least expect.
- assert generated > expect
-
-
-def test_lightning_app_aggregation_empty():
- """Verify the while loop exits before `state_accumulate_wait` is reached if no deltas are found."""
-
- class SlowQueue(MultiProcessQueue):
- def get(self, timeout):
- return super().get(timeout)
-
- app = LightningApp(EmptyFlow())
- app.delta_queue = SlowQueue("api_delta_queue", 0)
- t0 = time()
- assert app._collect_deltas_from_ui_and_work_queues() == []
- delta = time() - t0
- assert delta < app.state_accumulate_wait + 0.01, delta
-
-
-class SimpleFlow2(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if self.counter < 2:
- self.counter += 1
-
-
-def test_maybe_apply_changes_from_flow():
- """This test validates the app `_updated` is set to True only if the state was changed in the flow."""
- app = LightningApp(SimpleFlow2())
- app.delta_queue = MultiProcessQueue("a", 0)
- assert app._has_updated
- app.maybe_apply_changes()
- app.root.run()
- app.maybe_apply_changes()
- assert app._has_updated
- app._has_updated = False
- app.root.run()
- app.maybe_apply_changes()
- assert app._has_updated
- app._has_updated = False
- app.root.run()
- app.maybe_apply_changes()
- assert not app._has_updated
-
-
-class SimpleWork(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False, parallel=True)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowA(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work_a = SimpleWork()
- self.work_b = SimpleWork()
-
- def run(self):
- if self.work_a.counter == self.work_b.counter == 0:
- self.work_a.run()
- self.work_b.run()
-
-
-class SuccessException(Exception):
- pass
-
-
-class CheckpointLightningApp(LightningApp):
- def _dump_checkpoint(self):
- super()._dump_checkpoint()
- raise SuccessException
-
-
-@pytest.mark.flaky(reruns=3)
-def test_snap_shotting():
- with contextlib.suppress(SuccessException):
- app = CheckpointLightningApp(FlowA())
- app.checkpointing = True
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints")
- checkpoints = os.listdir(checkpoint_dir)
- assert len(checkpoints) == 1
- with open(os.path.join(checkpoint_dir, checkpoints[0]), "rb") as f:
- state = pickle.load(f)
- assert state["works"]["work_a"]["vars"]["counter"] == 1
- assert state["works"]["work_b"]["vars"]["counter"] == 1
-
-
-class CounterWork(LightningWork):
- def __init__(self, parallel: bool, cache_calls: bool):
- super().__init__(parallel=parallel, cache_calls=cache_calls)
- self.counter = 0
-
- def run(self, counter=0):
- self.counter += 1
-
-
-class WaitForAllFlow(LightningFlow):
- def __init__(self, use_same_args):
- super().__init__()
- counter = 0
- self.use_same_args = use_same_args
- for parallel in [False, True]:
- for cache_calls in [False, True]:
- work = CounterWork(parallel=parallel, cache_calls=cache_calls)
- setattr(self, f"work_{counter}", work)
- counter += 1
- self.c = 0
-
- def run(self):
- next_c = self.c + 1
- for work in self.experimental_iterate(self.works(), run_once=False):
- if work.num_successes < (next_c):
- if not self.use_same_args:
- work.run(self.c)
- else:
- work.run(None)
-
- expected = 1 if self.use_same_args else next_c
-
- if not all(w.num_successes == (expected if w.cache_calls else next_c) for w in self.works()):
- return
-
- self.c += 1
- assert [w.counter for w in self.works()] == [self.c, expected, self.c, expected]
- if self.c > 3:
- self.stop()
-
-
-# TODO (tchaton) Resolve this test.
-@pytest.mark.skipif(_IS_WINDOWS, reason="timeout with system crash")
-@pytest.mark.xfail(strict=False, reason="flaky test which never terminates")
-@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
-@pytest.mark.parametrize("use_same_args", [True])
-# todo: removed test_state_wait_for_all_all_works[False-MultiProcessRuntime] as it hangs
-def test_state_wait_for_all_all_works(tmpdir, runtime_cls, use_same_args):
- app = LightningApp(WaitForAllFlow(use_same_args))
- runtime_cls(app, start_server=False).dispatch()
-
-
-class CheckpointCounter(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class CheckpointFlow(LightningFlow):
- def __init__(self, work: CheckpointCounter, depth=0):
- super().__init__()
- self.depth = depth
-
- if depth == 0:
- self.counter = 0
-
- if depth >= 10:
- self.work = work
- else:
- self.flow = CheckpointFlow(work, depth + 1)
-
- def run(self):
- if self.works()[0].counter == 5:
- self.stop()
-
- if self.depth >= 10:
- self.work.run()
- else:
- self.flow.run()
-
-
-@pytest.mark.skipif(True, reason="reloading isn't properly supported")
-def test_lightning_app_checkpointing_with_nested_flows():
- work = CheckpointCounter()
- app = LightningApp(CheckpointFlow(work))
- app.checkpointing = True
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5
-
- work = CheckpointCounter()
- app = LightningApp(CheckpointFlow(work))
- assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0
-
- app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
- # The counter was increment to 6 after the latest checkpoints was created.
- assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5
-
-
-@pytest.mark.xfail(strict=False, reason="test is skipped because CI was blocking all the PRs.")
-def test_load_state_dict_from_checkpoint_dir(tmpdir):
- work = CheckpointCounter()
- app = LightningApp(CheckpointFlow(work))
-
- checkpoints = []
- num_checkpoints = 11
- # generate 11 checkpoints.
- for _ in range(num_checkpoints):
- checkpoints.append(app._dump_checkpoint())
- app.root.counter += 1
-
- app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
- assert app.root.counter == (num_checkpoints - 1)
-
- for version in range(num_checkpoints):
- app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=version)
- assert app.root.counter == version
-
- with pytest.raises(FileNotFoundError, match="The provided directory"):
- app.load_state_dict_from_checkpoint_dir("./random_folder/")
-
- with pytest.raises(Exception, match="No checkpoints where found"):
- app.load_state_dict_from_checkpoint_dir(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/")))
-
- # delete 2 checkpoints
- os.remove(os.path.join(checkpoints[4]))
- os.remove(os.path.join(checkpoints[7]))
-
- app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir)
- assert app.root.counter == (num_checkpoints - 1)
-
- app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir, version=5)
- checkpoint_path = app._dump_checkpoint()
-
- assert os.path.basename(checkpoint_path).startswith("v_11")
-
-
-class PicklableObject:
- pass
-
-
-class PickleableReturnWork(LightningWork):
- def __init__(self):
- super().__init__()
-
- def run(self):
- return PicklableObject()
-
-
-class PickleableReturnFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = PickleableReturnWork()
-
- def run(self):
- self.work.run()
-
-
-def test_pickleable_return_from_work():
- """Test that any object that is pickleable can be returned from the run method in LightningWork."""
- with pytest.raises(SystemExit, match="1"):
- app = LightningApp(PickleableReturnFlow())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkDD(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.total = 10
- self.counter = 1
-
- def run(self):
- should_wait = self.counter == 1
- start_counter = self.total - self.counter
- for _ in range(start_counter):
- if should_wait:
- sleep(0.5)
- self.counter += 1
-
-
-class FlowCCTolerance(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = WorkDD()
-
- def run(self):
- self.work.run()
- if self.work.counter == 10:
- self.stop()
-
-
-class FaultToleranceLightningTestApp(LightningTestApp):
- def on_after_run_once(self):
- if self.root.work.status.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER:
- assert self.root.work.counter < 10
- self.restart_work("root.work")
- elif self.root.work.counter == 2:
- self.kill_work("root.work")
- return True, 0.0
- return super().on_after_run_once()
-
-
-# TODO (tchaton) Resolve this test with Resumable App.
-@_RunIf(skip_windows=True)
-def test_fault_tolerance_work():
- app = FaultToleranceLightningTestApp(FlowCCTolerance())
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.work.counter == 2
-
-
-class ProtectedAttributesWork(LightningWork):
- def __init__(self):
- super().__init__()
- # a public attribute, this should show up in the state
- self.done = False
- # a protected and a private attribute, these should NOT show up in the state
- self._protected = 1
- self.__private = 2
-
- def run(self):
- self.done = True
- self._protected = 10
- self.__private = 20
-
-
-class ProtectedAttributesFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- # a public attribute, this should show up in the state
- self.done = False
- # a protected and a private attribute, these should NOT show up in the state
- self._protected = 1
- self.__private = 2
-
- self.protected_work = ProtectedAttributesWork()
-
- def run(self):
- flow_variables = self.state_vars["vars"]
- assert "done" in flow_variables
- assert "_protected" not in flow_variables
- assert "__private" not in flow_variables
- self.done = True
-
- self.protected_work.run()
- if self.protected_work.done:
- work_variables = self.protected_work.state_vars["vars"]
- assert "done" in work_variables
- assert "_protected" not in work_variables
- assert "__private" not in work_variables
-
- # TODO: getattr and setattr access outside the Work should raise an error in the future
- _ = self.protected_work._protected
- self.protected_work._protected = 1
-
- if self.done and self.protected_work.done:
- self.stop()
-
-
-def test_protected_attributes_not_in_state():
- flow = ProtectedAttributesFlow()
- MultiProcessRuntime(LightningApp(flow), start_server=False).dispatch()
-
-
-class WorkExit(LightningWork):
- def __init__(self):
- super().__init__(raise_exception=False)
- self.counter = 0
-
- def run(self):
- self.counter += 1
- raise Exception("Hello")
-
-
-class FlowExit(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = WorkExit()
-
- def run(self):
- if self.work.counter == 1:
- self.stop()
- self.work.run()
-
-
-def test_lightning_app_exit():
- app = LightningApp(FlowExit())
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.work.status.stage == WorkStageStatus.STOPPED
-
-
-class CounterWork2(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowStop(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = CounterWork2()
-
- def run(self):
- if self.w.status.stage == WorkStageStatus.STOPPED:
- self.stop()
- if self.w.counter == 1:
- self.w.stop()
- self.w.run()
-
-
-@_RunIf(skip_windows=True)
-def test_lightning_stop():
- app = LightningApp(FlowStop())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class SleepyFlow(LightningFlow):
- def __init__(self, sleep_interval, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter = 0
- self.sleep_interval = sleep_interval
-
- def run(self):
- if self.counter == 2 * FLOW_DURATION_SAMPLES:
- self.stop()
- sleep(self.sleep_interval)
- self.counter += 1
-
-
-class SleepyWork(LightningWork):
- def __init__(self, sleep_interval, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.sleep_interval = sleep_interval
-
- def run(self):
- sleep(self.sleep_interval)
-
-
-class SleepyFlowWithWork(LightningFlow):
- def __init__(self, sleep_interval, work_sleep_interval, parallel, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter = 0
- self.sleep_interval = sleep_interval
- self.work = SleepyWork(work_sleep_interval, parallel=parallel)
-
- def run(self):
- if self.counter == 2 * FLOW_DURATION_SAMPLES:
- self.stop()
- self.work.run()
- sleep(self.sleep_interval)
- self.counter += 1
-
-
-def test_slow_flow():
- app0 = LightningApp(SleepyFlow(sleep_interval=0.5 * FLOW_DURATION_THRESHOLD))
-
- MultiProcessRuntime(app0, start_server=False).dispatch()
-
- app1 = LightningApp(SleepyFlow(sleep_interval=2 * FLOW_DURATION_THRESHOLD))
-
- with pytest.warns(LightningFlowWarning):
- MultiProcessRuntime(app1, start_server=False).dispatch()
-
- app0 = LightningApp(
- SleepyFlowWithWork(
- sleep_interval=0.5 * FLOW_DURATION_THRESHOLD,
- work_sleep_interval=2 * FLOW_DURATION_THRESHOLD,
- parallel=False,
- )
- )
-
- MultiProcessRuntime(app0, start_server=False).dispatch()
-
- app1 = LightningApp(
- SleepyFlowWithWork(
- sleep_interval=0.5 * FLOW_DURATION_THRESHOLD, work_sleep_interval=2 * FLOW_DURATION_THRESHOLD, parallel=True
- )
- )
-
- MultiProcessRuntime(app1, start_server=False).dispatch()
-
-
-class SizeWork(LightningWork):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.counter = 0
-
- def run(self, signal: int):
- self.counter += 1
- assert len(self._calls) == 2
-
-
-class SizeFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work0 = SizeWork(parallel=True, cache_calls=True)
- self._state_sizes = {}
-
- def run(self):
- for idx in range(self.work0.counter + 2):
- self.work0.run(idx)
-
- self._state_sizes[self.work0.counter] = asizeof.asizeof(self.state)
-
- if self.work0.counter >= 20:
- self.stop()
-
-
-def test_state_size_constant_growth():
- app = LightningApp(SizeFlow())
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root._state_sizes[0] <= 8380
- assert app.root._state_sizes[20] <= 26999
-
-
-class FlowUpdated(LightningFlow):
- def run(self):
- logger.info("Hello World")
-
-
-class NonUpdatedLightningTestApp(LightningTestApp):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter = 0
-
- def on_after_run_once(self):
- self.counter += 1
- if not self._has_updated and self.counter > 2:
- return True
- return super().on_after_run_once()
-
-
-def test_non_updated_flow(caplog):
- """Validate that the app can run 3 times and calls the flow only once."""
- app = NonUpdatedLightningTestApp(FlowUpdated())
- runtime = MultiProcessRuntime(app, start_server=False)
- with caplog.at_level(logging.INFO):
- runtime.dispatch()
- assert caplog.messages == [
- "Hello World",
- "Your Lightning App is being stopped. This won't take long.",
- "Your Lightning App has been stopped successfully!",
- ]
- assert app.counter == 3
-
-
-def test_debug_mode_logging():
- """This test validates the DEBUG messages are collected when activated by the LightningApp(debug=True) and cleanup
- once finished."""
-
- from lightning.app.core.app import _console
-
- app = LightningApp(A4(), log_level="debug")
- assert _console.level == logging.DEBUG
- assert os.getenv("LIGHTNING_DEBUG") == "2"
-
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- assert os.getenv("LIGHTNING_DEBUG") is None
- assert _console.level == logging.INFO
-
- app = LightningApp(A4())
- assert _console.level == logging.INFO
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkPath(LightningWork):
- def __init__(self):
- super().__init__()
- self.path = None
-
- def run(self):
- self.path = Path(__file__)
-
-
-class FlowPath(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = WorkPath()
-
- def run(self):
- self.w.run()
-
-
-class TestLightningHasUpdatedApp(LightningApp):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.counter = 0
-
- def run_once(self):
- res = super().run_once()
-
- if self.root.w.has_succeeded:
- self.counter += 1
-
- # TODO: Resolve bug where it should work with self.counter == 2
- if self.counter > 5:
- assert not self._has_updated
- return True
- return res
-
-
-def test_lightning_app_has_updated():
- app = TestLightningHasUpdatedApp(FlowPath())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkCC(LightningWork):
- def run(self):
- pass
-
-
-class FlowCC(LightningFlow):
- def __init__(self):
- super().__init__()
- self.cloud_compute = CloudCompute(name="gpu", _internal_id="a")
- self.work_a = WorkCC(cloud_compute=self.cloud_compute)
- self.work_b = WorkCC(cloud_compute=self.cloud_compute)
- self.work_c = WorkCC()
- assert self.work_a.cloud_compute._internal_id == self.work_b.cloud_compute._internal_id
-
- def run(self):
- self.work_d = WorkCC()
-
-
-class FlowWrapper(LightningFlow):
- def __init__(self, flow):
- super().__init__()
- self.w = flow
-
-
-def test_cloud_compute_binding():
- cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = True
-
- assert {} == cloud_compute._CLOUD_COMPUTE_STORE
- flow = FlowCC()
- assert len(cloud_compute._CLOUD_COMPUTE_STORE) == 2
- assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.work_c"]
- assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.work_a", "root.work_b"]
-
- wrapper = FlowWrapper(flow)
- assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.work_c"]
- assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.work_a", "root.w.work_b"]
-
- _ = FlowWrapper(wrapper)
- assert cloud_compute._CLOUD_COMPUTE_STORE["default"].component_names == ["root.w.w.work_c"]
- assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_a", "root.w.w.work_b"]
-
- assert flow.state["vars"]["cloud_compute"]["type"] == "__cloud_compute__"
- assert flow.work_a.state["vars"]["_cloud_compute"]["type"] == "__cloud_compute__"
- assert flow.work_b.state["vars"]["_cloud_compute"]["type"] == "__cloud_compute__"
- assert flow.work_c.state["vars"]["_cloud_compute"]["type"] == "__cloud_compute__"
- work_a_id = flow.work_a.state["vars"]["_cloud_compute"]["_internal_id"]
- work_b_id = flow.work_b.state["vars"]["_cloud_compute"]["_internal_id"]
- work_c_id = flow.work_c.state["vars"]["_cloud_compute"]["_internal_id"]
- assert work_a_id == work_b_id
- assert work_a_id != work_c_id
- assert work_c_id == "default"
-
- flow.work_a.cloud_compute = CloudCompute(name="something_else")
- assert cloud_compute._CLOUD_COMPUTE_STORE["a"].component_names == ["root.w.w.work_b"]
-
- flow.set_state(flow.state)
- assert isinstance(flow.cloud_compute, CloudCompute)
- assert isinstance(flow.work_a.cloud_compute, CloudCompute)
- assert isinstance(flow.work_c.cloud_compute, CloudCompute)
-
- cloud_compute.ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = False
-
- with pytest.raises(Exception, match="A Cloud Compute can be assigned only to a single Work"):
- FlowCC()
-
-
-class FlowValue(LightningFlow):
- def __init__(self):
- super().__init__()
- self._value = None
-
- @property
- def value(self):
- return self._value
-
- @value.setter
- def value(self, value):
- self._value = value
-
- def run(self):
- self.value = True
-
-
-def test_lightning_flow_properties():
- """Validates setting properties to the LightningFlow properly calls property.fset."""
- flow = FlowValue()
- assert flow._value is None
- flow.run()
- assert flow._value is True
-
-
-class SimpleWork2(LightningWork):
- def run(self):
- pass
-
-
-def test_lightning_work_stopped():
- app = LightningApp(SimpleWork2())
- MultiProcessRuntime(app, start_server=False).dispatch()
diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py
deleted file mode 100644
index e7e8ef77603c1..0000000000000
--- a/tests/tests_app/core/test_lightning_flow.py
+++ /dev/null
@@ -1,965 +0,0 @@
-import contextlib
-import os
-import pickle
-from collections import Counter
-from copy import deepcopy
-from dataclasses import dataclass
-from functools import partial
-from time import time
-from unittest.mock import ANY
-
-import lightning.app
-import pytest
-from deepdiff import DeepDiff, Delta
-from lightning.app import CloudCompute, LightningApp
-from lightning.app.core.flow import LightningFlow, _RootFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.path import Path, _storage_root_dir
-from lightning.app.structures import Dict as LDict
-from lightning.app.structures import List as LList
-from lightning.app.testing.helpers import EmptyFlow, EmptyWork, _MockQueue
-from lightning.app.utilities.app_helpers import (
- _delta_to_app_state_delta,
- _LightningAppRef,
- _load_state_dict,
- _state_dict,
-)
-from lightning.app.utilities.enum import CacheCallsKeys
-from lightning.app.utilities.exceptions import ExitAppException
-from lightning.app.utilities.imports import _IS_WINDOWS
-
-
-def test_empty_component():
- class A(LightningFlow):
- def run(self):
- pass
-
- empty_component = A()
- assert empty_component.state == {
- "vars": {"_layout": ANY, "_paths": {}},
- "calls": {},
- "flows": {},
- "structures": {},
- "changes": {},
- "works": {},
- }
-
-
-@dataclass
-class CustomDataclass:
- x: int = 1
- y: tuple = (3, 2, 1)
-
-
-@pytest.mark.parametrize("attribute", [{3, 2, 1}, lambda _: 5, CustomDataclass()])
-@pytest.mark.parametrize("cls", [LightningWork, LightningFlow])
-def test_unsupported_attribute_types(cls, attribute):
- class Component(cls):
- def __init__(self):
- super().__init__()
- self.x = attribute
-
- def run(self):
- pass
-
- with pytest.raises(AttributeError, match="Only JSON-serializable attributes are currently supported"):
- Component()
-
-
-@pytest.mark.parametrize(
- ("name", "value"),
- [
- ("x", 1),
- ("f", EmptyFlow()),
- ("w", EmptyWork()),
- ],
-)
-def test_unsupported_attribute_declaration_outside_init_or_run(name, value):
- """Test that LightningFlow attributes (with a few exceptions) are not allowed to be declared outside __init__."""
- flow = EmptyFlow()
- with pytest.raises(AttributeError, match=f"Cannot set attributes that were not defined in __init__: {name}"):
- setattr(flow, name, value)
- assert not hasattr(flow, name)
- assert name not in flow.state["vars"]
- assert name not in flow._works
- assert name not in flow._flows
-
- # no error for protected attributes, since they don't contribute to the state
- setattr(flow, "_" + name, value)
- assert hasattr(flow, "_" + name)
-
-
-@pytest.mark.parametrize(
- ("name", "value"),
- [
- ("x", 1),
- ("f", EmptyFlow()),
- ("w", EmptyWork()),
- ],
-)
-@pytest.mark.parametrize("defined", [False, True])
-def test_unsupported_attribute_declaration_inside_run(defined, name, value):
- """Test that LightningFlow attributes can set LightningFlow or LightningWork inside its run method, but everything
- else needs to be defined in the __init__ method."""
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- if defined:
- setattr(self, name, None)
-
- def run(self):
- if not defined and not isinstance(value, (LightningFlow, LightningWork)):
- with pytest.raises(
- AttributeError, match=f"Cannot set attributes that were not defined in __init__: {name}"
- ):
- setattr(self, name, value)
- assert name not in self.state["vars"]
- assert name not in self._works
- assert name not in self._flows
- else:
- setattr(self, name, value)
- if isinstance(value, LightningFlow):
- assert name in self._flows
- elif isinstance(value, LightningWork):
- assert name in self._works
- else:
- assert name in self.state["vars"]
-
- flow = Flow()
- flow.run()
-
-
-@pytest.mark.parametrize("value", [EmptyFlow(), EmptyWork()])
-def test_name_gets_removed_from_state_when_defined_as_flow_works(value):
- """Test that LightningFlow attributes are removed from the state."""
-
- class EmptyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.value = None
-
- def run(self):
- self.value = value
-
- flow = EmptyFlow()
- flow.run()
- if isinstance(value, LightningFlow):
- assert "value" not in flow.state["vars"]
- assert "value" in flow._flows
- else:
- assert "value" not in flow.state["vars"]
- assert "value" in flow._works
-
-
-@pytest.mark.parametrize(
- ("name", "value"),
- [
- ("_name", "name"),
- ("_changes", {"change": 1}),
- ],
-)
-def test_supported_attribute_declaration_outside_init(name, value):
- """Test the custom LightningFlow setattr implementation for the few reserved attributes that are allowed to be set
- from outside __init__."""
- flow = EmptyFlow()
- setattr(flow, name, value)
- assert getattr(flow, name) == value
-
-
-def test_supported_attribute_declaration_inside_init():
- """Test that the custom LightningFlow setattr can identify the __init__ call in the stack frames above."""
-
- class Flow(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.directly_in_init = "init"
- self.method_under_init()
-
- def method_under_init(self):
- self.attribute = "test"
- self.subflow = EmptyFlow()
-
- flow = Flow()
- assert flow.directly_in_init == "init"
- assert flow.state["vars"]["directly_in_init"] == "init"
- assert flow.attribute == "test"
- assert flow.state["vars"]["attribute"] == "test"
- assert isinstance(flow.subflow, EmptyFlow)
- assert flow.state["flows"]["subflow"] == flow.subflow.state
-
-
-def test_setattr_outside_run_context():
- """Test that it is allowed to update attributes outside `run` as long as the attribute is already declared."""
-
- class Flow(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.attribute = ""
-
- def outside_run(self):
- # reading allowed, setting not allowed
- self.attribute = "allowed"
- return super().configure_layout()
-
- flow = Flow()
- flow.outside_run()
- assert flow.attribute == "allowed"
- assert flow.state["vars"]["attribute"] == "allowed"
-
-
-def _run_state_transformation(tmpdir, attribute, update_fn, inplace=False):
- """This helper function defines a flow, assignes an attribute and performs a transformation on the state."""
-
- class StateTransformationTest(LightningFlow):
- def __init__(self):
- super().__init__()
- self.x = attribute
- self.finished = False
-
- def run(self):
- if self.finished:
- self.stop()
-
- x = update_fn(self.x)
- if not inplace:
- self.x = x
- self.finished = True
-
- flow = StateTransformationTest()
- assert flow.x == attribute
- app = LightningApp(flow)
- MultiProcessRuntime(app, start_server=False).dispatch()
- return app.state["vars"]["x"]
-
-
-@pytest.mark.parametrize(
- ("attribute", "update_fn", "expected"),
- [
- (1, lambda x: x + 1, 2),
- (0.5, lambda x: x + 0.5, 1.0),
- (True, lambda x: not x, False),
- ("cocofruit", lambda x: x + "s", "cocofruits"),
- ({"a": 1, "b": 2}, lambda x: {"a": 1, "b": 3}, {"a": 1, "b": 3}),
- ([1, 2], lambda x: [1, 2, 3], [1, 2, 3]),
- ((4, 5), lambda x: (4, 5, 6), (4, 5, 6)),
- ],
-)
-def test_attribute_state_change(attribute, update_fn, expected, tmpdir):
- """Test that state changes get recored on all supported data types."""
- assert _run_state_transformation(tmpdir, attribute, update_fn, inplace=False) == expected
-
-
-def test_inplace_attribute_state_change(tmpdir):
- """Test that in-place modifications on containers get captured as a state change."""
-
- # inplace modification of a nested dict
- def transform(x):
- x["b"]["c"] += 1
-
- value = {"a": 1, "b": {"c": 2}}
- expected = {"a": 1, "b": {"c": 3}}
- assert _run_state_transformation(tmpdir, value, transform, inplace=True) == expected
-
- # inplace modification of nested list
- def transform(x):
- x[2].append(3.0)
-
- value = ["a", 1, [2.0]]
- expected = ["a", 1, [2.0, 3.0]]
- assert _run_state_transformation(tmpdir, value, transform, inplace=True) == expected
-
- # inplace modification of a custom dict
- def transform(x):
- x.update("baa")
-
- value = Counter("abab")
- expected = Counter(a=4, b=3)
- assert _run_state_transformation(tmpdir, value, transform, inplace=True) == expected
-
-
-def test_lightning_flow_and_work():
- class Work(LightningWork):
- def __init__(self, cache_calls: bool = True, port=None):
- super().__init__(cache_calls=cache_calls, port=port)
- self.counter = 0
-
- def run(self, *args, **kwargs):
- self.counter += 1
-
- class Flow_A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
- self.work_a = Work(cache_calls=True, port=8000)
- self.work_b = Work(cache_calls=False, port=8001)
-
- def run(self):
- if self.counter < 5:
- self.work_a.run()
- self.work_b.run()
- self.counter += 1
- else:
- self.stop()
-
- flow_a = Flow_A()
- assert flow_a.named_works() == [("root.work_a", flow_a.work_a), ("root.work_b", flow_a.work_b)]
- assert flow_a.works() == [flow_a.work_a, flow_a.work_b]
- state = {
- "vars": {"counter": 0, "_layout": ANY, "_paths": {}},
- "calls": {},
- "flows": {},
- "structures": {},
- "works": {
- "work_b": {
- "vars": {
- "counter": 0,
- "_url": "",
- "_future_url": "",
- "_port": 8001,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
- "changes": {},
- },
- "work_a": {
- "vars": {
- "counter": 0,
- "_url": "",
- "_future_url": "",
- "_port": 8000,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
- "changes": {},
- },
- },
- "changes": {},
- }
- assert flow_a.state == state
- with contextlib.suppress(ExitAppException):
- while True:
- flow_a.run()
-
- state = {
- "vars": {"counter": 5, "_layout": ANY, "_paths": {}},
- "calls": {},
- "flows": {},
- "structures": {},
- "works": {
- "work_b": {
- "vars": {
- "counter": 5,
- "_url": "",
- "_future_url": "",
- "_port": 8001,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- "calls": {CacheCallsKeys.LATEST_CALL_HASH: None},
- "changes": {},
- },
- "work_a": {
- "vars": {
- "counter": 1,
- "_url": "",
- "_future_url": "",
- "_port": 8000,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- "calls": {
- CacheCallsKeys.LATEST_CALL_HASH: None,
- "fe3fa0f": {
- "ret": None,
- },
- },
- "changes": {},
- },
- },
- "changes": {},
- }
- assert flow_a.state == state
-
-
-def test_populate_changes():
- class WorkA(LightningWork):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- pass
-
- class A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = WorkA()
-
- def run(self):
- pass
-
- flow_a = A()
- flow_state = flow_a.state
- work_state = flow_a.work.state
- flow_a.work.counter = 1
- work_state_2 = flow_a.work.state
- delta = Delta(DeepDiff(work_state, work_state_2, verbose_level=2))
- delta = _delta_to_app_state_delta(flow_a, flow_a.work, delta)
- new_flow_state = LightningApp.populate_changes(flow_state, flow_state + delta)
- flow_a.set_state(new_flow_state)
- assert flow_a.work.counter == 1
- assert new_flow_state["works"]["work"]["changes"] == {"counter": {"from": 0, "to": 1}}
- assert flow_a.work._changes == {"counter": {"from": 0, "to": 1}}
-
-
-def test_populate_changes_status_removed():
- """Regression test for https://github.com/Lightning-AI/lightning/issues/342."""
- last_state = {
- "vars": {},
- "calls": {},
- "flows": {},
- "works": {
- "work": {
- "vars": {},
- "calls": {
- CacheCallsKeys.LATEST_CALL_HASH: "run:fe3f",
- "run:fe3f": {
- "statuses": [
- {"stage": "requesting", "message": None, "reason": None, "timestamp": 1},
- {"stage": "starting", "message": None, "reason": None, "timestamp": 2},
- {"stage": "requesting", "message": None, "reason": None, "timestamp": 3},
- ],
- },
- },
- "changes": {},
- },
- },
- "changes": {},
- }
- new_state = deepcopy(last_state)
- call = new_state["works"]["work"]["calls"]["run:fe3f"]
- call["statuses"] = call["statuses"][:-1] # pretend that a status was removed from the list
- new_state_before = deepcopy(new_state)
- new_state = LightningApp.populate_changes(last_state, new_state)
- assert new_state == new_state_before
-
-
-class CFlow(LightningFlow):
- def __init__(self, run_once):
- super().__init__()
- self.looping = 0
- self.tracker = 0
- self.restarting = False
- self.run_once = run_once
-
- def run(self):
- for idx in self.experimental_iterate(range(0, 10), run_once=self.run_once):
- if not self.restarting and (idx + 1) == 5:
- _LightningAppRef.get_current()._dump_checkpoint()
- self.stop()
- self.tracker += 1
- self.looping += 1
- if self.looping == 2:
- self.stop()
-
-
-@pytest.mark.xfail(strict=False, reason="flaky")
-@pytest.mark.parametrize("run_once", [False, True])
-def test_lightning_flow_iterate(tmpdir, run_once):
- app = LightningApp(CFlow(run_once))
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.looping == 0
- assert app.root.tracker == 4
- call_hash = [v for v in app.root._calls if "experimental_iterate" in v][0]
- iterate_call = app.root._calls[call_hash]
- assert iterate_call["counter"] == 4
- assert not iterate_call["has_finished"]
-
- checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints")
- app = LightningApp(CFlow(run_once))
- app.load_state_dict_from_checkpoint_dir(checkpoint_dir)
- app.root.restarting = True
- assert app.root.looping == 0
- assert app.root.tracker == 4
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.looping == 2
- assert app.root.tracker == 10 if run_once else 20
- iterate_call = app.root._calls[call_hash]
- assert iterate_call["has_finished"]
-
-
-class FlowCounter(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if self.counter >= 3:
- self.stop()
- self.counter += 1
-
-
-@pytest.mark.xfail(strict=False, reason="flaky")
-def test_lightning_flow_counter(tmpdir):
- app = LightningApp(FlowCounter())
- app.checkpointing = True
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.counter == 3
-
- checkpoint_dir = os.path.join(_storage_root_dir(), "checkpoints")
- checkpoints = os.listdir(checkpoint_dir)
- assert len(checkpoints) == 4
- for checkpoint in checkpoints:
- checkpoint_path = os.path.join(checkpoint_dir, checkpoint)
- with open(checkpoint_path, "rb") as f:
- app = LightningApp(FlowCounter())
- app.set_state(pickle.load(f))
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.counter == 3
-
-
-def test_flow_iterate_method():
- class Flow(LightningFlow):
- def run(self):
- pass
-
- flow = Flow()
- with pytest.raises(TypeError, match="An iterable should be provided"):
- next(flow.experimental_iterate(1))
-
-
-def test_flow_path_assignment():
- """Test that paths in the lit format lit:// get converted to a proper lightning.app.storage.Path object."""
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.no_path = "a/b/c"
- self.path = Path("lit://x/y/z")
- self.lit_path = "lit://x/y/z"
-
- flow = Flow()
- assert isinstance(flow.no_path, str)
- assert isinstance(flow.path, Path)
- assert isinstance(flow.lit_path, Path)
- assert flow.path == flow.lit_path
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="timeout with system crash")
-@pytest.mark.xfail(strict=False, reason="Timeout") # fixme
-def test_flow_state_change_with_path():
- """Test that type changes to a Path attribute are properly reflected within the state."""
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.none_to_path = None
- self.path_to_none = Path()
- self.path_to_path = Path()
-
- def run(self):
- self.none_to_path = "lit://none/to/path"
- self.path_to_none = None
- self.path_to_path = "lit://path/to/path"
- self.stop()
-
- flow = Flow()
- MultiProcessRuntime(LightningApp(flow)).dispatch()
- assert flow.none_to_path == Path("lit://none/to/path")
- assert flow.path_to_none is None
- assert flow.path_to_path == Path("lit://path/to/path")
-
- assert "path_to_none" not in flow._paths
- assert "path_to_none" in flow._state
- assert flow._paths["none_to_path"] == Path("lit://none/to/path").to_dict()
- assert flow._paths["path_to_path"] == Path("lit://path/to/path").to_dict()
- assert flow.state["vars"]["none_to_path"] == Path("lit://none/to/path")
- assert flow.state["vars"]["path_to_none"] is None
- assert flow.state["vars"]["path_to_path"] == Path("lit://path/to/path")
-
-
-class FlowSchedule(LightningFlow):
- def __init__(self):
- super().__init__()
- self._last_times = []
- self.target = 3
- self.seconds = ",".join([str(v) for v in range(0, 60, self.target)])
-
- def run(self):
- if self.schedule(f"* * * * * {self.seconds}"):
- if len(self._last_times) < 3:
- self._last_times.append(time())
- else:
- # TODO: Resolve scheduling
- assert abs((time() - self._last_times[-1]) - self.target) < 20
- self.stop()
-
-
-def test_scheduling_api():
- app = LightningApp(FlowSchedule())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-def test_lightning_flow():
- class Flow(LightningFlow):
- def run(self):
- if self.schedule("midnight"):
- pass
- if self.schedule("hourly"):
- pass
- if self.schedule("@hourly"):
- pass
- if self.schedule("daily"):
- pass
- if self.schedule("weekly"):
- pass
- if self.schedule("monthly"):
- pass
- if self.schedule("yearly"):
- pass
- if self.schedule("annually"):
- pass
- assert len(self._calls["scheduling"]) == 8
-
- Flow().run()
-
-
-class WorkReload(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowReload(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- if not getattr(self, "w", None):
- self.w = WorkReload()
-
- self.counter += 1
- self.w.run()
-
- def load_state_dict(self, flow_state, children_states, strict) -> None:
- self.w = WorkReload()
- super().load_state_dict(flow_state, children_states, strict=strict)
-
-
-class FlowReload2(LightningFlow):
- def __init__(self, random_value: str):
- super().__init__()
- self.random_value = random_value
- self.counter = 0
-
- def run(self):
- if not getattr(self, "w", None):
- self.w = WorkReload()
- self.w.run()
- self.counter += 1
-
- def load_state_dict(self, flow_state, children_states, strict) -> None:
- self.w = WorkReload()
- super().load_state_dict(flow_state, children_states, strict=strict)
-
-
-class RootFlowReload(LightningFlow):
- def __init__(self):
- super().__init__()
- self.flow = FlowReload()
- self.counter = 0
-
- def run(self):
- if not getattr(self, "flow_2", None):
- self.flow_2 = FlowReload2("something")
- self.flow.run()
- self.flow_2.run()
- self.counter += 1
-
- def load_state_dict(self, flow_state, children_states, strict) -> None:
- self.flow_2 = FlowReload2(children_states["flow_2"]["vars"]["random_value"])
- super().load_state_dict(flow_state, children_states, strict=strict)
-
-
-class RootFlowReload2(RootFlowReload):
- def load_state_dict(self, flow_state, children_states, strict) -> None:
- LightningFlow.load_state_dict(self, flow_state, children_states, strict=strict)
-
-
-def test_lightning_flow_reload():
- flow = RootFlowReload()
-
- assert flow.counter == 0
- assert flow.flow.counter == 0
-
- flow.run()
-
- assert flow.flow.w.counter == 1
- assert flow.counter == 1
- assert flow.flow.counter == 1
- assert flow.flow_2.counter == 1
- assert flow.flow_2.w.counter == 1
-
- state = _state_dict(flow)
- flow = RootFlowReload()
- _load_state_dict(flow, state)
-
- assert flow.flow.w.counter == 1
- assert flow.counter == 1
- assert flow.flow.counter == 1
- assert flow.flow_2.counter == 1
- assert flow.flow_2.w.counter == 1
-
- flow.run()
-
- assert flow.flow.w.counter == 2
- assert flow.counter == 2
- assert flow.flow.counter == 2
- assert flow.flow_2.counter == 2
- assert flow.flow_2.w.counter == 2
-
- flow = RootFlowReload2()
- flow.run()
- state = _state_dict(flow)
- flow = RootFlowReload2()
- with pytest.raises(ValueError, match="The component flow_2 wasn't instantiated for the component root"):
- _load_state_dict(flow, state)
-
-
-class NestedFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.flows_dict = LDict(**{"a": EmptyFlow()})
- self.flows_list = LList(*[EmptyFlow()])
- self.flow = EmptyFlow()
- assert list(self.flows) == ["root.flow", "root.flows_dict.a", "root.flows_list.0"]
- self.w = EmptyWork()
-
- def run(self):
- pass
-
-
-class FlowNested2(LightningFlow):
- def __init__(self):
- super().__init__()
- self.flow3 = EmptyFlow()
- self.w = EmptyWork()
-
- def run(self):
- pass
-
-
-class FlowCollection(LightningFlow):
- def __init__(self):
- super().__init__()
- self.flow = EmptyFlow()
- assert self.flow.name == "root.flow"
- self.flow2 = FlowNested2()
- assert list(self.flow2.flows) == ["root.flow2.flow3"]
- self.flows_dict = LDict(**{"a": NestedFlow()})
- assert list(self.flows_dict.flows) == [
- "root.flows_dict.a",
- "root.flows_dict.a.flow",
- "root.flows_dict.a.flows_dict.a",
- "root.flows_dict.a.flows_list.0",
- ]
- self.flows_list = LList(*[NestedFlow()])
- assert list(self.flows_list.flows) == [
- "root.flows_list.0",
- "root.flows_list.0.flow",
- "root.flows_list.0.flows_dict.a",
- "root.flows_list.0.flows_list.0",
- ]
- self.w = EmptyWork()
-
- def run(self):
- pass
-
-
-def test_lightning_flow_flows_and_works():
- flow = FlowCollection()
- app = LightningApp(flow)
-
- assert list(app.root.flows.keys()) == [
- "root.flow",
- "root.flow2",
- "root.flow2.flow3",
- "root.flows_dict.a",
- "root.flows_dict.a.flow",
- "root.flows_dict.a.flows_dict.a",
- "root.flows_dict.a.flows_list.0",
- "root.flows_list.0",
- "root.flows_list.0.flow",
- "root.flows_list.0.flows_dict.a",
- "root.flows_list.0.flows_list.0",
- ]
-
- assert [w[0] for w in app.root.named_works()] == [
- "root.w",
- "root.flow2.w",
- "root.flows_dict.a.w",
- "root.flows_list.0.w",
- ]
-
-
-class WorkReady(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.ready = False
-
- def run(self):
- self.ready = True
-
-
-class FlowReady(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = WorkReady()
-
- @property
- def ready(self) -> bool:
- return self.w.has_succeeded
-
- def run(self):
- self.w.run()
-
- if self.ready:
- self.stop()
-
-
-class RootFlowReady(_RootFlow):
- def __init__(self):
- super().__init__(WorkReady())
-
-
-@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
-def test_flow_ready(flow):
- """This test validates that the app status queue is populated correctly."""
- mock_queue = _MockQueue("api_publish_state_queue")
-
- def run_patch(method):
- app.should_publish_changes_to_api = True
- app.api_publish_state_queue = mock_queue
- method()
-
- state = {"done": False}
-
- def lagged_run_once(method):
- """Ensure that the full loop is run after the app exits."""
- new_done = method()
- if state["done"]:
- return True
- state["done"] = new_done
- return False
-
- app = LightningApp(flow())
- app._run = partial(run_patch, method=app._run)
- app.run_once = partial(lagged_run_once, method=app.run_once)
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- _, first_status = mock_queue.get()
- assert not first_status.is_ui_ready
-
- _, last_status = mock_queue.get()
- while len(mock_queue) > 0:
- _, last_status = mock_queue.get()
- assert last_status.is_ui_ready
-
-
-def test_structures_register_work_cloudcompute():
- class MyDummyWork(LightningWork):
- def run(self):
- return
-
- class MyDummyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_list = LList(*[MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)])
- self.w_dict = LDict(**{str(i): MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)})
-
- def run(self):
- for w in self.w_list:
- w.run()
-
- for w in self.w_dict.values():
- w.run()
-
- MyDummyFlow()
- assert len(lightning.app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE) == 10
- for v in lightning.app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE.values():
- assert len(v.component_names) == 1
- assert v.component_names[0][:-1] in ("root.w_list.", "root.w_dict.")
- assert v.component_names[0][-1].isdigit()
-
-
-def test_deprecation_warning_exit():
- with pytest.raises(ExitAppException), pytest.warns(DeprecationWarning, match="*Use LightningFlow.stop instead"):
- RootFlowReady()._exit()
diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py
deleted file mode 100644
index 443851d97990f..0000000000000
--- a/tests/tests_app/core/test_lightning_work.py
+++ /dev/null
@@ -1,420 +0,0 @@
-import contextlib
-from queue import Empty
-from re import escape
-from unittest.mock import MagicMock, Mock
-
-import pytest
-from lightning.app import LightningApp
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.path import Path
-from lightning.app.testing.helpers import EmptyFlow, EmptyWork, _MockQueue
-from lightning.app.testing.testing import LightningTestApp
-from lightning.app.utilities.enum import WorkStageStatus, make_status
-from lightning.app.utilities.exceptions import LightningWorkException
-from lightning.app.utilities.imports import _IS_WINDOWS
-from lightning.app.utilities.packaging.build_config import BuildConfig
-from lightning.app.utilities.proxies import ProxyWorkRun, WorkRunner
-
-
-def test_lightning_work_run_method_required():
- """Test that a helpful exception is raised when the user did not implement the `LightningWork.run()` method."""
- with pytest.raises(TypeError, match=escape("The work `LightningWork` is missing the `run()` method")):
- LightningWork()
-
- class WorkWithoutRun(LightningWork):
- def __init__(self):
- super().__init__()
- self.started = False
-
- with pytest.raises(TypeError, match=escape("The work `WorkWithoutRun` is missing the `run()` method")):
- WorkWithoutRun()
-
- class WorkWithRun(WorkWithoutRun):
- def run(self, *args, **kwargs):
- self.started = True
-
- work = WorkWithRun()
- work.run()
- assert work.started
-
-
-def test_lightning_work_no_children_allowed():
- """Test that a LightningWork can't have any children (work or flow)."""
-
- class ChildWork(EmptyWork):
- pass
-
- class ParentWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.work_b = ChildWork()
-
- def run(self, *args, **kwargs):
- pass
-
- with pytest.raises(LightningWorkException, match="isn't allowed to take any children such as"):
- ParentWork()
-
- class ParentWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.flow = LightningFlow()
-
- def run(self, *args, **kwargs):
- pass
-
- with pytest.raises(LightningWorkException, match="LightningFlow"):
- ParentWork()
-
-
-def test_forgot_to_call_init():
- """This test validates the error message for user registering state without calling __init__ is comprehensible."""
-
- class W(LightningWork):
- def __init__(self):
- self.var_a = None
-
- def run(self):
- pass
-
- with pytest.raises(AttributeError, match="Did you forget to call"):
- W()
-
-
-@pytest.mark.parametrize(
- ("name", "value"),
- [
- ("x", 1),
- ("f", EmptyFlow()),
- ("w", EmptyWork()),
- ("run", lambda _: _),
- ],
-)
-def test_unsupported_attribute_declaration_outside_init(name, value):
- """Test that LightningWork attributes (with a few exceptions) are not allowed to be set outside __init__."""
- flow = EmptyFlow()
- with pytest.raises(AttributeError, match=f"Cannot set attributes that were not defined in __init__: {name}"):
- setattr(flow, name, value)
- assert name == "run" or not hasattr(flow, name)
-
-
-@pytest.mark.parametrize(
- ("name", "value"),
- [
- ("_name", "name"),
- ("_changes", {"change": 1}),
- ("run", ProxyWorkRun(work_run=Mock(), work_name="any", work=Mock(), caller_queue=Mock())),
- ],
-)
-def test_supported_attribute_declaration_outside_init(name, value):
- """Test the custom LightningWork setattr implementation for the few reserved attributes that are allowed to be set
- from outside __init__."""
- flow = EmptyWork()
- setattr(flow, name, value)
- assert getattr(flow, name) == value
-
-
-def test_supported_attribute_declaration_inside_init():
- """Test that the custom LightningWork setattr can identify the __init__ call in the stack frames above."""
-
- class Work(EmptyWork):
- def __init__(self):
- super().__init__()
- self.directly_in_init = "init"
- self.method_under_init()
-
- def method_under_init(self):
- self.attribute = "test"
-
- work = Work()
- assert work.directly_in_init == "init"
- assert work.attribute == "test"
-
-
-@pytest.mark.parametrize("replacement", [EmptyFlow(), EmptyWork(), None])
-def test_fixing_flows_and_works(replacement):
- class FlowFixed(LightningFlow):
- def run(self):
- self.empty_flow = EmptyFlow()
- self.empty_flow = replacement
-
- with pytest.raises(AttributeError, match="Cannot set attributes as"):
- FlowFixed().run()
-
-
-@pytest.mark.parametrize("enable_exception", [False, True])
-@pytest.mark.parametrize("raise_exception", [False, True])
-def test_lightning_status(enable_exception, raise_exception):
- class Work(EmptyWork):
- def __init__(self, raise_exception, enable_exception=True):
- super().__init__(raise_exception=raise_exception)
- self.enable_exception = enable_exception
- self.dummy_path = Path("test")
-
- def run(self):
- if self.enable_exception:
- raise Exception("Custom Exception")
-
- work = Work(raise_exception, enable_exception=enable_exception)
- work._name = "root.w"
- assert work.status.stage == WorkStageStatus.NOT_STARTED
- caller_queue = _MockQueue("caller_queue")
- delta_queue = _MockQueue("delta_queue")
- readiness_queue = _MockQueue("readiness_queue")
- error_queue = _MockQueue("error_queue")
- request_queue = _MockQueue("request_queue")
- response_queue = _MockQueue("response_queue")
- copy_request_queue = _MockQueue("copy_request_queue")
- copy_response_queue = _MockQueue("copy_response_queue")
- call_hash = "fe3fa0f"
- work._calls[call_hash] = {
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "run_started_counter": 1,
- "statuses": [],
- }
- caller_queue.put({
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "state": work.state,
- })
- work_runner = WorkRunner(
- work,
- work.name,
- caller_queue,
- delta_queue,
- readiness_queue,
- error_queue,
- request_queue,
- response_queue,
- copy_request_queue,
- copy_response_queue,
- )
- with contextlib.suppress(Exception, Empty):
- work_runner()
-
- res = delta_queue._queue[0].delta.to_dict()["iterable_item_added"]
- L = len(delta_queue._queue) - 1
- if enable_exception:
- exception_cls = Exception if raise_exception else Empty
- assert isinstance(error_queue._queue[0], exception_cls)
- res_end = delta_queue._queue[L].delta.to_dict()["iterable_item_added"]
- res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["stage"] == "failed"
- res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["message"] == "Custom Exception"
- else:
- assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running"
- key = f"root['calls']['{call_hash}']['statuses'][1]"
- while L >= 0:
- res_end = delta_queue._queue[L].delta.to_dict()["iterable_item_added"]
- if key in res_end and res_end[key]["stage"] == "succeeded":
- break
- L -= 1
-
- # Stop blocking and let the thread join
- work_runner.copier.join()
-
-
-def test_lightning_work_url():
- class ExposedWork(LightningWork):
- def run(self):
- pass
-
- work = ExposedWork(port=8000)
- work._name = "root.work"
- assert work.state["vars"]["_url"] == ""
-
-
-def test_work_path_assignment():
- """Test that paths in the lit format lit:// get converted to a proper lightning.app.storage.Path object."""
-
- class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.no_path = "a/b/c"
- self.path = Path("lit://x/y/z")
- self.lit_path = "lit://x/y/z"
-
- def run(self):
- pass
-
- work = Work()
- assert isinstance(work.no_path, str)
- assert isinstance(work.path, Path)
- assert isinstance(work.lit_path, Path)
- assert work.path == work.lit_path
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="Timeout") # fixme
-def test_work_state_change_with_path():
- """Test that type changes to a Path attribute are properly reflected within the state."""
-
- class Work(LightningFlow):
- def __init__(self):
- super().__init__()
- self.none_to_path = None
- self.path_to_none = Path()
- self.path_to_path = Path()
-
- def run(self):
- self.none_to_path = "lit://none/to/path"
- self.path_to_none = None
- self.path_to_path = "lit://path/to/path"
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = Work()
-
- def run(self):
- self.work.run()
- self.stop()
-
- flow = Flow()
- MultiProcessRuntime(LightningApp(flow)).dispatch()
- assert flow.work.none_to_path == Path("lit://none/to/path")
- assert flow.work.path_to_none is None
- assert flow.work.path_to_path == Path("lit://path/to/path")
-
- assert "path_to_none" not in flow.work._paths
- assert "path_to_none" in flow.work._state
- assert flow.work._paths["none_to_path"] == Path("lit://none/to/path").to_dict()
- assert flow.work._paths["path_to_path"] == Path("lit://path/to/path").to_dict()
- assert flow.work.state["vars"]["none_to_path"] == Path("lit://none/to/path")
- assert flow.work.state["vars"]["path_to_none"] is None
- assert flow.work.state["vars"]["path_to_path"] == Path("lit://path/to/path")
-
-
-def test_lightning_work_calls():
- class W(LightningWork):
- def run(self, *args, **kwargs):
- pass
-
- w = W()
- assert len(w._calls) == 1
- w.run(1, [2], (3, 4), {"1": "3"})
- assert len(w._calls) == 2
- assert w._calls["0d824f7"] == {"ret": None}
-
-
-def test_work_cloud_build_config_provided():
- assert isinstance(LightningWork.cloud_build_config, property)
- assert LightningWork.cloud_build_config.fset is not None
-
- class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.cloud_build_config = BuildConfig(image="ghcr.io/gridai/base-images:v1.8-cpu")
-
- def run(self, *args, **kwargs):
- pass
-
- w = Work()
- w.run()
-
-
-def test_work_local_build_config_provided():
- assert isinstance(LightningWork.local_build_config, property)
- assert LightningWork.local_build_config.fset is not None
-
- class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.local_build_config = BuildConfig(image="ghcr.io/gridai/base-images:v1.8-cpu")
-
- def run(self, *args, **kwargs):
- pass
-
- w = Work()
- w.run()
-
-
-class WorkCounter(LightningWork):
- def run(self):
- pass
-
-
-class LightningTestAppWithWork(LightningTestApp):
- def on_before_run_once(self):
- if self.root.work.has_succeeded:
- return True
- return super().on_before_run_once()
-
-
-def test_lightning_app_with_work():
- app = LightningTestAppWithWork(WorkCounter())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkStart(LightningWork):
- def __init__(self, cache_calls, parallel):
- super().__init__(cache_calls=cache_calls, parallel=parallel)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowStart(LightningFlow):
- def __init__(self, cache_calls, parallel):
- super().__init__()
- self.w = WorkStart(cache_calls, parallel)
- self.finish = False
-
- def run(self):
- if self.finish:
- self.stop()
- if self.w.status.stage == WorkStageStatus.STOPPED:
- with pytest.raises(Exception, match="A work can be started only once for now."):
- self.w.start()
- self.finish = True
- if self.w.status.stage == WorkStageStatus.NOT_STARTED:
- self.w.start()
- if self.w.status.stage == WorkStageStatus.STARTED:
- self.w.run()
- if self.w.counter == 1:
- self.w.stop()
-
-
-@pytest.mark.parametrize("cache_calls", [False, True])
-@pytest.mark.parametrize("parallel", [False, True])
-def test_lightning_app_work_start(cache_calls, parallel):
- app = LightningApp(FlowStart(cache_calls, parallel))
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-def test_lightning_work_delete():
- work = WorkCounter()
-
- with pytest.raises(Exception, match="Can't delete the work"):
- work.delete()
-
- mock = MagicMock()
- work._backend = mock
- work.delete()
- assert work == mock.delete_work._mock_call_args_list[0].args[1]
-
-
-class WorkDisplay(LightningWork):
- def __init__(self):
- super().__init__()
-
- def run(self):
- pass
-
-
-def test_lightning_work_display_name():
- work = WorkDisplay()
- assert work.state_vars["vars"]["_display_name"] == ""
- work.display_name = "Hello"
- assert work.state_vars["vars"]["_display_name"] == "Hello"
-
- work._calls["latest_call_hash"] = "test"
- work._calls["test"] = {"statuses": [make_status(WorkStageStatus.PENDING)]}
- with pytest.raises(RuntimeError, match="The display name can be set only before the work has started."):
- work.display_name = "HELLO"
- work.display_name = "Hello"
diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py
deleted file mode 100644
index 0f68d8aa1ff98..0000000000000
--- a/tests/tests_app/core/test_queues.py
+++ /dev/null
@@ -1,284 +0,0 @@
-import base64
-import multiprocessing
-import pickle
-import queue
-import time
-from unittest import mock
-
-import pytest
-import requests_mock
-from lightning.app import LightningFlow
-from lightning.app.core import queues
-from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
-from lightning.app.core.queues import (
- READINESS_QUEUE_CONSTANT,
- BaseQueue,
- HTTPQueue,
- QueuingSystem,
- RateLimitedQueue,
- RedisQueue,
-)
-from lightning.app.utilities.imports import _is_redis_available
-from lightning.app.utilities.redis import check_if_redis_running
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
-@pytest.mark.parametrize("queue_type", [QueuingSystem.REDIS, QueuingSystem.MULTIPROCESS])
-def test_queue_api(queue_type, monkeypatch):
- """Test the Queue API.
-
- This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
-
- """
- import redis
-
- blpop_out = (b"entry-id", pickle.dumps("test_entry"))
-
- monkeypatch.setattr(redis.Redis, "blpop", lambda *args, **kwargs: blpop_out)
- monkeypatch.setattr(redis.Redis, "rpush", lambda *args, **kwargs: None)
- monkeypatch.setattr(redis.Redis, "set", lambda *args, **kwargs: None)
- monkeypatch.setattr(redis.Redis, "get", lambda *args, **kwargs: None)
-
- test_queue = queue_type.get_readiness_queue()
- assert test_queue.name == READINESS_QUEUE_CONSTANT
- assert isinstance(test_queue, BaseQueue)
- test_queue.put("test_entry")
- assert test_queue.get() == "test_entry"
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
-def test_redis_queue():
- queue_id = int(time.time())
- queue1 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id))
- queue2 = QueuingSystem.REDIS.get_readiness_queue(queue_id=str(queue_id + 1))
- queue1.put("test_entry1")
- queue2.put("test_entry2")
- assert queue1.get() == "test_entry1"
- assert queue2.get() == "test_entry2"
- with pytest.raises(queue.Empty):
- queue2.get(timeout=1)
- queue1.put("test_entry1")
- assert queue1.length() == 1
- queue1.clear()
- with pytest.raises(queue.Empty):
- queue1.get(timeout=1)
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
-def test_redis_health_check_success():
- redis_queue = QueuingSystem.REDIS.get_readiness_queue()
- assert redis_queue.is_running
-
- redis_queue = RedisQueue(name="test_queue", default_timeout=1)
- assert redis_queue.is_running
-
-
-@pytest.mark.skipif(not _is_redis_available(), reason="redis is required for this test.")
-@pytest.mark.skipif(check_if_redis_running(), reason="This is testing the failure case when redis is not running")
-def test_redis_health_check_failure():
- redis_queue = RedisQueue(name="test_queue", default_timeout=1)
- assert not redis_queue.is_running
-
-
-@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
-def test_redis_credential(monkeypatch):
- monkeypatch.setattr(queues, "REDIS_HOST", "test-host")
- monkeypatch.setattr(queues, "REDIS_PORT", "test-port")
- monkeypatch.setattr(queues, "REDIS_PASSWORD", "test-password")
- redis_queue = QueuingSystem.REDIS.get_readiness_queue()
- assert redis_queue.redis.connection_pool.connection_kwargs["host"] == "test-host"
- assert redis_queue.redis.connection_pool.connection_kwargs["port"] == "test-port"
- assert redis_queue.redis.connection_pool.connection_kwargs["password"] == "test-password"
-
-
-@pytest.mark.skipif(not _is_redis_available(), reason="redis isn't installed.")
-@mock.patch("lightning.app.core.queues.redis.Redis")
-def test_redis_queue_read_timeout(redis_mock):
- redis_mock.return_value.blpop.return_value = (b"READINESS_QUEUE", pickle.dumps("test_entry"))
- redis_queue = QueuingSystem.REDIS.get_readiness_queue()
-
- # default timeout
- assert redis_queue.get(timeout=0) == "test_entry"
- assert redis_mock.return_value.blpop.call_args_list[0] == mock.call(["READINESS_QUEUE"], timeout=0.005)
-
- # custom timeout
- assert redis_queue.get(timeout=2) == "test_entry"
- assert redis_mock.return_value.blpop.call_args_list[1] == mock.call(["READINESS_QUEUE"], timeout=2)
-
- # blocking timeout
- assert redis_queue.get() == "test_entry"
- assert redis_mock.return_value.blpop.call_args_list[2] == mock.call(["READINESS_QUEUE"], timeout=0)
-
-
-@pytest.mark.parametrize(
- ("queue_type", "queue_process_mock"),
- [(QueuingSystem.MULTIPROCESS, multiprocessing)],
-)
-def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch):
- context = mock.MagicMock()
- queue_mocked = mock.MagicMock()
- context.Queue = queue_mocked
- monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context))
- my_queue = queue_type.get_readiness_queue()
-
- # default timeout
- my_queue.get(timeout=0)
- assert queue_mocked.return_value.get.call_args_list[0] == mock.call(timeout=0.001, block=False)
-
- # custom timeout
- my_queue.get(timeout=2)
- assert queue_mocked.return_value.get.call_args_list[1] == mock.call(timeout=2, block=False)
-
- # blocking timeout
- my_queue.get()
- assert queue_mocked.return_value.get.call_args_list[2] == mock.call(timeout=None, block=True)
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
-@mock.patch("lightning.app.core.queues.WARNING_QUEUE_SIZE", 2)
-def test_redis_queue_warning():
- my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
- my_queue.clear()
- with pytest.warns(UserWarning, match="is larger than the"):
- my_queue.put(None)
- my_queue.put(None)
- my_queue.put(None)
-
-
-@pytest.mark.skipif(not check_if_redis_running(), reason="Redis is not running")
-@mock.patch("lightning.app.core.queues.redis.Redis")
-def test_redis_raises_error_if_failing(redis_mock):
- import redis
-
- my_queue = QueuingSystem.REDIS.get_api_delta_queue(queue_id="test_redis_queue_warning")
- redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
- redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
-
- with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
- redis_mock.return_value.blpop.side_effect = redis.exceptions.ConnectionError("EROOOR")
- my_queue.get()
-
- with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
- redis_mock.return_value.rpush.side_effect = redis.exceptions.ConnectionError("EROOOR")
- redis_mock.return_value.llen.return_value = 1
- my_queue.put(1)
-
- with pytest.raises(ConnectionError, match="Your app failed because it couldn't connect to Redis."):
- redis_mock.return_value.llen.side_effect = redis.exceptions.ConnectionError("EROOOR")
- my_queue.length()
-
-
-class TestHTTPQueue:
- def test_http_queue_failure_on_queue_name(self):
- test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
- with pytest.raises(ValueError, match="App ID couldn't be extracted"):
- test_queue.put("test")
-
- with pytest.raises(ValueError, match="App ID couldn't be extracted"):
- test_queue.get()
-
- with pytest.raises(ValueError, match="App ID couldn't be extracted"):
- test_queue.length()
-
- def test_http_queue_put(self, monkeypatch):
- monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
- test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
- test_obj = LightningFlow()
-
- # mocking requests and responses
- adapter = requests_mock.Adapter()
- test_queue.client.session.mount("http://", adapter)
- adapter.register_uri(
- "GET",
- f"{HTTP_QUEUE_URL}/v1/test/http_queue/length",
- request_headers={"Authorization": "Bearer test-token"},
- status_code=200,
- content=b"1",
- )
- adapter.register_uri(
- "POST",
- f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=push",
- status_code=201,
- additional_matcher=lambda req: pickle.dumps(test_obj) == req._request.body,
- request_headers={"Authorization": "Bearer test-token"},
- content=b"data pushed",
- )
-
- test_queue.put(test_obj)
-
- def test_http_queue_get(self, monkeypatch):
- monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
- test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
- adapter = requests_mock.Adapter()
- test_queue.client.session.mount("http://", adapter)
-
- adapter.register_uri(
- "POST",
- f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=pop",
- request_headers={"Authorization": "Bearer test-token"},
- status_code=200,
- content=pickle.dumps("test"),
- )
- assert test_queue.get() == "test"
-
- def test_http_queue_batch_get(self, monkeypatch):
- monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
- test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
- adapter = requests_mock.Adapter()
- test_queue.client.session.mount("http://", adapter)
-
- adapter.register_uri(
- "POST",
- f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=popCount",
- request_headers={"Authorization": "Bearer test-token"},
- status_code=200,
- json=[
- base64.b64encode(pickle.dumps("test")).decode("utf-8"),
- base64.b64encode(pickle.dumps("test2")).decode("utf-8"),
- ],
- )
- assert test_queue.batch_get() == ["test", "test2"]
-
-
-def test_unreachable_queue(monkeypatch):
- monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
-
- test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
-
- resp1 = mock.MagicMock()
- resp1.status_code = 204
-
- resp2 = mock.MagicMock()
- resp2.status_code = 201
-
- test_queue.client = mock.MagicMock()
- test_queue.client.post = mock.Mock(side_effect=[resp1, resp1, resp2])
-
- with pytest.raises(queue.Empty):
- test_queue._get()
-
- # Test backoff on queue.put
- test_queue.put("foo")
- assert test_queue.client.post.call_count == 3
-
-
-@mock.patch("lightning.app.core.queues.time.sleep")
-def test_rate_limited_queue(mock_sleep):
- sleeps = []
- mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)
-
- mock_queue = mock.MagicMock()
-
- mock_queue.name = "inner_queue"
- mock_queue.default_timeout = 10.0
-
- rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)
-
- assert rate_limited_queue.name == "inner_queue"
- assert rate_limited_queue.default_timeout == 10.0
-
- timeout = time.perf_counter() + 1
- while time.perf_counter() + sum(sleeps) < timeout:
- rate_limited_queue.get()
-
- assert mock_queue.get.call_count == 2
diff --git a/tests/tests_app/frontend/__init__.py b/tests/tests_app/frontend/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/frontend/conftest.py b/tests/tests_app/frontend/conftest.py
deleted file mode 100644
index 9b1d9f174396d..0000000000000
--- a/tests/tests_app/frontend/conftest.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""Test configuration."""
-
-# pylint: disable=protected-access
-from unittest import mock
-
-import pytest
-
-FLOW_SUB = "lit_flow"
-FLOW = f"root.{FLOW_SUB}"
-PORT = 61896
-
-FLOW_STATE = {
- "vars": {
- "_paths": {},
- "_layout": {"target": f"http://localhost:{PORT}/{FLOW}"},
- },
- "calls": {},
- "flows": {},
- "works": {},
- "structures": {},
- "changes": {},
-}
-
-APP_STATE = {
- "vars": {"_paths": {}, "_layout": [{"name": "home", "content": FLOW}]},
- "calls": {},
- "flows": {
- FLOW_SUB: FLOW_STATE,
- },
- "works": {},
- "structures": {},
- "changes": {},
- "app_state": {"stage": "running"},
-}
-
-
-def _request_state(self):
- _state = APP_STATE
- self._store_state(_state)
-
-
-@pytest.fixture()
-def flow():
- return FLOW
-
-
-@pytest.fixture(autouse=True, scope="module")
-def mock_request_state():
- """Avoid requests to the api."""
- with mock.patch("lightning.app.utilities.state.AppState._request_state", _request_state):
- yield
-
-
-def do_nothing():
- """Be lazy!"""
-
-
-@pytest.fixture(autouse=True, scope="module")
-def mock_start_websocket():
- """Avoid starting the websocket."""
- with mock.patch("lightning.app.frontend.panel.app_state_comm._start_websocket", do_nothing):
- yield
-
-
-@pytest.fixture()
-def app_state_state():
- """Returns an AppState dict."""
- return APP_STATE.copy()
-
-
-@pytest.fixture()
-def flow_state_state():
- """Returns an AppState dict scoped to the flow."""
- return FLOW_STATE.copy()
diff --git a/tests/tests_app/frontend/just_py/test_just_py.py b/tests/tests_app/frontend/just_py/test_just_py.py
deleted file mode 100644
index f273e64d5f30a..0000000000000
--- a/tests/tests_app/frontend/just_py/test_just_py.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import os
-import os.path as osp
-import sys
-from typing import Callable
-from unittest.mock import MagicMock
-
-import lightning.app
-from lightning.app.frontend import JustPyFrontend
-from lightning.app.frontend.just_py import just_py
-from lightning.app.frontend.just_py.just_py_base import _main, _webpage
-
-
-def render_fn(get_state: Callable) -> Callable:
- return _webpage
-
-
-def test_justpy_frontend(monkeypatch):
- justpy = MagicMock()
- popen = MagicMock()
- monkeypatch.setitem(sys.modules, "justpy", justpy)
- monkeypatch.setattr(just_py, "Popen", popen)
-
- frontend = JustPyFrontend(render_fn=render_fn)
- flow = MagicMock()
- flow.name = "c"
- frontend.flow = flow
- frontend.start_server("a", 90)
-
- path = osp.join(osp.dirname(lightning.app.frontend.just_py.__file__), "just_py_base.py")
-
- assert popen._mock_call_args[0][0] == f"{sys.executable} {path}"
- env = popen._mock_call_args[1]["env"]
- assert env["LIGHTNING_FLOW_NAME"] == "c"
- assert env["LIGHTNING_RENDER_FUNCTION"] == "render_fn"
- assert env["LIGHTNING_HOST"] == "a"
- assert env["LIGHTNING_PORT"] == "90"
-
- monkeypatch.setattr(os, "environ", env)
-
- _main()
-
- assert justpy.app._mock_mock_calls[0].args[0] == "/c"
- assert justpy.app._mock_mock_calls[0].args[1] == _webpage
-
- assert justpy.justpy._mock_mock_calls[0].args[0] == _webpage
- assert justpy.justpy._mock_mock_calls[0].kwargs == {"host": "a", "port": 90}
diff --git a/tests/tests_app/frontend/panel/__init__.py b/tests/tests_app/frontend/panel/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/frontend/panel/app_panel.py b/tests/tests_app/frontend/panel/app_panel.py
deleted file mode 100644
index 3d8c056cc4492..0000000000000
--- a/tests/tests_app/frontend/panel/app_panel.py
+++ /dev/null
@@ -1,4 +0,0 @@
-if __name__ == "__main__":
- import panel as pn
-
- pn.pane.Markdown("# Panel App").servable()
diff --git a/tests/tests_app/frontend/panel/test_app_state_comm.py b/tests/tests_app/frontend/panel/test_app_state_comm.py
deleted file mode 100644
index a2501db24fde1..0000000000000
--- a/tests/tests_app/frontend/panel/test_app_state_comm.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""The watch_app_state function enables us to trigger a callback function whenever the App state changes."""
-
-import os
-from unittest import mock
-
-from lightning.app.core.constants import APP_SERVER_PORT
-from lightning.app.frontend.panel.app_state_comm import _get_ws_url, _run_callbacks, _watch_app_state
-
-FLOW_SUB = "lit_flow"
-FLOW = f"root.{FLOW_SUB}"
-
-
-def do_nothing():
- """Be lazy!"""
-
-
-def test_get_ws_url_when_local():
- """The websocket uses port APP_SERVER_PORT when local."""
- assert _get_ws_url() == f"ws://localhost:{APP_SERVER_PORT}/api/v1/ws"
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_APP_STATE_URL": "some_url"})
-def test_get_ws_url_when_cloud():
- """The websocket uses port 8080 when LIGHTNING_APP_STATE_URL is set."""
- assert _get_ws_url() == "ws://localhost:8080/api/v1/ws"
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_FLOW_NAME": "FLOW"})
-def test_watch_app_state():
- """We can watch the App state and a callback function will be run when it changes."""
- callback = mock.MagicMock()
- # When
- _watch_app_state(callback)
-
- # Here we would like to send messages using the web socket
- # For testing the web socket is not started. See conftest.py
- # So we need to manually trigger _run_callbacks here
- _run_callbacks()
- # Then
- callback.assert_called_once()
diff --git a/tests/tests_app/frontend/panel/test_app_state_watcher.py b/tests/tests_app/frontend/panel/test_app_state_watcher.py
deleted file mode 100644
index f12b989654141..0000000000000
--- a/tests/tests_app/frontend/panel/test_app_state_watcher.py
+++ /dev/null
@@ -1,95 +0,0 @@
-"""The AppStateWatcher enables a Frontend to.
-
-- subscribe to App state changes.
-- to access and change the App state.
-
-This is particularly useful for the PanelFrontend, but can be used by other Frontends too.
-
-"""
-
-# pylint: disable=protected-access
-import os
-from unittest import mock
-
-import pytest
-from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
-from lightning.app.utilities.state import AppState
-from lightning_utilities.core.imports import RequirementCache
-
-_PARAM_AVAILABLE = RequirementCache("param")
-
-FLOW_SUB = "lit_flow"
-FLOW = f"root.{FLOW_SUB}"
-PORT = 61896
-
-
-@pytest.fixture(autouse=True)
-def mock_settings_env_vars():
- """Set the LIGHTNING environment variables."""
- with mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_FLOW_NAME": FLOW,
- "LIGHTNING_RENDER_ADDRESS": "localhost",
- "LIGHTNING_RENDER_PORT": f"{PORT}",
- },
- ):
- yield
-
-
-@pytest.mark.skipif(not _PARAM_AVAILABLE, reason="requires param")
-def test_init(flow_state_state: dict):
- """We can instantiate the AppStateWatcher.
-
- - the .state is set
- - the .state is scoped to the flow state
-
- """
- # When
- app = AppStateWatcher()
- # Needed as AppStateWatcher is singleton and might have been
- # instantiated and the state changed in other tests
- app._update_flow_state()
-
- # Then
- assert isinstance(app.state, AppState)
- assert app.state._state == flow_state_state
-
-
-@pytest.mark.skipif(not _PARAM_AVAILABLE, reason="requires param")
-def test_update_flow_state(flow_state_state: dict):
- """We can update the state.
-
- - the .state is scoped to the flow state
-
- """
- app = AppStateWatcher()
- org_state = app.state
- app._update_flow_state()
- assert app.state is not org_state
- assert app.state._state == flow_state_state
-
-
-@pytest.mark.skipif(not _PARAM_AVAILABLE, reason="requires param")
-def test_is_singleton():
- """The AppStateWatcher is a singleton for efficiency reasons.
-
- Its key that __new__ and __init__ of AppStateWatcher is only called once. See
- https://github.com/holoviz/param/issues/643
-
- """
- # When
- app1 = AppStateWatcher()
- name1 = app1.name
- state1 = app1.state
-
- app2 = AppStateWatcher()
- name2 = app2.name
- state2 = app2.state
-
- # Then
- assert app1 is app2
- assert name1 == name2
- assert app1.name == name2
- assert state1 is state2
- assert app1.state is state2
diff --git a/tests/tests_app/frontend/panel/test_panel_frontend.py b/tests/tests_app/frontend/panel/test_panel_frontend.py
deleted file mode 100644
index 4b6aa41a12104..0000000000000
--- a/tests/tests_app/frontend/panel/test_panel_frontend.py
+++ /dev/null
@@ -1,168 +0,0 @@
-"""The PanelFrontend wraps your Panel code in your LightningFlow."""
-
-# pylint: disable=protected-access, too-few-public-methods
-import os
-import runpy
-import sys
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from lightning.app import LightningFlow
-from lightning.app.frontend.panel import PanelFrontend, panel_serve_render_fn
-from lightning.app.frontend.panel.panel_frontend import _has_panel_autoreload
-from lightning.app.utilities.state import AppState
-
-
-@pytest.mark.skipif(True, reason="broken")
-def test_stop_server_not_running():
- """If the server is not running but stopped an Exception should be raised."""
- frontend = PanelFrontend(entry_point=Mock())
- with pytest.raises(RuntimeError, match="Server is not running."):
- frontend.stop_server()
-
-
-def _noop_render_fn(_):
- pass
-
-
-class MockFlow(LightningFlow):
- """Test Flow."""
-
- @property
- def name(self):
- """Return name."""
- return "root.my.flow"
-
- def run(self): # pylint: disable=arguments-differ
- """Be lazy!"""
-
-
-@mock.patch("lightning.app.frontend.panel.panel_frontend.subprocess")
-@pytest.mark.skipif(True, reason="broken")
-def test_panel_frontend_start_stop_server(subprocess_mock):
- """Test that `PanelFrontend.start_server()` invokes subprocess.Popen with the right parameters."""
- # Given
- frontend = PanelFrontend(entry_point=_noop_render_fn)
- frontend.flow = MockFlow()
- # When
- frontend.start_server(host="hostname", port=1111)
- # Then
- subprocess_mock.Popen.assert_called_once()
-
- env_variables = subprocess_mock.method_calls[0].kwargs["env"]
- call_args = subprocess_mock.method_calls[0].args[0]
- assert call_args == [
- sys.executable,
- "-m",
- "panel",
- "serve",
- panel_serve_render_fn.__file__,
- "--port",
- "1111",
- "--address",
- "hostname",
- "--prefix",
- "root.my.flow",
- "--allow-websocket-origin",
- "*",
- ]
-
- assert env_variables["LIGHTNING_FLOW_NAME"] == "root.my.flow"
- assert env_variables["LIGHTNING_RENDER_ADDRESS"] == "hostname"
- assert env_variables["LIGHTNING_RENDER_FUNCTION"] == "_noop_render_fn"
- assert env_variables["LIGHTNING_RENDER_MODULE_FILE"] == __file__
- assert env_variables["LIGHTNING_RENDER_PORT"] == "1111"
-
- assert "LIGHTNING_FLOW_NAME" not in os.environ
- assert "LIGHTNING_RENDER_FUNCTION" not in os.environ
- assert "LIGHTNING_RENDER_MODULE_FILE" not in os.environ
- assert "LIGHTNING_RENDER_MODULE_PORT" not in os.environ
- assert "LIGHTNING_RENDER_MODULE_ADDRESS" not in os.environ
- # When
- frontend.stop_server()
- # Then
- subprocess_mock.Popen().kill.assert_called_once()
-
-
-def _call_me(state):
- assert isinstance(state, AppState)
- print(state)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_FLOW_NAME": "root",
- "LIGHTNING_RENDER_FUNCTION": "_call_me",
- "LIGHTNING_RENDER_MODULE_FILE": __file__,
- "LIGHTNING_RENDER_ADDRESS": "127.0.0.1",
- "LIGHTNING_RENDER_PORT": "61896",
- },
-)
-def test_panel_wrapper_calls_entry_point(*_):
- """Run the panel_serve_entry_point."""
- runpy.run_module("lightning.app.frontend.panel.panel_serve_render_fn")
-
-
-@pytest.mark.skipif(True, reason="broken")
-def test_method_exception():
- """The PanelFrontend does not support entry_point being a method and should raise an Exception."""
-
- class _DummyClass:
- def _render_fn(self):
- pass
-
- with pytest.raises(TypeError, match="being a method"):
- PanelFrontend(entry_point=_DummyClass()._render_fn)
-
-
-@pytest.mark.skipif(True, reason="broken")
-def test_open_close_log_files():
- """We can open and close the log files."""
- frontend = PanelFrontend(_noop_render_fn)
- assert not frontend._log_files
- # When
- frontend._open_log_files()
- # Then
- stdout = frontend._log_files["stdout"]
- stderr = frontend._log_files["stderr"]
- assert not stdout.closed
- assert not stderr.closed
-
- # When
- frontend._close_log_files()
- # Then
- assert not frontend._log_files
- assert stdout.closed
- assert stderr.closed
-
- # We can close even if not open
- frontend._close_log_files()
-
-
-@pytest.mark.parametrize(
- ("value", "expected"),
- [
- ("Yes", True),
- ("yes", True),
- ("YES", True),
- ("Y", True),
- ("y", True),
- ("True", True),
- ("true", True),
- ("TRUE", True),
- ("No", False),
- ("no", False),
- ("NO", False),
- ("N", False),
- ("n", False),
- ("False", False),
- ("false", False),
- ("FALSE", False),
- ],
-)
-def test_has_panel_autoreload(value, expected):
- """We can get and set autoreload using the environment variable PANEL_AUTORELOAD."""
- with mock.patch.dict(os.environ, {"PANEL_AUTORELOAD": value}):
- assert _has_panel_autoreload() == expected
diff --git a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py
deleted file mode 100644
index 6605373b2f3b6..0000000000000
--- a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py
+++ /dev/null
@@ -1,86 +0,0 @@
-"""The panel_serve_render_fn_or_file file gets run by Python to launch a Panel Server with Lightning.
-
-These tests are for serving a render_fn function.
-
-"""
-
-import inspect
-import os
-from unittest import mock
-
-import pytest
-from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
-from lightning.app.frontend.panel.panel_serve_render_fn import _get_render_fn, _get_render_fn_from_environment
-from lightning_utilities.core.imports import RequirementCache
-
-_PARAM_AVAILABLE = RequirementCache("param")
-
-
-@pytest.fixture(autouse=True)
-def _mock_settings_env_vars():
- with mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_FLOW_NAME": "root.lit_flow",
- "LIGHTNING_RENDER_ADDRESS": "localhost",
- "LIGHTNING_RENDER_MODULE_FILE": __file__,
- "LIGHTNING_RENDER_PORT": "61896",
- },
- ):
- yield
-
-
-def render_fn(app):
- """Test render_fn function with app args."""
- return app
-
-
-@pytest.mark.skipif(not _PARAM_AVAILABLE, reason="requires param")
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_RENDER_FUNCTION": "render_fn",
- },
-)
-def test_get_view_fn_args():
- """We have a helper get_view_fn function that create a function for our view.
-
- If the render_fn provides an argument an AppStateWatcher is provided as argument
-
- """
- result = _get_render_fn()
- assert isinstance(result(), AppStateWatcher)
-
-
-def render_fn_no_args():
- """Test function with no arguments."""
- return "no_args"
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_RENDER_FUNCTION": "render_fn_no_args",
- },
-)
-def test_get_view_fn_no_args():
- """We have a helper get_view_fn function that create a function for our view.
-
- If the render_fn provides an argument an AppStateWatcher is provided as argument
-
- """
- result = _get_render_fn()
- assert result() == "no_args"
-
-
-def render_fn_2():
- """Do nothing."""
-
-
-def test_get_render_fn_from_environment():
- """We have a method to get the render_fn from the environment."""
- # When
- result = _get_render_fn_from_environment("render_fn_2", __file__)
- # Then
- assert result.__name__ == render_fn_2.__name__
- assert inspect.getmodule(result).__file__ == __file__
diff --git a/tests/tests_app/frontend/test_stream_lit.py b/tests/tests_app/frontend/test_stream_lit.py
deleted file mode 100644
index 76a3252f8f832..0000000000000
--- a/tests/tests_app/frontend/test_stream_lit.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import os
-import runpy
-import sys
-from unittest import mock
-from unittest.mock import ANY, Mock
-
-import pytest
-from lightning.app import LightningFlow
-from lightning.app.frontend.stream_lit import StreamlitFrontend
-from lightning.app.utilities.state import AppState
-from lightning_utilities.core.imports import RequirementCache
-
-_STREAMLIT_AVAILABLE = RequirementCache("streamlit")
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-def test_stop_server_not_running():
- frontend = StreamlitFrontend(render_fn=Mock())
- with pytest.raises(RuntimeError, match="Server is not running."):
- frontend.stop_server()
-
-
-def _noop_render_fn(_):
- pass
-
-
-class MockFlow(LightningFlow):
- @property
- def name(self):
- return "root.my.flow"
-
- def run(self):
- pass
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-@mock.patch("lightning.app.frontend.stream_lit.subprocess")
-def test_streamlit_frontend_start_stop_server(subprocess_mock):
- """Test that `StreamlitFrontend.start_server()` invokes subprocess.Popen with the right parameters."""
- frontend = StreamlitFrontend(render_fn=_noop_render_fn)
- frontend.flow = MockFlow()
- frontend.start_server(host="hostname", port=1111)
- subprocess_mock.Popen.assert_called_once()
-
- env_variables = subprocess_mock.method_calls[0].kwargs["env"]
- call_args = subprocess_mock.method_calls[0].args[0]
- assert call_args == [
- sys.executable,
- "-m",
- "streamlit",
- "run",
- ANY,
- "--server.address",
- "hostname",
- "--server.port",
- "1111",
- "--server.baseUrlPath",
- "root.my.flow",
- "--server.headless",
- "true",
- "--server.enableXsrfProtection",
- "false",
- ]
-
- assert env_variables["LIGHTNING_FLOW_NAME"] == "root.my.flow"
- assert env_variables["LIGHTNING_RENDER_FUNCTION"] == "_noop_render_fn"
- assert env_variables["LIGHTNING_RENDER_MODULE_FILE"] == __file__
-
- assert "LIGHTNING_FLOW_NAME" not in os.environ
- assert "LIGHTNING_RENDER_FUNCTION" not in os.environ
- assert "LIGHTNING_RENDER_MODULE_FILE" not in os.environ
-
- frontend.stop_server()
- subprocess_mock.Popen().kill.assert_called_once()
-
-
-def _streamlit_call_me(state):
- assert isinstance(state, AppState)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_FLOW_NAME": "root",
- "LIGHTNING_RENDER_FUNCTION": "_streamlit_call_me",
- "LIGHTNING_RENDER_MODULE_FILE": __file__,
- },
-)
-def test_streamlit_wrapper_calls_render_fn(*_):
- runpy.run_module("lightning.app.frontend.streamlit_base")
- # TODO: find a way to assert that _streamlit_call_me got called
-
-
-@pytest.mark.skipif(not _STREAMLIT_AVAILABLE, reason="requires streamlit")
-def test_method_exception():
- class A:
- def render_fn(self):
- pass
-
- with pytest.raises(TypeError, match="being a method"):
- StreamlitFrontend(render_fn=A().render_fn)
diff --git a/tests/tests_app/frontend/test_utils.py b/tests/tests_app/frontend/test_utils.py
deleted file mode 100644
index 367b95e32bcca..0000000000000
--- a/tests/tests_app/frontend/test_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""We have some utility functions that can be used across frontends."""
-
-from lightning.app.frontend.utils import _get_flow_state, _get_frontend_environment
-from lightning.app.utilities.state import AppState
-
-
-def test_get_flow_state(flow_state_state: dict, flow):
- """We have a method to get an AppState scoped to the Flow state."""
- # When
- flow_state = _get_flow_state(flow)
- # Then
- assert isinstance(flow_state, AppState)
- assert flow_state._state == flow_state_state # pylint: disable=protected-access
-
-
-def some_fn(_):
- """Be lazy!"""
-
-
-def test_get_frontend_environment_fn():
- """We have a utility function to get the frontend render_fn environment."""
- # When
- env = _get_frontend_environment(flow="root.lit_frontend", render_fn_or_file=some_fn, host="myhost", port=1234)
- # Then
- assert env["LIGHTNING_FLOW_NAME"] == "root.lit_frontend"
- assert env["LIGHTNING_RENDER_ADDRESS"] == "myhost"
- assert env["LIGHTNING_RENDER_FUNCTION"] == "some_fn"
- assert env["LIGHTNING_RENDER_MODULE_FILE"] == __file__
- assert env["LIGHTNING_RENDER_PORT"] == "1234"
-
-
-def test_get_frontend_environment_file():
- """We have a utility function to get the frontend render_fn environment."""
- # When
- env = _get_frontend_environment(
- flow="root.lit_frontend", render_fn_or_file="app_panel.py", host="myhost", port=1234
- )
- # Then
- assert env["LIGHTNING_FLOW_NAME"] == "root.lit_frontend"
- assert env["LIGHTNING_RENDER_ADDRESS"] == "myhost"
- assert env["LIGHTNING_RENDER_FILE"] == "app_panel.py"
- assert env["LIGHTNING_RENDER_PORT"] == "1234"
diff --git a/tests/tests_app/frontend/test_web.py b/tests/tests_app/frontend/test_web.py
deleted file mode 100644
index 9e990ec911295..0000000000000
--- a/tests/tests_app/frontend/test_web.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import os
-from unittest import mock
-from unittest.mock import ANY, MagicMock
-
-import lightning.app
-import pytest
-from lightning.app import LightningFlow
-from lightning.app.frontend.web import StaticWebFrontend, _healthz
-from lightning.app.storage.path import _storage_root_dir
-
-
-def test_stop_server_not_running():
- frontend = StaticWebFrontend(serve_dir=".")
- with pytest.raises(RuntimeError, match="Server is not running."):
- frontend.stop_server()
-
-
-class MockFlow(LightningFlow):
- @property
- def name(self):
- return "root.my.flow"
-
- def run(self):
- pass
-
-
-@mock.patch("lightning.app.frontend.web.mp.Process")
-def test_start_stop_server_through_frontend(process_mock):
- frontend = StaticWebFrontend(serve_dir=".")
- frontend.flow = MockFlow()
- frontend.start_server("localhost", 5000)
- log_file_root = _storage_root_dir()
- process_mock.assert_called_once_with(
- target=lightning.app.frontend.web._start_server,
- kwargs={
- "host": "localhost",
- "port": 5000,
- "serve_dir": ".",
- "path": "/root.my.flow",
- "log_file": os.path.join(log_file_root, "frontend", "logs.log"),
- "root_path": "",
- },
- )
- process_mock().start.assert_called_once()
- frontend.stop_server()
- process_mock().kill.assert_called_once()
-
-
-@mock.patch("lightning.app.frontend.web.uvicorn")
-@pytest.mark.parametrize("root_path", ["", "/base"])
-def test_start_server_through_function(uvicorn_mock, tmpdir, monkeypatch, root_path):
- FastAPIMock = MagicMock()
- FastAPIMock.mount = MagicMock()
- FastAPIGetDecoratorMock = MagicMock()
- FastAPIMock.get.return_value = FastAPIGetDecoratorMock
- monkeypatch.setattr(lightning.app.frontend.web, "FastAPI", MagicMock(return_value=FastAPIMock))
-
- lightning.app.frontend.web._start_server(
- serve_dir=tmpdir, host="myhost", port=1000, path="/test-flow", root_path=root_path
- )
- uvicorn_mock.run.assert_called_once_with(app=ANY, host="myhost", port=1000, log_config=ANY, root_path=root_path)
-
- FastAPIMock.mount.assert_called_once_with(root_path or "/test-flow", ANY, name="static")
- FastAPIMock.get.assert_called_once_with("/test-flow/healthz", status_code=200)
-
- FastAPIGetDecoratorMock.assert_called_once_with(_healthz)
-
- # path has default value "/"
- FastAPIMock.mount = MagicMock()
- lightning.app.frontend.web._start_server(serve_dir=tmpdir, host="myhost", port=1000, root_path=root_path)
- FastAPIMock.mount.assert_called_once_with(root_path or "/", ANY, name="static")
-
-
-def test_healthz():
- assert _healthz() == {"status": "ok"}
-
-
-@mock.patch("lightning.app.frontend.web.uvicorn")
-def test_start_server_find_free_port(uvicorn_mock, tmpdir):
- lightning.app.frontend.web._start_server(serve_dir=tmpdir, host="myhost")
- assert uvicorn_mock.run.call_args_list[0].kwargs["port"] > 0
diff --git a/tests/tests_app/frontend/utilities/__init__.py b/tests/tests_app/frontend/utilities/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/helpers/__init__.py b/tests/tests_app/helpers/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/launcher/test_lightning_backend.py b/tests/tests_app/launcher/test_lightning_backend.py
deleted file mode 100644
index 5e60d7930bd7f..0000000000000
--- a/tests/tests_app/launcher/test_lightning_backend.py
+++ /dev/null
@@ -1,808 +0,0 @@
-import json
-import os
-from copy import copy
-from datetime import datetime
-from unittest import mock
-from unittest.mock import ANY, MagicMock, Mock
-
-import pytest
-from lightning.app import BuildConfig, CloudCompute, LightningWork
-from lightning.app.launcher.lightning_backend import CloudBackend
-from lightning.app.storage import Drive, Mount
-from lightning.app.testing.helpers import EmptyWork
-from lightning.app.utilities.enum import WorkFailureReasons, WorkStageStatus
-from lightning.app.utilities.exceptions import LightningPlatformException
-from lightning_cloud.openapi import Body5, V1DriveType, V1LightningworkState, V1SourceType
-from lightning_cloud.openapi.rest import ApiException
-
-
-class WorkWithDrive(LightningWork):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.drive = None
-
- def run(self):
- pass
-
-
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_no_update_when_no_works(client_mock):
- cloud_backend = CloudBackend("")
- cloud_backend._get_cloud_work_specs = Mock()
- client_mock.assert_called_once()
- cloud_backend.update_work_statuses(works=[])
- cloud_backend._get_cloud_work_specs.assert_not_called()
-
-
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_no_update_when_all_work_has_started(client_mock):
- cloud_backend = CloudBackend("")
- cloud_backend._get_cloud_work_specs = MagicMock()
- client_mock.assert_called_once()
- started_mock = MagicMock()
- started_mock.has_started = True
-
- # all works have started
- works = [started_mock, started_mock]
- cloud_backend.update_work_statuses(works=works)
- cloud_backend._get_cloud_work_specs.assert_called_once()
-
-
-@mock.patch("lightning.app.launcher.lightning_backend.monotonic")
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_no_update_within_interval(client_mock, monotonic_mock):
- cloud_backend = CloudBackend("", status_update_interval=2)
- cloud_backend._get_cloud_work_specs = Mock()
- client_mock.assert_called_once()
- cloud_backend._last_time_updated = 1
- monotonic_mock.return_value = 2
-
- stopped_mock = Mock()
- stopped_mock.has_started = False
-
- # not all works have started
- works = [stopped_mock, stopped_mock]
-
- cloud_backend.update_work_statuses(works=works)
- cloud_backend._get_cloud_work_specs.assert_not_called()
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.monotonic")
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_within_interval(client_mock, monotonic_mock):
- cloud_backend = CloudBackend("", status_update_interval=2)
- cloud_backend._last_time_updated = 1
- # pretend a lot of time has passed since the last update
- monotonic_mock.return_value = 8
-
- stopped_mock1 = Mock()
- stopped_mock1.has_started = False
- stopped_mock1.name = "root.mock1"
- stopped_mock2 = Mock()
- stopped_mock2.has_started = False
- stopped_mock2.name = "root.mock2"
-
- spec1 = Mock()
- spec1.name = "root.mock1"
- spec2 = Mock()
- spec2.name = "root.mock2"
-
- # not all works have started
- works = [stopped_mock1, stopped_mock2]
-
- cloud_backend.update_work_statuses(works=works)
- client_mock().lightningwork_service_list_lightningwork.assert_called_with(project_id="project_id", app_id="app_id")
-
- # TODO: assert calls on the work mocks
- # ...
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_stop_all_works(mock_client):
- work_a = EmptyWork()
- work_a._name = "root.work_a"
- work_a._calls = {
- "latest_call_hash": "some_call_hash",
- "some_call_hash": {
- "statuses": [
- {
- "stage": WorkStageStatus.FAILED,
- "timestamp": int(datetime.now().timestamp()),
- "reason": WorkFailureReasons.USER_EXCEPTION,
- },
- ]
- },
- }
-
- work_b = EmptyWork()
- work_b._name = "root.work_b"
- work_b._calls = {
- "latest_call_hash": "some_call_hash",
- "some_call_hash": {
- "statuses": [{"stage": WorkStageStatus.RUNNING, "timestamp": int(datetime.now().timestamp()), "reason": ""}]
- },
- }
-
- cloud_backend = CloudBackend("")
-
- spec1 = Mock()
- spec1.name = "root.work_a"
- spec1.spec.desired_state = V1LightningworkState.RUNNING
- spec1.status.phase = V1LightningworkState.FAILED
- spec2 = Mock()
- spec2.name = "root.work_b"
- spec2.spec.desired_state = V1LightningworkState.RUNNING
-
- class BackendMock:
- def __init__(self):
- self.called = 0
-
- def _get_cloud_work_specs(self, *_):
- value = [spec1, spec2] if not self.called else []
- self.called += 1
- return value
-
- cloud_backend._get_cloud_work_specs = BackendMock()._get_cloud_work_specs
- cloud_backend.stop_all_works([work_a, work_b])
-
- mock_client().lightningwork_service_update_lightningwork.assert_called_with(
- project_id="project_id",
- id=ANY,
- spec_lightningapp_instance_id="app_id",
- body=ANY,
- )
- assert spec1.spec.desired_state == V1LightningworkState.RUNNING
- assert spec2.spec.desired_state == V1LightningworkState.STOPPED
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_stop_work(mock_client):
- work = EmptyWork()
- work._name = "root.work"
- work._calls = {
- "latest_call_hash": "some_call_hash",
- "some_call_hash": {
- "statuses": [
- {
- "stage": WorkStageStatus.RUNNING,
- "timestamp": int(datetime.now().timestamp()),
- "reason": "",
- },
- ]
- },
- }
-
- cloud_backend = CloudBackend("")
- spec1 = Mock()
- spec1.name = "root.work"
- spec1.spec.desired_state = V1LightningworkState.RUNNING
-
- spec2 = Mock()
- spec2.name = "root.work_b"
- spec2.spec.desired_state = V1LightningworkState.RUNNING
-
- class BackendMock:
- def __init__(self):
- self.called = 0
-
- def _get_cloud_work_specs(self, *_):
- value = [spec1, spec2] if not self.called else []
- self.called += 1
- return value
-
- cloud_backend._get_cloud_work_specs = BackendMock()._get_cloud_work_specs
- cloud_backend.stop_work(MagicMock(), work)
-
- mock_client().lightningwork_service_update_lightningwork.assert_called_with(
- project_id="project_id",
- id=ANY,
- spec_lightningapp_instance_id="app_id",
- body=ANY,
- )
- assert spec1.spec.desired_state == V1LightningworkState.STOPPED
- assert spec2.spec.desired_state == V1LightningworkState.RUNNING
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_create_work_where_work_does_not_exists(mock_client):
- cloud_backend = CloudBackend("")
- non_matching_spec = Mock()
- app = MagicMock()
- work = EmptyWork(port=1111)
- work._name = "name"
-
- def lightningwork_service_create_lightningwork(
- project_id: str = None,
- spec_lightningapp_instance_id: str = None,
- body: "Body5" = None,
- ):
- assert project_id == "project_id"
- assert spec_lightningapp_instance_id == "app_id"
- assert len(body.spec.network_config) == 1
- assert body.spec.network_config[0].port == 1111
- assert not body.spec.network_config[0].host
- body.spec.network_config[0].host = "x.lightning.ai"
- return body
-
- response_mock = Mock()
- response_mock.lightningworks = [non_matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
- mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork
-
- cloud_backend.create_work(app, work)
- assert work._future_url == "https://x.lightning.ai"
- app.work_queues["name"].put.assert_called_once_with(work)
-
- # testing whether the exception is raised correctly when the backend throws on work creation
- http_resp = MagicMock()
- error_message = "exception generated from test_create_work_where_work_does_not_exists test case"
- http_resp.data = json.dumps({"message": error_message})
- mock_client().lightningwork_service_create_lightningwork = MagicMock()
- mock_client().lightningwork_service_create_lightningwork.side_effect = ApiException(http_resp=http_resp)
- with pytest.raises(LightningPlatformException, match=error_message):
- cloud_backend.create_work(app, work)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_create_work_with_drives_where_work_does_not_exists(mock_client, tmpdir):
- cloud_backend = CloudBackend("")
- non_matching_spec = Mock()
- app = MagicMock()
-
- mocked_drive = MagicMock(spec=Drive)
- setattr(mocked_drive, "id", "foobar")
- setattr(mocked_drive, "protocol", "lit://")
- setattr(mocked_drive, "component_name", "test-work")
- setattr(mocked_drive, "allow_duplicates", False)
- setattr(mocked_drive, "root_folder", tmpdir)
- # deepcopy on a MagicMock instance will return an empty magicmock instance. To
- # overcome this we set the __deepcopy__ method `return_value` to equal what
- # should be the results of the deepcopy operation (an instance of the original class)
- mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
-
- work = WorkWithDrive(port=1111)
- work._name = "test-work-name"
- work.drive = mocked_drive
-
- def lightningwork_service_create_lightningwork(
- project_id: str = None,
- spec_lightningapp_instance_id: str = None,
- body: "Body5" = None,
- ):
- assert project_id == "project_id"
- assert spec_lightningapp_instance_id == "app_id"
- assert len(body.spec.network_config) == 1
- assert body.spec.network_config[0].port == 1111
- assert not body.spec.network_config[0].host
- body.spec.network_config[0].host = "x.lightning.ai"
- assert len(body.spec.drives) == 1
- assert body.spec.drives[0].drive.spec.drive_type == V1DriveType.NO_MOUNT_S3
- assert body.spec.drives[0].drive.spec.source_type == V1SourceType.S3
- assert body.spec.drives[0].drive.spec.source == "lit://foobar"
- assert body.spec.drives[0].drive.metadata.name == "test-work-name.drive"
- for v in body.spec.drives[0].drive.status.to_dict().values():
- assert v is None
-
- return body
-
- response_mock = Mock()
- response_mock.lightningworks = [non_matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
- mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork
-
- cloud_backend.create_work(app, work)
- assert work._future_url == "https://x.lightning.ai"
- app.work_queues["test-work-name"].put.assert_called_once_with(work)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- "LIGHTNING_PROXY_SCHEME": "http",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_create_work_proxy_http(mock_client, tmpdir):
- cloud_backend = CloudBackend("")
- non_matching_spec = Mock()
- app = MagicMock()
-
- mocked_drive = MagicMock(spec=Drive)
- setattr(mocked_drive, "id", "foobar")
- setattr(mocked_drive, "protocol", "lit://")
- setattr(mocked_drive, "component_name", "test-work")
- setattr(mocked_drive, "allow_duplicates", False)
- setattr(mocked_drive, "root_folder", tmpdir)
- # deepcopy on a MagicMock instance will return an empty magicmock instance. To
- # overcome this we set the __deepcopy__ method `return_value` to equal what
- # should be the results of the deepcopy operation (an instance of the original class)
- mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
-
- work = WorkWithDrive(port=1111)
- work._name = "test-work-name"
- work.drive = mocked_drive
-
- def lightningwork_service_create_lightningwork(
- project_id: str = None,
- spec_lightningapp_instance_id: str = None,
- body: "Body5" = None,
- ):
- assert project_id == "project_id"
- assert spec_lightningapp_instance_id == "app_id"
- assert len(body.spec.network_config) == 1
- assert body.spec.network_config[0].port == 1111
- assert not body.spec.network_config[0].host
- body.spec.network_config[0].host = "x.lightning.ai"
- assert len(body.spec.drives) == 1
- assert body.spec.drives[0].drive.spec.drive_type == V1DriveType.NO_MOUNT_S3
- assert body.spec.drives[0].drive.spec.source_type == V1SourceType.S3
- assert body.spec.drives[0].drive.spec.source == "lit://foobar"
- assert body.spec.drives[0].drive.metadata.name == "test-work-name.drive"
- for v in body.spec.drives[0].drive.status.to_dict().values():
- assert v is None
-
- return body
-
- response_mock = Mock()
- response_mock.lightningworks = [non_matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
- mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork
-
- cloud_backend.create_work(app, work)
- assert work._future_url == "http://x.lightning.ai"
- app.work_queues["test-work-name"].put.assert_called_once_with(work)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_work_with_changed_compute_config_with_mounts(mock_client):
- cloud_backend = CloudBackend("")
- matching_spec = Mock()
- app = MagicMock()
- work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1"))
- work._name = "work_name"
-
- matching_spec.spec = cloud_backend._work_to_spec(work)
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- matching_spec.name = "work_name"
-
- response_mock = Mock()
- response_mock.lightningworks = [matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
-
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.desired_state
- == V1LightningworkState.RUNNING
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.user_requested_compute_config.name
- == "cpu-small"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.build_spec.image
- == "image1"
- )
-
- # resetting the values changed in the previous step
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock()
-
- # new work with same name but different compute config
- mount = Mount(source="s3://foo/", mount_path="/foo")
- work = EmptyWork(cloud_compute=CloudCompute("gpu", mounts=mount), cloud_build_config=BuildConfig(image="image2"))
- work._name = "work_name"
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.desired_state
- == V1LightningworkState.RUNNING
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.user_requested_compute_config.name
- == "gpu"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"]
- .spec.drives[0]
- .mount_location
- == "/foo"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"]
- .spec.drives[0]
- .drive.spec.source
- == "s3://foo/"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.build_spec.image
- == "image2"
- )
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_create_work_where_work_already_exists(mock_client):
- cloud_backend = CloudBackend("")
- matching_spec = Mock()
- app = MagicMock()
- work = EmptyWork(port=1111)
- work._name = "work_name"
- work._backend = cloud_backend
-
- matching_spec.spec = cloud_backend._work_to_spec(work)
- matching_spec.spec.network_config[0].host = "x.lightning.ai"
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- matching_spec.name = "work_name"
-
- response_mock = Mock()
- response_mock.lightningworks = [matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
-
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.desired_state
- == V1LightningworkState.RUNNING
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"]
- .spec.network_config[0]
- .port
- == 1111
- )
- assert work._future_url == "https://x.lightning.ai"
- app.work_queues["work_name"].put.assert_called_once_with(work)
-
- # resetting the values changed in the previous step
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock()
- app.work_queues["work_name"].put.reset_mock()
-
- # changing the port
- work._port = 2222
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"]
- .spec.network_config[0]
- .port
- == 2222
- )
- app.work_queues["work_name"].put.assert_called_once_with(work)
-
- # testing whether the exception is raised correctly when the backend throws on work creation
- # resetting the values changed in the previous step
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- http_resp = MagicMock()
- error_message = "exception generated from test_create_work_where_work_already_exists test case"
- http_resp.data = json.dumps({"message": error_message})
- mock_client().lightningwork_service_update_lightningwork = MagicMock()
- mock_client().lightningwork_service_update_lightningwork.side_effect = ApiException(http_resp=http_resp)
- with pytest.raises(LightningPlatformException, match=error_message):
- cloud_backend.create_work(app, work)
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_create_work_will_have_none_backend(mockclient):
- def queue_put_mock(work):
- # because we remove backend before pushing to queue
- assert work._backend is None
-
- cloud_backend = CloudBackend("")
- app = MagicMock()
- work = EmptyWork()
- # attaching backend - this will be removed by the queue
- work._backend = cloud_backend
- app.work_queues["work_name"].put = queue_put_mock
- cloud_backend.create_work(app, work)
- # make sure the work still have the backend attached
- assert work._backend == cloud_backend
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_work_with_changed_compute_config_and_build_spec(mock_client):
- cloud_backend = CloudBackend("")
- matching_spec = Mock()
- app = MagicMock()
- work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1"))
- work._name = "work_name"
-
- matching_spec.spec = cloud_backend._work_to_spec(work)
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- matching_spec.name = "work_name"
-
- response_mock = Mock()
- response_mock.lightningworks = [matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
-
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.desired_state
- == V1LightningworkState.RUNNING
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.user_requested_compute_config.name
- == "cpu-small"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.build_spec.image
- == "image1"
- )
-
- # resetting the values changed in the previous step
- matching_spec.spec.desired_state = V1LightningworkState.STOPPED
- cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock()
-
- # new work with same name but different compute config
- work = EmptyWork(cloud_compute=CloudCompute("gpu"), cloud_build_config=BuildConfig(image="image2"))
- work._name = "work_name"
- cloud_backend.create_work(app, work)
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.desired_state
- == V1LightningworkState.RUNNING
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.user_requested_compute_config.name
- == "gpu"
- )
- assert (
- cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[
- "body"
- ].spec.build_spec.image
- == "image2"
- )
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_work_with_changed_spec_while_work_running(mock_client):
- cloud_backend = CloudBackend("")
- matching_spec = Mock()
- app = MagicMock()
- work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1"))
- work._name = "work_name"
-
- matching_spec.spec = cloud_backend._work_to_spec(work)
- matching_spec.spec.desired_state = V1LightningworkState.RUNNING
- matching_spec.name = "work_name"
-
- response_mock = Mock()
- response_mock.lightningworks = [matching_spec]
- mock_client().lightningwork_service_list_lightningwork.return_value = response_mock
-
- cloud_backend.create_work(app, work)
-
- # asserting the method is not called
- cloud_backend.client.lightningwork_service_update_lightningwork.assert_not_called()
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_lightning_app_frontend_new_frontends(mock_client):
- cloud_backend = CloudBackend("")
- cloud_backend.client = mock_client
- mocked_app = MagicMock()
- mocked_app.frontends.keys.return_value = ["frontend2", "frontend1"]
- app_instance_mock = MagicMock()
- app_instance_mock.spec.flow_servers = []
- update_lightning_app_instance_mock = MagicMock()
- mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = app_instance_mock
- mock_client.lightningapp_instance_service_update_lightningapp_instance.return_value = (
- update_lightning_app_instance_mock
- )
- cloud_backend.update_lightning_app_frontend(mocked_app)
- assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 1
-
- # frontends should be sorted
- assert (
- mock_client.lightningapp_instance_service_update_lightningapp_instance.call_args.kwargs["body"]
- .spec.flow_servers[0]
- .name
- == "frontend1"
- )
- assert (
- mock_client.lightningapp_instance_service_update_lightningapp_instance.call_args.kwargs["body"]
- .spec.flow_servers[1]
- .name
- == "frontend2"
- )
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_update_lightning_app_frontend_existing_frontends(mock_client):
- cloud_backend = CloudBackend("")
- cloud_backend.client = mock_client
- mocked_app = MagicMock()
- mocked_app.frontends.keys.return_value = ["frontend2", "frontend1"]
- app_instance_mock = MagicMock()
- app_instance_mock.spec.flow_servers = ["frontend2", "frontend1"]
- update_lightning_app_instance_mock = MagicMock()
- mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = app_instance_mock
- mock_client.lightningapp_instance_service_update_lightningapp_instance.return_value = (
- update_lightning_app_instance_mock
- )
- cloud_backend.update_lightning_app_frontend(mocked_app)
-
- # the app spec already has the frontends, so no update should be called
- assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 0
- assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 0
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock())
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_stop_app(mock_client):
- cloud_backend = CloudBackend("")
- external_spec = MagicMock()
- mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = external_spec
- cloud_backend.client = mock_client
- mocked_app = MagicMock()
- cloud_backend.stop_app(mocked_app)
- spec = mock_client.lightningapp_instance_service_update_lightningapp_instance._mock_call_args.kwargs["body"].spec
- assert spec.desired_state == "LIGHTNINGAPP_INSTANCE_STATE_STOPPED"
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_failed_works_during_pending(client_mock):
- cloud_backend = CloudBackend("")
- cloud_work = MagicMock()
- cloud_work.name = "a"
- cloud_work.status.phase = V1LightningworkState.FAILED
- cloud_backend._get_cloud_work_specs = MagicMock(return_value=[cloud_work])
-
- local_work = MagicMock()
- local_work.status.stage = "pending"
- local_work.name = "a"
- local_work._raise_exception = True
-
- with pytest.raises(Exception, match="The work a failed during pending phase."):
- # all works have started
- cloud_backend.update_work_statuses(works=[local_work])
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_CLOUD_PROJECT_ID": "project_id",
- "LIGHTNING_CLOUD_APP_ID": "app_id",
- },
-)
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_work_delete(client_mock):
- cloud_backend = CloudBackend("")
- cloud_work = MagicMock()
- cloud_work.name = "a"
- cloud_work.status.phase = V1LightningworkState.RUNNING
- cloud_backend._get_cloud_work_specs = MagicMock(return_value=[cloud_work])
-
- local_work = MagicMock()
- local_work.status.stage = "running"
- local_work.name = "a"
- local_work._raise_exception = True
- cloud_backend.delete_work(None, local_work)
- call = cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args_list[0]
- assert call.kwargs["body"].spec.desired_state == V1LightningworkState.DELETED
diff --git a/tests/tests_app/launcher/test_lightning_hydrid.py b/tests/tests_app/launcher/test_lightning_hydrid.py
deleted file mode 100644
index 695d3216a2316..0000000000000
--- a/tests/tests_app/launcher/test_lightning_hydrid.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from unittest import mock
-
-from lightning.app import CloudCompute
-from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend
-
-
-@mock.patch("lightning.app.launcher.lightning_backend.LightningClient")
-def test_backend_selection(client_mock):
- cloud_backend = CloudHybridBackend("", queue_id="")
- work = mock.MagicMock()
- work.cloud_compute = CloudCompute()
- assert cloud_backend._get_backend(work) == cloud_backend.backends["multiprocess"]
- work.cloud_compute = CloudCompute("gpu")
- assert cloud_backend._get_backend(work) == cloud_backend.backends["cloud"]
diff --git a/tests/tests_app/launcher/test_running_flow.py b/tests/tests_app/launcher/test_running_flow.py
deleted file mode 100644
index 228047f0b0b8a..0000000000000
--- a/tests/tests_app/launcher/test_running_flow.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import logging
-import os
-import signal
-import sys
-from unittest import mock
-from unittest.mock import MagicMock, Mock
-
-import pytest
-import requests
-from lightning.app.launcher import launcher, lightning_backend
-from lightning.app.utilities.app_helpers import convert_print_to_logger_info
-from lightning.app.utilities.enum import AppStage
-from lightning.app.utilities.exceptions import ExitAppException
-
-
-def _make_mocked_network_config(key, host):
- network_config = Mock()
- network_config.name = key
- network_config.host = host
- return network_config
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", mock.MagicMock())
-@mock.patch("lightning.app.launcher.launcher.check_if_redis_running", MagicMock(return_value=True))
-def test_running_flow(monkeypatch):
- app = MagicMock()
- flow = MagicMock()
- work = MagicMock()
- work.run.__name__ = "run"
- flow._layout = {}
- flow.name = "flowname"
- work.name = "workname"
- app.flows = [flow]
- flow.works.return_value = [work]
-
- def load_app_from_file(file):
- assert file == "file.py"
- return app
-
- class BackendMock:
- def __init__(self, return_value):
- self.called = 0
- self.return_value = return_value
-
- def _get_cloud_work_specs(self, *_):
- value = self.return_value if not self.called else []
- self.called += 1
- return value
-
- cloud_work_spec = Mock()
- cloud_work_spec.name = "workname"
- cloud_work_spec.spec.network_config = [
- _make_mocked_network_config("key1", "x.lightning.ai"),
- ]
- monkeypatch.setattr(launcher, "load_app_from_file", load_app_from_file)
- monkeypatch.setattr(launcher, "start_server", MagicMock())
- monkeypatch.setattr(lightning_backend, "LightningClient", MagicMock())
- lightning_backend.CloudBackend._get_cloud_work_specs = BackendMock(
- return_value=[cloud_work_spec]
- )._get_cloud_work_specs
- monkeypatch.setattr(lightning_backend.CloudBackend, "_get_project_id", MagicMock())
- monkeypatch.setattr(lightning_backend.CloudBackend, "_get_app_id", MagicMock())
- queue_system = MagicMock()
- queue_system.REDIS = MagicMock()
- monkeypatch.setattr(launcher, "QueuingSystem", queue_system)
- monkeypatch.setattr(launcher, "StorageOrchestrator", MagicMock())
-
- response = MagicMock()
- response.status_code = 200
- monkeypatch.setattr(requests, "get", MagicMock(return_value=response))
-
- # testing with correct base URL
- with pytest.raises(SystemExit, match="0"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="http://localhost:8080")
- assert flow._layout["target"] == "http://localhost:8080/flowname/"
-
- app._run.assert_called_once()
-
- # testing with invalid base URL
- with pytest.raises(ValueError, match="Base URL doesn't have a valid scheme"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080")
-
- app.flows = []
-
- def run_patch():
- raise Exception
-
- app._run = run_patch
-
- with pytest.raises(SystemExit, match="1"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080")
-
- def run_patch():
- app.stage = AppStage.FAILED
-
- app._run = run_patch
-
- with pytest.raises(SystemExit, match="1"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080")
-
- def run_patch():
- raise ExitAppException
-
- if sys.platform == "win32":
- return
-
- app.stage = AppStage.STOPPING
-
- app._run = run_patch
- with pytest.raises(SystemExit, match="0"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080")
-
- def run_method():
- os.kill(os.getpid(), signal.SIGTERM)
-
- app._run = run_method
- monkeypatch.setattr(lightning_backend.CloudBackend, "resolve_url", MagicMock())
- with pytest.raises(SystemExit, match="0"):
- launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080")
- assert app.stage == AppStage.STOPPING
-
-
-def test_replace_print_to_info(caplog, monkeypatch):
- monkeypatch.setattr("lightning.app._logger", logging.getLogger())
-
- @convert_print_to_logger_info
- def fn_captured(value):
- print(value)
-
- with caplog.at_level(logging.INFO):
- fn_captured(1)
-
- assert caplog.messages == ["1"]
diff --git a/tests/tests_app/plugin/__init__.py b/tests/tests_app/plugin/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/plugin/test_plugin.py b/tests/tests_app/plugin/test_plugin.py
deleted file mode 100644
index c4867af04dd2b..0000000000000
--- a/tests/tests_app/plugin/test_plugin.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import io
-import json
-import sys
-import tarfile
-from dataclasses import dataclass
-from pathlib import Path
-from unittest import mock
-
-import pytest
-from fastapi import status
-from fastapi.testclient import TestClient
-from lightning.app.plugin.plugin import _Run, _start_plugin_server
-from lightning_cloud.openapi import Externalv1LightningappInstance
-
-
-@pytest.fixture()
-@mock.patch("lightning.app.plugin.plugin.uvicorn")
-def mock_plugin_server(mock_uvicorn) -> TestClient:
- """This fixture returns a `TestClient` for the plugin server."""
- test_client = {}
-
- def create_test_client(app, **_):
- test_client["client"] = TestClient(app)
-
- mock_uvicorn.run.side_effect = create_test_client
-
- _start_plugin_server(8888)
-
- return test_client["client"]
-
-
-@dataclass
-class _MockResponse:
- content: bytes
-
- def raise_for_status(self):
- pass
-
-
-def mock_requests_get(valid_url, return_value):
- """Used to replace `requests.get` with a function that returns the given value for the given valid URL and raises
- otherwise."""
-
- def inner(url):
- if url == valid_url:
- return _MockResponse(return_value)
- raise RuntimeError
-
- return inner
-
-
-def as_tar_bytes(file_name, content):
- """Utility to encode the given string as a gzipped tar and return the bytes."""
- tar_fileobj = io.BytesIO()
- with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar:
- content = content.encode("utf-8")
- tf = tarfile.TarInfo(file_name)
- tf.size = len(content)
- tar.addfile(tf, io.BytesIO(content))
- tar_fileobj.seek(0)
- return tar_fileobj.read()
-
-
-_plugin_with_internal_error = """
-from lightning.app.plugin.plugin import LightningPlugin
-
-class TestPlugin(LightningPlugin):
- def run(self):
- raise RuntimeError("Internal Error")
-
-plugin = TestPlugin()
-"""
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.")
-@pytest.mark.parametrize(
- ("body", "message", "tar_file_name", "content"),
- [
- (
- _Run(
- plugin_entrypoint="test",
- source_code_url="this_url_does_not_exist",
- project_id="any",
- cloudspace_id="any",
- cluster_id="any",
- plugin_arguments={},
- source_app="any",
- keep_machines_after_stop=False,
- ),
- "Error downloading plugin source:",
- None,
- b"",
- ),
- (
- _Run(
- plugin_entrypoint="test",
- source_code_url="http://test.tar.gz",
- project_id="any",
- cloudspace_id="any",
- cluster_id="any",
- plugin_arguments={},
- source_app="any",
- keep_machines_after_stop=False,
- ),
- "Error extracting plugin source:",
- None,
- b"this is not a tar",
- ),
- (
- _Run(
- plugin_entrypoint="plugin.py",
- source_code_url="http://test.tar.gz",
- project_id="any",
- cloudspace_id="any",
- cluster_id="any",
- plugin_arguments={},
- source_app="any",
- keep_machines_after_stop=False,
- ),
- "Error loading plugin:",
- "plugin.py",
- "this is not a plugin",
- ),
- (
- _Run(
- plugin_entrypoint="plugin.py",
- source_code_url="http://test.tar.gz",
- project_id="any",
- cloudspace_id="any",
- cluster_id="any",
- plugin_arguments={},
- source_app="any",
- keep_machines_after_stop=False,
- ),
- "Error running plugin:",
- "plugin.py",
- _plugin_with_internal_error,
- ),
- ],
-)
-@mock.patch("lightning.app.plugin.plugin.requests")
-def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_name, content):
- if tar_file_name is not None:
- content = as_tar_bytes(tar_file_name, content)
-
- mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content)
-
- response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True))
-
- assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
- assert message in response.text
-
-
-_plugin_with_job_run = """
-from lightning.app.plugin.plugin import LightningPlugin
-
-class TestPlugin(LightningPlugin):
- def run(self, name, entrypoint):
- return self.run_job(name, entrypoint)
-
-plugin = TestPlugin()
-"""
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.")
-@mock.patch("lightning.app.runners.backends.cloud.CloudBackend")
-@mock.patch("lightning.app.runners.cloud.CloudRuntime")
-@mock.patch("lightning.app.plugin.plugin.requests")
-def test_run_job(mock_requests, mock_cloud_runtime, mock_cloud_backend, mock_plugin_server):
- """Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct arguments."""
- content = as_tar_bytes("plugin.py", _plugin_with_job_run)
- mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content)
-
- body = _Run(
- plugin_entrypoint="plugin.py",
- source_code_url="http://test.tar.gz",
- project_id="test_project_id",
- cloudspace_id="test_cloudspace_id",
- cluster_id="test_cluster_id",
- plugin_arguments={"name": "test_name", "entrypoint": "test_entrypoint"},
- source_app="test_source_app",
- keep_machines_after_stop=True,
- )
-
- mock_app = mock.MagicMock()
- mock_cloud_runtime.load_app_from_file.return_value = mock_app
- mock_cloud_runtime.return_value.cloudspace_dispatch.return_value = Externalv1LightningappInstance(
- id="created_app_id"
- )
-
- response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True))
-
- assert response.status_code == status.HTTP_200_OK, response.json()
- assert json.loads(response.text)["id"] == "created_app_id"
-
- mock_cloud_runtime.load_app_from_file.assert_called_once()
- assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0]
-
- mock_cloud_runtime.assert_called_once_with(
- app=mock_app,
- entrypoint=Path("test_entrypoint"),
- start_server=True,
- env_vars={},
- secrets={},
- run_app_comment_commands=True,
- backend=mock.ANY,
- )
-
- mock_cloud_runtime().cloudspace_dispatch.assert_called_once_with(
- project_id=body.project_id,
- cloudspace_id=body.cloudspace_id,
- name="test_name",
- cluster_id=body.cluster_id,
- source_app=body.source_app,
- keep_machines_after_stop=body.keep_machines_after_stop,
- )
-
-
-def test_healthz(mock_plugin_server):
- """Smoke test for the healthz endpoint."""
- assert mock_plugin_server.get("/healthz").status_code == 200
diff --git a/tests/tests_app/runners/__init__.py b/tests/tests_app/runners/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/runners/backends/__init__.py b/tests/tests_app/runners/backends/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/runners/backends/test_mp_process.py b/tests/tests_app/runners/backends/test_mp_process.py
deleted file mode 100644
index ae22a9133124f..0000000000000
--- a/tests/tests_app/runners/backends/test_mp_process.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from unittest import mock
-from unittest.mock import MagicMock, Mock
-
-from lightning.app import LightningApp, LightningWork
-from lightning.app.runners.backends import MultiProcessingBackend
-
-
-@mock.patch("lightning.app.core.app.AppStatus")
-@mock.patch("lightning.app.runners.backends.mp_process.multiprocessing")
-def test_backend_create_work_with_set_start_method(multiprocessing_mock, *_):
- backend = MultiProcessingBackend(entrypoint_file="fake.py")
- work = Mock(spec=LightningWork)
- work._start_method = "test_start_method"
-
- app = LightningApp(work)
- app.caller_queues = MagicMock()
- app.delta_queue = MagicMock()
- app.readiness_queue = MagicMock()
- app.error_queue = MagicMock()
- app.request_queues = MagicMock()
- app.response_queues = MagicMock()
- app.copy_request_queues = MagicMock()
- app.copy_response_queues = MagicMock()
- app.flow_to_work_delta_queues = MagicMock()
-
- backend.create_work(app=app, work=work)
- multiprocessing_mock.get_context.assert_called_with("test_start_method")
- multiprocessing_mock.get_context().Process().start.assert_called_once()
diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py
deleted file mode 100644
index 74b74c99a8049..0000000000000
--- a/tests/tests_app/runners/test_cloud.py
+++ /dev/null
@@ -1,2113 +0,0 @@
-import contextlib
-import logging
-import os
-import pathlib
-import re
-import sys
-from copy import copy
-from pathlib import Path
-from unittest import mock
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app import BuildConfig, LightningApp, LightningFlow, LightningWork
-from lightning.app.runners import CloudRuntime, backends, cloud
-from lightning.app.source_code.copytree import _copytree, _parse_lightningignore
-from lightning.app.source_code.local import LocalSourceCodeDir
-from lightning.app.storage import Drive, Mount
-from lightning.app.testing.helpers import EmptyWork
-from lightning.app.utilities.cloud import _get_project
-from lightning.app.utilities.dependency_caching import get_hash
-from lightning.app.utilities.packaging.cloud_compute import CloudCompute
-from lightning_cloud.openapi import (
- CloudspaceIdRunsBody,
- Externalv1Cluster,
- Externalv1LightningappInstance,
- Gridv1ImageSpec,
- IdGetBody1,
- ProjectIdProjectclustersbindingsBody,
- V1BuildSpec,
- V1CloudSpace,
- V1CloudSpaceInstanceConfig,
- V1ClusterSpec,
- V1ClusterType,
- V1DataConnectionMount,
- V1DependencyFileInfo,
- V1Drive,
- V1DriveSpec,
- V1DriveStatus,
- V1DriveType,
- V1EnvVar,
- V1GetUserResponse,
- V1LightningappInstanceSpec,
- V1LightningappInstanceState,
- V1LightningappInstanceStatus,
- V1LightningAuth,
- V1LightningBasicAuth,
- V1LightningRun,
- V1LightningworkDrives,
- V1LightningworkSpec,
- V1ListCloudSpacesResponse,
- V1ListClustersResponse,
- V1ListLightningappInstancesResponse,
- V1ListMembershipsResponse,
- V1ListProjectClusterBindingsResponse,
- V1Membership,
- V1Metadata,
- V1NetworkConfig,
- V1PackageManager,
- V1ProjectClusterBinding,
- V1PythonDependencyInfo,
- V1QueueServerType,
- V1SourceType,
- V1UserFeatures,
- V1UserRequestedComputeConfig,
- V1UserRequestedFlowComputeConfig,
- V1Work,
-)
-
-
-class MyWork(LightningWork):
- def run(self):
- print("my run")
-
-
-class WorkWithSingleDrive(LightningWork):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.drive = None
-
- def run(self):
- pass
-
-
-class WorkWithTwoDrives(LightningWork):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.lit_drive_1 = None
- self.lit_drive_2 = None
-
- def run(self):
- pass
-
-
-def get_cloud_runtime_request_body(**kwargs) -> "CloudspaceIdRunsBody":
- default_request_body = {
- "app_entrypoint_file": mock.ANY,
- "enable_app_server": True,
- "is_headless": True,
- "should_mount_cloudspace_content": False,
- "flow_servers": [],
- "image_spec": None,
- "works": [],
- "local_source": True,
- "dependency_cache_key": mock.ANY,
- "user_requested_flow_compute_config": V1UserRequestedFlowComputeConfig(
- name="flow-lite",
- preemptible=False,
- shm_size=0,
- ),
- }
-
- if kwargs.get("user_requested_flow_compute_config") is not None:
- default_request_body["user_requested_flow_compute_config"] = kwargs["user_requested_flow_compute_config"]
-
- return CloudspaceIdRunsBody(**default_request_body)
-
-
-@pytest.fixture()
-def cloud_backend(monkeypatch):
- cloud_backend = mock.MagicMock()
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- return cloud_backend
-
-
-@pytest.fixture()
-def project_id():
- return "test-project-id"
-
-
-DEFAULT_CLUSTER = "litng-ai-03"
-
-
-class TestAppCreationClient:
- """Testing the calls made using GridRestClient to create the app."""
-
- def test_run_on_deleted_cluster(self, cloud_backend):
- app_name = "test-app"
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="Default Project", project_id=project_id)]
- )
-
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([
- Externalv1Cluster(id=DEFAULT_CLUSTER)
- ])
- cloud_backend.client = mock_client
-
- app = mock.MagicMock()
- app.flows = []
- app.frontend = {}
-
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- existing_instance.spec.cluster_id = DEFAULT_CLUSTER
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[existing_instance])
- )
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=Path("entrypoint.py"))
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- with pytest.raises(ValueError, match="that cluster doesn't exist"):
- cloud_runtime.dispatch(name=app_name, cluster_id="unknown-cluster")
-
- @pytest.mark.parametrize(
- ("old_cluster", "new_cluster"),
- [
- ("test", "other"),
- ("test", "test"),
- (None, None),
- (None, "litng-ai-03"),
- ("litng-ai-03", None),
- ],
- )
- def test_new_instance_on_different_cluster(self, tmpdir, cloud_backend, project_id, old_cluster, new_cluster):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- app_name = "test-app"
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="Default Project", project_id=project_id)]
- )
- mock_client.lightningapp_v2_service_create_lightningapp_release.return_value = V1LightningRun(
- cluster_id=new_cluster
- )
-
- # Note:
- # backend converts "None" cluster to "litng-ai-03"
- # dispatch should receive None, but API calls should return "litng-ai-03"
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([
- Externalv1Cluster(id=old_cluster or DEFAULT_CLUSTER),
- Externalv1Cluster(id=new_cluster or DEFAULT_CLUSTER),
- ])
-
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[
- V1ProjectClusterBinding(cluster_id=old_cluster or DEFAULT_CLUSTER),
- V1ProjectClusterBinding(cluster_id=new_cluster or DEFAULT_CLUSTER),
- ]
- )
-
- # Mock all clusters as global clusters
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
-
- cloud_backend.client = mock_client
-
- app = mock.MagicMock()
- app.flows = []
- app.frontend = {}
-
- existing_app = MagicMock()
- existing_app.name = app_name
- existing_app.id = "test-id"
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=[existing_app]
- )
-
- existing_instance = MagicMock()
- existing_instance.name = app_name
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- existing_instance.spec.cluster_id = old_cluster or DEFAULT_CLUSTER
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[existing_instance])
- )
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- # This is the main assertion:
- # we have an existing instance on `cluster-001`
- # but we want to run this app on `cluster-002`
- cloud_runtime.dispatch(name=app_name, cluster_id=new_cluster)
-
- if new_cluster != old_cluster and None not in (old_cluster, new_cluster):
- # If we switched cluster, check that a new name was used which starts with the old name
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once()
- args = mock_client.cloud_space_service_create_lightning_run_instance.call_args
- assert args[1]["body"].name != app_name
- assert args[1]["body"].name.startswith(app_name)
- assert args[1]["body"].cluster_id == new_cluster
-
- def test_running_deleted_app(self, tmpdir, cloud_backend, project_id):
- """Deleted apps show up in list apps but not in list instances.
-
- This tests that we don't try to reacreate a previously deleted app.
-
- """
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- app_name = "test-app"
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="Default Project", project_id=project_id)]
- )
- mock_client.lightningapp_v2_service_create_lightningapp_release.return_value = V1LightningRun(
- cluster_id=DEFAULT_CLUSTER
- )
-
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([
- Externalv1Cluster(id=DEFAULT_CLUSTER)
- ])
-
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id=DEFAULT_CLUSTER)]
- )
-
- # Mock all clusters as global clusters
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
-
- cloud_backend.client = mock_client
-
- app = mock.MagicMock()
- app.flows = []
- app.frontend = {}
-
- existing_app = MagicMock()
- existing_app.name = app_name
- existing_app.id = "test-id"
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=[existing_app]
- )
-
- # Simulate the app as deleted so no instance to return
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- cloud_runtime.dispatch(name=app_name)
-
- # Check that a new name was used which starts with and does not equal the old name
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once()
- args = mock_client.cloud_space_service_create_lightning_run_instance.call_args
- assert args[1]["body"].name != app_name
- assert args[1]["body"].name.startswith(app_name)
-
- @pytest.mark.parametrize("flow_cloud_compute", [None, CloudCompute(name="t2.medium")])
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- def test_run_with_default_flow_compute_config(self, tmpdir, monkeypatch, flow_cloud_compute):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.lightningapp_v2_service_create_lightningapp_release.return_value = V1LightningRun(cluster_id="test")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
-
- dummy_flow = mock.MagicMock()
- monkeypatch.setattr(dummy_flow, "run", lambda *args, **kwargs: None)
- if flow_cloud_compute is None:
- app = LightningApp(dummy_flow)
- else:
- app = LightningApp(dummy_flow, flow_cloud_compute=flow_cloud_compute)
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- cloud_runtime.dispatch()
-
- user_requested_flow_compute_config = None
- if flow_cloud_compute is not None:
- user_requested_flow_compute_config = V1UserRequestedFlowComputeConfig(
- name=flow_cloud_compute.name, preemptible=False, shm_size=0
- )
-
- body = get_cloud_runtime_request_body(user_requested_flow_compute_config=user_requested_flow_compute_config)
- cloud_runtime.backend.client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=body
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- def test_run_on_byoc_cluster(self, tmpdir, monkeypatch):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="Default Project", project_id="default-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(cluster_id="test1234")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([
- Externalv1Cluster(id="test1234")
- ])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
- app.flows = []
- app.frontend = {}
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- cloud_runtime.dispatch(cluster_id="test1234")
- body = CloudspaceIdRunsBody(
- cluster_id="test1234",
- app_entrypoint_file=mock.ANY,
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- image_spec=None,
- works=[],
- local_source=True,
- dependency_cache_key=mock.ANY,
- user_requested_flow_compute_config=mock.ANY,
- )
- cloud_runtime.backend.client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="default-project-id", cloudspace_id=mock.ANY, body=body
- )
- cloud_runtime.backend.client.projects_service_create_project_cluster_binding.assert_called_once_with(
- project_id="default-project-id",
- body=ProjectIdProjectclustersbindingsBody(cluster_id="test1234"),
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- def test_requirements_file(self, tmpdir, monkeypatch):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
- app.flows = []
- app.frontend = {}
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- # Without requirements file
- cloud_runtime.dispatch()
- body = CloudspaceIdRunsBody(
- app_entrypoint_file=mock.ANY,
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- image_spec=None,
- works=[],
- local_source=True,
- dependency_cache_key=mock.ANY,
- user_requested_flow_compute_config=mock.ANY,
- )
- cloud_runtime.backend.client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=body
- )
-
- # with requirements file
- requirements = Path(tmpdir) / "requirements.txt"
- requirements.touch()
-
- cloud_runtime.dispatch(no_cache=True)
- body.image_spec = Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(package_manager=V1PackageManager.PIP, path="requirements.txt")
- )
- cloud_runtime.backend.client.cloud_space_service_create_lightning_run.assert_called_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=body
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- def test_basic_auth_enabled(self, tmpdir, monkeypatch):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
- app.flows = []
- app.frontend = {}
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
- # Set cloud_runtime.enable_basic_auth to be not empty:
- cloud_runtime.enable_basic_auth = "username:password"
-
- cloud_runtime.dispatch()
- mock_client = cloud_runtime.backend.client
-
- body = CloudspaceIdRunsBody(
- app_entrypoint_file=mock.ANY,
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- image_spec=None,
- works=[],
- local_source=True,
- dependency_cache_key=mock.ANY,
- user_requested_flow_compute_config=mock.ANY,
- )
-
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=body
- )
-
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id",
- cloudspace_id=mock.ANY,
- id=mock.ANY,
- body=IdGetBody1(
- desired_state=mock.ANY,
- name=mock.ANY,
- env=mock.ANY,
- queue_server_type=mock.ANY,
- auth=V1LightningAuth(basic=V1LightningBasicAuth(username="username", password="password")),
- ),
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- def test_no_cache(self, tmpdir, monkeypatch):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
- requirements = Path(tmpdir) / "requirements.txt"
- requirements.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(cluster_id="test")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- monkeypatch.setattr(cloud, "get_hash", lambda *args, **kwargs: "dummy-hash")
- app = mock.MagicMock()
- app.flows = []
- app.frontend = {}
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- # testing with no-cache False
- cloud_runtime.dispatch(no_cache=False)
- _, _, kwargs = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
- body = kwargs["body"]
- assert body.dependency_cache_key == "dummy-hash"
-
- # testing with no-cache True
- mock_client.reset_mock()
- cloud_runtime.dispatch(no_cache=True)
- _, _, kwargs = cloud_runtime.backend.client.cloud_space_service_create_lightning_run.mock_calls[0]
- body = kwargs["body"]
- assert body.dependency_cache_key is None
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @pytest.mark.parametrize(
- ("lightningapps", "start_with_flow"),
- [([], False), ([MagicMock()], False), ([MagicMock()], True)],
- )
- def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, tmpdir):
- source_code_root_dir = Path(tmpdir / "src").absolute()
- source_code_root_dir.mkdir()
- Path(source_code_root_dir / ".lightning").write_text("name: myapp")
- requirements_file = Path(source_code_root_dir / "requirements.txt")
- Path(requirements_file).touch()
- (source_code_root_dir / "entrypoint.py").touch()
-
- mock_client = mock.MagicMock()
- if lightningapps:
- lightningapps[0].name = "myapp"
- lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
- lightningapps[0].spec.cluster_id = "test"
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=lightningapps
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=lightningapps)
- )
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id="test")]
- )
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = MagicMock()
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
-
- work = MyWork(start_with_flow=start_with_flow, cloud_compute=CloudCompute("custom"))
- work._name = "test-work"
- work._cloud_build_config.build_commands = lambda: ["echo 'start'"]
- work._cloud_build_config.requirements = ["torch==1.0.0", "numpy==1.0.0"]
- work._cloud_build_config.image = "random_base_public_image"
- work._cloud_compute.disk_size = 0
- work._port = 8080
-
- app.works = [work]
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
- monkeypatch.setattr(
- "lightning.app.runners.cloud._get_project",
- lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
- )
- cloud_runtime.dispatch()
-
- if lightningapps:
- expected_body = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- )
-
- if start_with_flow:
- expected_body.works = [
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom",
- count=1,
- disk_size=0,
- shm_size=0,
- preemptible=False,
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- data_connection_mounts=[],
- ),
- )
- ]
- else:
- expected_body.works = []
-
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
-
- # running dispatch with disabled dependency cache
- mock_client.reset_mock()
- monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
- expected_body.dependency_cache_key = None
- cloud_runtime.dispatch()
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- else:
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=mock.ANY
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
- def test_call_with_queue_server_type_specified(self, tmpdir, lightningapps, monkeypatch):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.flows = []
- app.frontend = {}
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- cloud_runtime._check_uploaded_folder = mock.MagicMock()
-
- cloud_runtime.dispatch()
-
- # calling with no env variable set
- body = IdGetBody1(
- desired_state=V1LightningappInstanceState.STOPPED,
- env=[],
- name=mock.ANY,
- queue_server_type=V1QueueServerType.UNSPECIFIED,
- )
- client = cloud_runtime.backend.client
- client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=body
- )
-
- # calling with env variable set to http
- monkeypatch.setitem(os.environ, "LIGHTNING_CLOUD_QUEUE_TYPE", "http")
- cloud_runtime.backend.client.reset_mock()
- cloud_runtime.dispatch()
- body = IdGetBody1(
- desired_state=V1LightningappInstanceState.STOPPED,
- env=mock.ANY,
- name=mock.ANY,
- queue_server_type=V1QueueServerType.HTTP,
- )
- client = cloud_runtime.backend.client
- client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=body
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
- def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch, tmpdir):
- source_code_root_dir = Path(tmpdir / "src").absolute()
- source_code_root_dir.mkdir()
- Path(source_code_root_dir / ".lightning").write_text("name: myapp")
- requirements_file = Path(source_code_root_dir / "requirements.txt")
- Path(requirements_file).touch()
- (source_code_root_dir / "entrypoint.py").touch()
-
- mock_client = mock.MagicMock()
- if lightningapps:
- lightningapps[0].name = "myapp"
- lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
- lightningapps[0].spec.cluster_id = "test"
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=lightningapps
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=lightningapps)
- )
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id="test")]
- )
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- lit_app_instance = MagicMock()
- mock_client.cloud_space_service_create_lightning_run_instance = MagicMock(return_value=lit_app_instance)
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
-
- mocked_drive = MagicMock(spec=Drive)
- setattr(mocked_drive, "id", "foobar")
- setattr(mocked_drive, "protocol", "lit://")
- setattr(mocked_drive, "component_name", "test-work")
- setattr(mocked_drive, "allow_duplicates", False)
- setattr(mocked_drive, "root_folder", tmpdir)
- # deepcopy on a MagicMock instance will return an empty magicmock instance. To
- # overcome this we set the __deepcopy__ method `return_value` to equal what
- # should be the results of the deepcopy operation (an instance of the original class)
- mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
-
- work = WorkWithSingleDrive(cloud_compute=CloudCompute("custom"))
- monkeypatch.setattr(work, "drive", mocked_drive)
- monkeypatch.setattr(work, "_state", {"_port", "drive"})
- monkeypatch.setattr(work, "_name", "test-work")
- monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
- monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
- monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
- monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
- monkeypatch.setattr(work, "_port", 8080)
-
- app.works = [work]
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
- monkeypatch.setattr(
- "lightning.app.runners.cloud._get_project",
- lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
- )
- cloud_runtime.dispatch()
-
- if lightningapps:
- expected_body = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- works=[
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name="test-work.drive",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.NO_MOUNT_S3,
- source_type=V1SourceType.S3,
- source="lit://foobar",
- ),
- status=V1DriveStatus(),
- ),
- ),
- ],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom", count=1, disk_size=0, shm_size=0, preemptible=False
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- data_connection_mounts=[],
- ),
- )
- ],
- )
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
-
- # running dispatch with disabled dependency cache
- mock_client.reset_mock()
- monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
- expected_body.dependency_cache_key = None
- cloud_runtime.dispatch()
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- else:
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=mock.ANY
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @mock.patch("lightning.app.core.constants.ENABLE_APP_COMMENT_COMMAND_EXECUTION", True)
- @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
- def test_call_with_work_app_and_app_comment_command_execution_set(self, lightningapps, monkeypatch, tmpdir):
- source_code_root_dir = Path(tmpdir / "src").absolute()
- source_code_root_dir.mkdir()
- Path(source_code_root_dir / ".lightning").write_text("name: myapp")
- requirements_file = Path(source_code_root_dir / "requirements.txt")
- Path(requirements_file).touch()
- (source_code_root_dir / "entrypoint.py").touch()
-
- mock_client = mock.MagicMock()
- if lightningapps:
- lightningapps[0].name = "myapp"
- lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
- lightningapps[0].spec.cluster_id = "test"
- mock_client.projects_service_list_project_cluster_bindings.return_value = (
- V1ListProjectClusterBindingsResponse(clusters=[V1ProjectClusterBinding(cluster_id="test")])
- )
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=lightningapps
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=lightningapps)
- )
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = V1LightningRun()
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- lit_app_instance = MagicMock()
- mock_client.cloud_space_service_create_lightning_run_instance = MagicMock(return_value=lit_app_instance)
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
-
- work = MyWork(cloud_compute=CloudCompute("custom"))
- work._state = {"_port"}
- work._name = "test-work"
- work._cloud_build_config.build_commands = lambda: ["echo 'start'"]
- work._cloud_build_config.requirements = ["torch==1.0.0", "numpy==1.0.0"]
- work._cloud_build_config.image = "random_base_public_image"
- work._cloud_compute.disk_size = 0
- work._port = 8080
-
- app.works = [work]
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
- monkeypatch.setattr(
- "lightning.app.runners.cloud._get_project",
- lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
- )
- cloud_runtime.run_app_comment_commands = True
- cloud_runtime.dispatch()
-
- if lightningapps:
- expected_body = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- works=[
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom", count=1, disk_size=0, shm_size=0, preemptible=mock.ANY
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- cluster_id=mock.ANY,
- data_connection_mounts=[],
- ),
- )
- ],
- )
-
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
-
- # running dispatch with disabled dependency cache
- mock_client.reset_mock()
- monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
- expected_body.dependency_cache_key = None
- cloud_runtime.dispatch()
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- else:
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id",
- cloudspace_id=mock.ANY,
- id=mock.ANY,
- body=IdGetBody1(
- desired_state=V1LightningappInstanceState.STOPPED,
- name=mock.ANY,
- env=[V1EnvVar(name="ENABLE_APP_COMMENT_COMMAND_EXECUTION", value="1")],
- queue_server_type=mock.ANY,
- ),
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
- def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, monkeypatch, tmpdir):
- source_code_root_dir = Path(tmpdir / "src").absolute()
- source_code_root_dir.mkdir()
- Path(source_code_root_dir / ".lightning").write_text("name: myapp")
- requirements_file = Path(source_code_root_dir / "requirements.txt")
- Path(requirements_file).touch()
- (source_code_root_dir / "entrypoint.py").touch()
-
- mock_client = mock.MagicMock()
- if lightningapps:
- lightningapps[0].name = "myapp"
- lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
- lightningapps[0].spec.cluster_id = "test"
- mock_client.projects_service_list_project_cluster_bindings.return_value = (
- V1ListProjectClusterBindingsResponse(
- clusters=[
- V1ProjectClusterBinding(cluster_id="test"),
- ]
- )
- )
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=lightningapps
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=lightningapps)
- )
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = V1LightningRun(cluster_id="test")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- lit_app_instance = MagicMock()
- mock_client.cloud_space_service_create_lightning_run_instance = MagicMock(return_value=lit_app_instance)
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
-
- mocked_lit_drive = MagicMock(spec=Drive)
- setattr(mocked_lit_drive, "id", "foobar")
- setattr(mocked_lit_drive, "protocol", "lit://")
- setattr(mocked_lit_drive, "component_name", "test-work")
- setattr(mocked_lit_drive, "allow_duplicates", False)
- setattr(mocked_lit_drive, "root_folder", tmpdir)
- # deepcopy on a MagicMock instance will return an empty magicmock instance. To
- # overcome this we set the __deepcopy__ method `return_value` to equal what
- # should be the results of the deepcopy operation (an instance of the original class)
- mocked_lit_drive.__deepcopy__.return_value = copy(mocked_lit_drive)
-
- work = WorkWithTwoDrives(cloud_compute=CloudCompute("custom"))
- work.lit_drive_1 = mocked_lit_drive
- work.lit_drive_2 = mocked_lit_drive
- work._state = {"_port", "_name", "lit_drive_1", "lit_drive_2"}
- work._name = "test-work"
- work._cloud_build_config.build_commands = lambda: ["echo 'start'"]
- work._cloud_build_config.requirements = ["torch==1.0.0", "numpy==1.0.0"]
- work._cloud_build_config.image = "random_base_public_image"
- work._cloud_compute.disk_size = 0
- work._port = 8080
-
- app.works = [work]
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
- monkeypatch.setattr(
- "lightning.app.runners.cloud._get_project",
- lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
- )
- cloud_runtime.dispatch()
-
- if lightningapps:
- lit_drive_1_spec = V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name="test-work.lit_drive_1",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.NO_MOUNT_S3,
- source_type=V1SourceType.S3,
- source="lit://foobar",
- ),
- status=V1DriveStatus(),
- ),
- )
- lit_drive_2_spec = V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name="test-work.lit_drive_2",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.NO_MOUNT_S3,
- source_type=V1SourceType.S3,
- source="lit://foobar",
- ),
- status=V1DriveStatus(),
- ),
- )
-
- # order of drives in the spec is non-deterministic, so there are two options
- # depending for the expected body value on which drive is ordered in the list first.
-
- expected_body_option_1 = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- works=[
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[lit_drive_2_spec, lit_drive_1_spec],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom",
- count=1,
- disk_size=0,
- shm_size=0,
- preemptible=False,
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- data_connection_mounts=[],
- ),
- )
- ],
- )
-
- expected_body_option_2 = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- works=[
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[lit_drive_1_spec, lit_drive_2_spec],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom",
- count=1,
- disk_size=0,
- shm_size=0,
- preemptible=False,
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- data_connection_mounts=[],
- ),
- )
- ],
- )
-
- # try both options for the expected body to avoid false
- # positive test failures depending on system randomness
-
- expected_body = expected_body_option_1
- try:
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- except Exception:
- expected_body = expected_body_option_2
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
-
- # running dispatch with disabled dependency cache
- mock_client.reset_mock()
- monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
- expected_body.dependency_cache_key = None
- cloud_runtime.dispatch()
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- else:
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=mock.ANY
- )
-
- @mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
- @pytest.mark.parametrize("lightningapps", [[], [MagicMock()]])
- def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, monkeypatch, tmpdir):
- source_code_root_dir = Path(tmpdir / "src").absolute()
- source_code_root_dir.mkdir()
- Path(source_code_root_dir / ".lightning").write_text("name: myapp")
- requirements_file = Path(source_code_root_dir / "requirements.txt")
- Path(requirements_file).touch()
- (source_code_root_dir / "entrypoint.py").touch()
-
- mock_client = mock.MagicMock()
- if lightningapps:
- lightningapps[0].name = "myapp"
- lightningapps[0].status.phase = V1LightningappInstanceState.STOPPED
- lightningapps[0].spec.cluster_id = "test"
- mock_client.projects_service_list_project_cluster_bindings.return_value = (
- V1ListProjectClusterBindingsResponse(clusters=[V1ProjectClusterBinding(cluster_id="test")])
- )
- mock_client.cluster_service_get_cluster.side_effect = lambda cluster_id: Externalv1Cluster(
- id=cluster_id, spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)
- )
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=lightningapps
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=lightningapps)
- )
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = V1LightningRun(cluster_id="test")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- lit_app_instance = MagicMock()
- mock_client.cloud_space_service_create_lightning_run_instance = MagicMock(return_value=lit_app_instance)
- existing_instance = MagicMock()
- existing_instance.status.phase = V1LightningappInstanceState.STOPPED
- existing_instance.spec.cluster_id = None
- mock_client.lightningapp_service_get_lightningapp = MagicMock(return_value=existing_instance)
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock())
- monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock())
- app = mock.MagicMock()
- app.is_headless = False
-
- mocked_drive = MagicMock(spec=Drive)
- setattr(mocked_drive, "id", "foobar")
- setattr(mocked_drive, "protocol", "lit://")
- setattr(mocked_drive, "component_name", "test-work")
- setattr(mocked_drive, "allow_duplicates", False)
- setattr(mocked_drive, "root_folder", tmpdir)
- # deepcopy on a MagicMock instance will return an empty magicmock instance. To
- # overcome this we set the __deepcopy__ method `return_value` to equal what
- # should be the results of the deepcopy operation (an instance of the original class)
- mocked_drive.__deepcopy__.return_value = copy(mocked_drive)
-
- mocked_mount = MagicMock(spec=Mount)
- setattr(mocked_mount, "source", "s3://foo/")
- setattr(mocked_mount, "mount_path", "/content/foo")
- setattr(mocked_mount, "protocol", "s3://")
-
- work = WorkWithSingleDrive(cloud_compute=CloudCompute("custom"))
- monkeypatch.setattr(work, "drive", mocked_drive)
- monkeypatch.setattr(work, "_state", {"_port", "drive"})
- monkeypatch.setattr(work, "_name", "test-work")
- monkeypatch.setattr(work._cloud_build_config, "build_commands", lambda: ["echo 'start'"])
- monkeypatch.setattr(work._cloud_build_config, "requirements", ["torch==1.0.0", "numpy==1.0.0"])
- monkeypatch.setattr(work._cloud_build_config, "image", "random_base_public_image")
- monkeypatch.setattr(work._cloud_compute, "disk_size", 0)
- monkeypatch.setattr(work._cloud_compute, "mounts", mocked_mount)
- monkeypatch.setattr(work, "_port", 8080)
-
- app.works = [work]
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=(source_code_root_dir / "entrypoint.py"))
- monkeypatch.setattr(
- "lightning.app.runners.cloud._get_project",
- lambda _, project_id: V1Membership(name="test-project", project_id="test-project-id"),
- )
- cloud_runtime.dispatch()
-
- if lightningapps:
- expected_body = CloudspaceIdRunsBody(
- description=None,
- local_source=True,
- app_entrypoint_file="entrypoint.py",
- enable_app_server=True,
- is_headless=False,
- should_mount_cloudspace_content=False,
- flow_servers=[],
- dependency_cache_key=get_hash(requirements_file),
- image_spec=Gridv1ImageSpec(
- dependency_file_info=V1DependencyFileInfo(
- package_manager=V1PackageManager.PIP, path="requirements.txt"
- )
- ),
- user_requested_flow_compute_config=mock.ANY,
- cluster_id="test",
- works=[
- V1Work(
- name="test-work",
- display_name="",
- spec=V1LightningworkSpec(
- build_spec=V1BuildSpec(
- commands=["echo 'start'"],
- python_dependencies=V1PythonDependencyInfo(
- package_manager=V1PackageManager.PIP, packages="torch==1.0.0\nnumpy==1.0.0"
- ),
- image="random_base_public_image",
- ),
- drives=[
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name="test-work.drive",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.NO_MOUNT_S3,
- source_type=V1SourceType.S3,
- source="lit://foobar",
- ),
- status=V1DriveStatus(),
- ),
- ),
- V1LightningworkDrives(
- drive=V1Drive(
- metadata=V1Metadata(
- name="test-work",
- ),
- spec=V1DriveSpec(
- drive_type=V1DriveType.INDEXED_S3,
- source_type=V1SourceType.S3,
- source="s3://foo/",
- ),
- status=V1DriveStatus(),
- ),
- mount_location="/content/foo",
- ),
- ],
- user_requested_compute_config=V1UserRequestedComputeConfig(
- name="custom",
- count=1,
- disk_size=0,
- shm_size=0,
- preemptible=False,
- ),
- network_config=[V1NetworkConfig(name=mock.ANY, host=None, port=8080)],
- data_connection_mounts=[],
- ),
- )
- ],
- )
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
-
- # running dispatch with disabled dependency cache
- mock_client.reset_mock()
- monkeypatch.setattr(cloud, "DISABLE_DEPENDENCY_CACHE", True)
- expected_body.dependency_cache_key = None
- cloud_runtime.dispatch()
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, body=expected_body
- )
- else:
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="test-project-id", cloudspace_id=mock.ANY, id=mock.ANY, body=mock.ANY
- )
-
-
-class TestOpen:
- def test_open(self, monkeypatch):
- """Tests that the open method calls the expected API endpoints."""
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(
- username="tester", features=V1UserFeatures(code_tab=True)
- )
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
-
- mock_client.cloud_space_service_create_cloud_space.return_value = V1CloudSpace(id="cloudspace_id")
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(id="run_id")
-
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- mock_local_source = mock.MagicMock()
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock_local_source)
-
- cloud_runtime = cloud.CloudRuntime(entrypoint=Path("."))
-
- cloud_runtime.open("test_space")
-
- mock_client.cloud_space_service_create_cloud_space.assert_called_once_with(
- project_id="test-project-id", body=mock.ANY
- )
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id="cloudspace_id", body=mock.ANY
- )
-
- assert mock_client.cloud_space_service_create_cloud_space.call_args.kwargs["body"].name == "test_space"
-
- @pytest.mark.parametrize(
- ("path", "expected_root", "entries", "expected_filtered_entries"),
- [(".", ".", ["a.py", "b.ipynb"], ["a.py", "b.ipynb"]), ("a.py", ".", ["a.py", "b.ipynb"], ["a.py"])],
- )
- def test_open_repo(self, tmpdir, monkeypatch, path, expected_root, entries, expected_filtered_entries):
- """Tests that the local source code repo is set up with the correct path and ignore functions."""
- tmpdir = Path(tmpdir)
- for entry in entries:
- (tmpdir / entry).touch()
-
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(
- username="tester", features=V1UserFeatures(code_tab=True)
- )
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.lightningapp_v2_service_create_lightningapp_release.return_value = V1LightningRun(cluster_id="test")
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- mock_local_source = mock.MagicMock()
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock_local_source)
-
- cloud_runtime = cloud.CloudRuntime(entrypoint=tmpdir / path)
-
- cloud_runtime.open("test_space")
-
- mock_local_source.assert_called_once()
- repo_call = mock_local_source.call_args
-
- assert repo_call.kwargs["path"] == (tmpdir / expected_root).absolute()
- ignore_functions = repo_call.kwargs["ignore_functions"]
- if len(ignore_functions) > 0:
- filtered = ignore_functions[0]("", [tmpdir / entry for entry in entries])
- else:
- filtered = [tmpdir / entry for entry in entries]
-
- filtered = [entry.absolute() for entry in filtered]
- expected_filtered_entries = [(tmpdir / entry).absolute() for entry in expected_filtered_entries]
- assert filtered == expected_filtered_entries
-
- def test_reopen(self, monkeypatch, capsys):
- """Tests that the open method calls the expected API endpoints when the CloudSpace already exists."""
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(
- username="tester", features=V1UserFeatures(code_tab=True)
- )
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
-
- mock_client.cloud_space_service_list_cloud_spaces.return_value = V1ListCloudSpacesResponse(
- cloudspaces=[V1CloudSpace(id="cloudspace_id", name="test_space")]
- )
-
- running_instance = Externalv1LightningappInstance(
- id="instance_id",
- name="test_space",
- spec=V1LightningappInstanceSpec(cluster_id="test"),
- status=V1LightningappInstanceStatus(phase=V1LightningappInstanceState.RUNNING),
- )
-
- stopped_instance = Externalv1LightningappInstance(
- id="instance_id",
- name="test_space",
- spec=V1LightningappInstanceSpec(cluster_id="test"),
- status=V1LightningappInstanceStatus(phase=V1LightningappInstanceState.STOPPED),
- )
-
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[running_instance])
- )
- mock_client.lightningapp_instance_service_update_lightningapp_instance.return_value = running_instance
- mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = stopped_instance
-
- mock_client.cloud_space_service_create_cloud_space.return_value = V1CloudSpace(id="cloudspace_id")
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(id="run_id")
-
- cluster = Externalv1Cluster(id="test", spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL))
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id="test")],
- )
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([cluster])
- mock_client.cluster_service_get_cluster.return_value = cluster
-
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- mock_local_source = mock.MagicMock()
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock_local_source)
-
- cloud_runtime = cloud.CloudRuntime(entrypoint=Path("."))
-
- cloud_runtime.open("test_space")
-
- mock_client.cloud_space_service_create_lightning_run_instance.assert_not_called()
- mock_client.cloud_space_service_create_cloud_space.assert_not_called()
-
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="test-project-id", cloudspace_id="cloudspace_id", body=mock.ANY
- )
-
- def test_not_enabled(self, monkeypatch, capsys):
- """Tests that an error is printed and the call exits if the feature isn't enabled for the user."""
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(
- username="tester",
- features=V1UserFeatures(code_tab=False),
- )
-
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- cloud_runtime = cloud.CloudRuntime(entrypoint=Path("."))
-
- monkeypatch.setattr(cloud, "Path", Path)
-
- exited = False
- try:
- cloud_runtime.open("test_space")
- except SystemExit:
- # Expected behaviour
- exited = True
-
- out, _ = capsys.readouterr()
-
- assert exited
- assert "`lightning_app open` command has not been enabled" in out
-
-
-class TestCloudspaceDispatch:
- @mock.patch.object(pathlib.Path, "exists")
- @pytest.mark.parametrize(
- ("custom_env_sync_path_value", "cloudspace"),
- [
- (None, V1CloudSpace(id="test_id", code_config=V1CloudSpaceInstanceConfig())),
- (
- Path("/tmp/sys-customizations-sync"),
- V1CloudSpace(id="test_id", code_config=V1CloudSpaceInstanceConfig()),
- ),
- (
- Path("/tmp/sys-customizations-sync"),
- V1CloudSpace(
- id="test_id",
- code_config=V1CloudSpaceInstanceConfig(data_connection_mounts=[V1DataConnectionMount(id="test")]),
- ),
- ),
- ],
- )
- def test_cloudspace_dispatch(self, custom_env_sync_root, custom_env_sync_path_value, cloudspace, monkeypatch):
- """Tests that the cloudspace_dispatch method calls the expected API endpoints."""
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(
- username="tester",
- features=V1UserFeatures(),
- )
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="project", project_id="project_id")]
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(id="run_id")
- mock_client.cloud_space_service_create_lightning_run_instance.return_value = Externalv1LightningappInstance(
- id="instance_id"
- )
-
- cluster = Externalv1Cluster(id="test", spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL))
- mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse(
- clusters=[V1ProjectClusterBinding(cluster_id="cluster_id")],
- )
- mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([cluster])
- mock_client.cluster_service_get_cluster.return_value = cluster
- mock_client.cloud_space_service_get_cloud_space.return_value = cloudspace
-
- cloud_backend = mock.MagicMock()
- cloud_backend.client = mock_client
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
- mock_repo = mock.MagicMock()
- mock_local_source = mock.MagicMock(return_value=mock_repo)
- monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock_local_source)
- custom_env_sync_root.return_value = custom_env_sync_path_value
-
- mock_app = mock.MagicMock()
- mock_app.works = [mock.MagicMock()]
- cloud_runtime = cloud.CloudRuntime(app=mock_app, entrypoint=Path("."))
-
- app = cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name", "cluster_id")
- assert app.id == "instance_id"
-
- mock_client.cloud_space_service_get_cloud_space.assert_called_once_with(
- project_id="project_id", id="cloudspace_id"
- )
-
- mock_client.cloud_space_service_create_lightning_run.assert_called_once_with(
- project_id="project_id", cloudspace_id="cloudspace_id", body=mock.ANY
- )
-
- assert (
- mock_client.cloud_space_service_create_lightning_run.call_args.kwargs["body"]
- .works[0]
- .spec.data_connection_mounts
- == cloudspace.code_config.data_connection_mounts
- )
-
- mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
- project_id="project_id", cloudspace_id="cloudspace_id", id="run_id", body=mock.ANY
- )
-
- assert mock_client.cloud_space_service_create_lightning_run_instance.call_args.kwargs["body"].name == "run_name"
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
-@mock.patch("lightning.app.runners.backends.cloud.LightningClient", MagicMock())
-def test_get_project(monkeypatch):
- mock_client = mock.MagicMock()
- monkeypatch.setattr(cloud, "CloudBackend", mock.MagicMock(return_value=mock_client))
- app = mock.MagicMock(spec=LightningApp)
- cloud.CloudRuntime(app=app, entrypoint=Path("entrypoint.py"))
-
- # No valid projects
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(memberships=[])
-
- with pytest.raises(ValueError, match="No valid projects found"):
- _get_project(mock_client)
-
- # One valid project
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- ret = _get_project(mock_client)
- assert ret.project_id == "test-project-id"
-
- # Multiple valid projects
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[
- V1Membership(name="test-project1", project_id="test-project-id1"),
- V1Membership(name="test-project2", project_id="test-project-id2"),
- ]
- )
- ret = _get_project(mock_client)
- assert ret.project_id == "test-project-id1"
-
-
-def write_file_of_size(path, size):
- os.makedirs(os.path.dirname(path), exist_ok=True)
- with open(path, "wb") as f:
- f.seek(size)
- f.write(b"\0")
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
-@mock.patch("lightning.app.runners.backends.cloud.LightningClient", MagicMock())
-def test_check_uploaded_folder(monkeypatch, tmpdir, caplog):
- app = MagicMock()
- root = Path(tmpdir)
- repo = LocalSourceCodeDir(root)
- backend = cloud.CloudRuntime(app)
- with caplog.at_level(logging.WARN):
- backend._validate_repo(root, repo)
- assert caplog.messages == []
-
- # write some files to assert the message below.
- write_file_of_size(root / "a.png", 4 * 1000 * 1000)
- write_file_of_size(root / "b.txt", 5 * 1000 * 1000)
- write_file_of_size(root / "c.jpg", 6 * 1000 * 1000)
-
- repo._non_ignored_files = None # force reset
- with caplog.at_level(logging.WARN):
- backend._validate_repo(root, repo)
- assert f"Your application folder '{root.absolute()}' is more than 2 MB" in caplog.text
- assert "The total size is 15.0 MB" in caplog.text
- assert "4 files were uploaded" in caplog.text
- assert "files:\n6.0 MB: c.jpg\n5.0 MB: b.txt\n4.0 MB: a.png\nPerhaps" in caplog.text # tests the order
- assert "adding them to `.lightningignore`." in caplog.text
- assert "lightningingore` attribute in a Flow or Work" in caplog.text
-
-
-@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock())
-@mock.patch("lightning.app.runners.backends.cloud.LightningClient", MagicMock())
-def test_project_has_sufficient_credits():
- app = mock.MagicMock(spec=LightningApp)
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=Path("entrypoint.py"))
- credits_and_test_value = [
- [0.3, True],
- [1, False],
- [1.1, False],
- ]
- for balance, result in credits_and_test_value:
- project = V1Membership(name="test-project1", project_id="test-project-id1", balance=balance)
- assert cloud_runtime._resolve_needs_credits(project) is result
-
-
-@pytest.mark.parametrize(
- "lines",
- [
- [
- "import this_package_is_not_real",
- "from lightning.app import LightningApp",
- "from lightning.app.testing.helpers import EmptyWork",
- "app = LightningApp(EmptyWork())",
- ],
- [
- "from this_package_is_not_real import this_module_is_not_real",
- "from lightning.app import LightningApp",
- "from lightning.app.testing.helpers import EmptyWork",
- "app = LightningApp(EmptyWork())",
- ],
- [
- "import this_package_is_not_real",
- "from this_package_is_not_real import this_module_is_not_real",
- "from lightning.app import LightningApp",
- "from lightning.app.testing.helpers import EmptyWork",
- "app = LightningApp(EmptyWork())",
- ],
- [
- "import this_package_is_not_real",
- "from lightning.app import LightningApp",
- "from lightning.app.core.flow import _RootFlow",
- "from lightning.app.testing.helpers import EmptyWork",
- "class MyFlow(_RootFlow):",
- " def configure_layout(self):",
- " return [{'name': 'test', 'content': this_package_is_not_real()}]",
- "app = LightningApp(MyFlow(EmptyWork()))",
- ],
- ],
-)
-@pytest.mark.skipif(sys.platform != "linux", reason="Causing conflicts on non-linux")
-def test_load_app_from_file_mock_imports(tmpdir, lines):
- path = copy(sys.path)
- app_file = os.path.join(tmpdir, "app.py")
-
- with open(app_file, "w") as f:
- f.write("\n".join(lines))
-
- app = CloudRuntime.load_app_from_file(app_file)
- assert isinstance(app, LightningApp)
- assert isinstance(app.root.work, EmptyWork)
-
- # Cleanup PATH to prevent conflict with other tests
- sys.path = path
- os.remove(app_file)
-
-
-def test_load_app_from_file():
- test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
-
- app = CloudRuntime.load_app_from_file(
- os.path.join(test_script_dir, "app_with_env.py"),
- )
- assert app.works[0].cloud_compute.name == "cpu-small"
-
- app = CloudRuntime.load_app_from_file(
- os.path.join(test_script_dir, "app_with_env.py"),
- env_vars={"COMPUTE_NAME": "foo"},
- )
- assert app.works[0].cloud_compute.name == "foo"
-
-
-@pytest.mark.parametrize(
- ("print_format", "expected"),
- [
- (
- "web",
- [
- {
- "displayName": "",
- "name": "root.work",
- "spec": {
- "buildSpec": {
- "commands": [],
- "pythonDependencies": {"packageManager": "PACKAGE_MANAGER_PIP", "packages": ""},
- },
- "dataConnectionMounts": [],
- "drives": [],
- "networkConfig": [{"name": "*", "port": "*"}],
- "userRequestedComputeConfig": {
- "count": 1,
- "diskSize": 0,
- "name": "cpu-small",
- "preemptible": "*",
- "shmSize": 0,
- },
- },
- }
- ],
- ),
- (
- "gallery",
- [
- {
- "display_name": "",
- "name": "root.work",
- "spec": {
- "build_spec": {
- "commands": [],
- "python_dependencies": {"package_manager": "PACKAGE_MANAGER_PIP", "packages": ""},
- },
- "data_connection_mounts": [],
- "drives": [],
- "network_config": [{"name": "*", "port": "*"}],
- "user_requested_compute_config": {
- "count": 1,
- "disk_size": 0,
- "name": "cpu-small",
- "preemptible": "*",
- "shm_size": 0,
- },
- },
- }
- ],
- ),
- ],
-)
-def test_print_specs(tmpdir, caplog, monkeypatch, print_format, expected):
- entrypoint = Path(tmpdir) / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- cloud_runtime = cloud.CloudRuntime(app=LightningApp(EmptyWork()), entrypoint=entrypoint)
-
- cloud.LIGHTNING_CLOUD_PRINT_SPECS = print_format
-
- try:
- with caplog.at_level(logging.INFO), contextlib.suppress(SystemExit):
- cloud_runtime.dispatch()
-
- lines = caplog.text.split("\n")
-
- expected = re.escape(str(expected).replace("'", '"').replace(" ", "")).replace('"\\*"', "(.*)")
- expected = "INFO(.*)works: " + expected
- assert any(re.fullmatch(expected, line) for line in lines)
- finally:
- cloud.LIGHTNING_CLOUD_PRINT_SPECS = None
-
-
-def test_incompatible_cloud_compute_and_build_config(monkeypatch):
- """Test that an exception is raised when a build config has a custom image defined, but the cloud compute is the
- default.
-
- This combination is not supported by the platform.
-
- """
- mock_client = mock.MagicMock()
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.cloud_compute = CloudCompute(name="default")
- # TODO: Remove me
- self.cloud_compute.name = "default"
- self.cloud_build_config = BuildConfig(image="custom")
-
- def run(self):
- pass
-
- app = MagicMock()
- app.works = [Work()]
-
- with pytest.raises(ValueError, match="You requested a custom base image for the Work with name"):
- CloudRuntime(app=app)._validate_work_build_specs_and_compute()
-
-
-def test_programmatic_lightningignore(monkeypatch, caplog, tmpdir):
- path = Path(tmpdir)
- entrypoint = path / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(cluster_id="test")
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- class MyWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.lightningignore += ("foo", "lightning_logs")
-
- def run(self):
- with pytest.raises(RuntimeError, match="w.lightningignore` does not"):
- self.lightningignore += ("foobar",)
-
- class MyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.lightningignore = ("foo",)
- self.w = MyWork()
-
- def run(self):
- with pytest.raises(RuntimeError, match="root.lightningignore` does not"):
- self.lightningignore = ("baz",)
- self.w.run()
-
- flow = MyFlow()
- app = LightningApp(flow)
-
- monkeypatch.setattr(app, "_update_index_file", mock.MagicMock())
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- monkeypatch.setattr(LocalSourceCodeDir, "upload", mock.MagicMock())
-
- # write some files
- write_file_of_size(path / "a.txt", 5 * 1000 * 1000)
- write_file_of_size(path / "foo.png", 4 * 1000 * 1000)
- write_file_of_size(path / "lightning_logs" / "foo.ckpt", 6 * 1000 * 1000)
- # also an actual .lightningignore file
- (path / ".lightningignore").write_text("foo.png")
-
- with mock.patch(
- "lightning.app.runners.cloud._parse_lightningignore", wraps=_parse_lightningignore
- ) as parse_mock, mock.patch(
- "lightning.app.source_code.local._copytree", wraps=_copytree
- ) as copy_mock, caplog.at_level(logging.WARN):
- cloud_runtime.dispatch()
-
- parse_mock.assert_called_once_with(("foo", "foo", "lightning_logs"))
- assert copy_mock.mock_calls[0].kwargs["ignore_functions"][0].args[1] == {"lightning_logs", "foo"}
-
- assert f"Your application folder '{path.absolute()}' is more than 2 MB" in caplog.text
- assert "The total size is 5.0 MB" in caplog.text
- assert "2 files were uploaded" # a.txt and .lightningignore
- assert "files:\n5.0 MB: a.txt\nPerhaps" in caplog.text # only this file appears
-
- flow.run()
-
-
-def test_default_lightningignore(monkeypatch, caplog, tmpdir):
- path = Path(tmpdir)
- entrypoint = path / "entrypoint.py"
- entrypoint.touch()
-
- mock_client = mock.MagicMock()
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(name="test-project", project_id="test-project-id")]
- )
- mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = (
- V1ListLightningappInstancesResponse(lightningapps=[])
- )
- mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(cluster_id="test")
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- class MyWork(LightningWork):
- def run(self):
- pass
-
- app = LightningApp(MyWork())
-
- cloud_runtime = cloud.CloudRuntime(app=app, entrypoint=entrypoint)
- monkeypatch.setattr(LocalSourceCodeDir, "upload", mock.MagicMock())
-
- # write some files
- write_file_of_size(path / "a.txt", 5 * 1000 * 1000)
- write_file_of_size(path / "venv" / "foo.txt", 4 * 1000 * 1000)
-
- assert not (path / ".lightningignore").exists()
-
- with mock.patch(
- "lightning.app.runners.cloud._parse_lightningignore", wraps=_parse_lightningignore
- ) as parse_mock, mock.patch(
- "lightning.app.source_code.local._copytree", wraps=_copytree
- ) as copy_mock, caplog.at_level(logging.WARN):
- cloud_runtime.dispatch()
-
- parse_mock.assert_called_once_with(())
- assert copy_mock.mock_calls[0].kwargs["ignore_functions"][0].args[1] == set()
-
- assert (path / ".lightningignore").exists()
-
- assert f"Your application folder '{path.absolute()}' is more than 2 MB" in caplog.text
- assert "The total size is 5.0 MB" in caplog.text
- assert "2 files were uploaded" # a.txt and .lightningignore
- assert "files:\n5.0 MB: a.txt\nPerhaps" in caplog.text # only this file appears
-
-
-@pytest.mark.parametrize(
- ("project", "run_instance", "user", "tab", "lightning_cloud_url", "expected_url"),
- [
- # Old style
- (
- V1Membership(),
- Externalv1LightningappInstance(id="test-app-id"),
- V1GetUserResponse(username="tester", features=V1UserFeatures()),
- "logs",
- "https://lightning.ai",
- "https://lightning.ai/tester/apps/test-app-id/logs",
- ),
- (
- V1Membership(),
- Externalv1LightningappInstance(id="test-app-id"),
- V1GetUserResponse(username="tester", features=V1UserFeatures()),
- "logs",
- "http://localhost:9800",
- "http://localhost:9800/tester/apps/test-app-id/logs",
- ),
- # New style
- (
- V1Membership(name="tester's project"),
- Externalv1LightningappInstance(name="test/job"),
- V1GetUserResponse(username="tester", features=V1UserFeatures(project_selector=True)),
- "logs",
- "https://lightning.ai",
- "https://lightning.ai/tester/tester%27s%20project/jobs/test%2Fjob/logs",
- ),
- (
- V1Membership(name="tester's project"),
- Externalv1LightningappInstance(name="test/job"),
- V1GetUserResponse(username="tester", features=V1UserFeatures(project_selector=True)),
- "logs",
- "https://localhost:9800",
- "https://localhost:9800/tester/tester%27s%20project/jobs/test%2Fjob/logs",
- ),
- ],
-)
-def test_get_app_url(monkeypatch, project, run_instance, user, tab, lightning_cloud_url, expected_url):
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = user
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- runtime = CloudRuntime()
-
- with mock.patch(
- "lightning.app.runners.cloud.get_lightning_cloud_url", mock.MagicMock(return_value=lightning_cloud_url)
- ):
- assert runtime._get_app_url(project, run_instance, tab) == expected_url
-
-
-@pytest.mark.parametrize(
- ("user", "project", "cloudspace_name", "tab", "lightning_cloud_url", "expected_url"),
- [
- (
- V1GetUserResponse(username="tester", features=V1UserFeatures()),
- V1Membership(name="default-project"),
- "test/cloudspace",
- "code",
- "https://lightning.ai",
- "https://lightning.ai/tester/default-project/apps/test%2Fcloudspace/code",
- ),
- (
- V1GetUserResponse(username="tester", features=V1UserFeatures()),
- V1Membership(name="Awesome Project"),
- "The Best CloudSpace ever",
- "web-ui",
- "http://localhost:9800",
- "http://localhost:9800/tester/Awesome%20Project/apps/The%20Best%20CloudSpace%20ever/web-ui",
- ),
- ],
-)
-def test_get_cloudspace_url(monkeypatch, user, project, cloudspace_name, tab, lightning_cloud_url, expected_url):
- mock_client = mock.MagicMock()
- mock_client.auth_service_get_user.return_value = user
- cloud_backend = mock.MagicMock(client=mock_client)
- monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))
-
- runtime = CloudRuntime()
-
- with mock.patch(
- "lightning.app.runners.cloud.get_lightning_cloud_url", mock.MagicMock(return_value=lightning_cloud_url)
- ):
- assert runtime._get_cloudspace_url(project, cloudspace_name, tab) == expected_url
diff --git a/tests/tests_app/runners/test_multiprocess.py b/tests/tests_app/runners/test_multiprocess.py
deleted file mode 100644
index ce6c41e2d56fb..0000000000000
--- a/tests/tests_app/runners/test_multiprocess.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import os
-import sys
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.core import constants
-from lightning.app.frontend import StaticWebFrontend, StreamlitFrontend
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.utilities.component import _get_context
-from lightning.app.utilities.imports import _IS_WINDOWS
-
-
-def _streamlit_render_fn():
- pass
-
-
-class StreamlitFlow(LightningFlow):
- def run(self):
- self.stop()
-
- def configure_layout(self):
- frontend = StreamlitFrontend(render_fn=_streamlit_render_fn)
- frontend.start_server = Mock()
- frontend.stop_server = Mock()
- return frontend
-
-
-class WebFlow(LightningFlow):
- def run(self):
- self.stop()
-
- def configure_layout(self):
- frontend = StaticWebFrontend(serve_dir="a/b/c")
- frontend.start_server = Mock()
- frontend.stop_server = Mock()
- return frontend
-
-
-class StartFrontendServersTestFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.flow0 = StreamlitFlow()
- self.flow1 = WebFlow()
-
- def run(self):
- self.stop()
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="hanging with timeout") # fixme
-@pytest.mark.parametrize(
- ("cloudspace_host", "port", "expected_host", "expected_target"),
- [
- (None, 7000, "localhost", "http://localhost:7000"),
- ("test.lightning.ai", 7000, "0.0.0.0", "https://7000-test.lightning.ai"), # noqa: S104
- ],
-)
-@mock.patch("lightning.app.runners.multiprocess.find_free_network_port")
-def test_multiprocess_starts_frontend_servers(
- mock_find_free_network_port, monkeypatch, cloudspace_host, port, expected_host, expected_target
-):
- """Test that the MultiProcessRuntime starts the servers for the frontends in each LightningFlow."""
-
- monkeypatch.setattr(constants, "LIGHTNING_CLOUDSPACE_HOST", cloudspace_host)
- mock_find_free_network_port.return_value = port
-
- root = StartFrontendServersTestFlow()
- app = LightningApp(root)
- MultiProcessRuntime(app).dispatch()
-
- app.frontends[root.flow0.name].start_server.assert_called_once()
- assert app.frontends[root.flow0.name].start_server.call_args.kwargs["host"] == expected_host
-
- app.frontends[root.flow1.name].start_server.assert_called_once()
- assert app.frontends[root.flow1.name].start_server.call_args.kwargs["host"] == expected_host
-
- assert app.frontends[root.flow0.name].flow._layout["target"] == f"{expected_target}/{root.flow0.name}"
- assert app.frontends[root.flow1.name].flow._layout["target"] == f"{expected_target}/{root.flow1.name}"
-
- app.frontends[root.flow0.name].stop_server.assert_called_once()
- app.frontends[root.flow1.name].stop_server.assert_called_once()
-
-
-class ContextWork(LightningWork):
- def __init__(self):
- super().__init__()
-
- def run(self):
- assert _get_context().value == "work"
-
-
-class ContextFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = ContextWork()
- assert _get_context() is None
-
- def run(self):
- assert _get_context().value == "flow"
- self.work.run()
- assert _get_context().value == "flow"
- self.stop()
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="hanging with timeout") # fixme
-def test_multiprocess_runtime_sets_context():
- """Test that the runtime sets the global variable COMPONENT_CONTEXT in Flow and Work."""
- MultiProcessRuntime(LightningApp(ContextFlow())).dispatch()
-
-
-@pytest.mark.parametrize(
- ("env", "expected_url"),
- [
- ({}, "http://127.0.0.1:7501/view"),
- ({"APP_SERVER_HOST": "http://test"}, "http://test"),
- ],
-)
-@pytest.mark.skipif(sys.platform == "win32", reason="hanging with timeout")
-def test_get_app_url(env, expected_url):
- with mock.patch.dict(os.environ, env):
- assert MultiProcessRuntime._get_app_url() == expected_url
diff --git a/tests/tests_app/runners/test_runtime.py b/tests/tests_app/runners/test_runtime.py
deleted file mode 100644
index fcf861e07310c..0000000000000
--- a/tests/tests_app/runners/test_runtime.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import os
-import signal
-from unittest import mock
-
-import pytest
-from lightning.app.runners import cloud
-from lightning.app.runners.runtime import dispatch
-from lightning.app.runners.runtime_type import RuntimeType
-
-from tests_app import _PROJECT_ROOT
-
-
-@pytest.mark.parametrize(
- "runtime_type",
- [
- RuntimeType.MULTIPROCESS,
- RuntimeType.CLOUD,
- ],
-)
-@mock.patch("lightning.app.core.queues.QueuingSystem", mock.MagicMock())
-@mock.patch("lightning.app.runners.backends.cloud.LightningClient", mock.MagicMock())
-def test_dispatch(runtime_type, monkeypatch):
- """This test ensures the runtime dispatch method gets called when using dispatch."""
- monkeypatch.setattr(cloud, "CloudBackend", mock.MagicMock())
-
- with pytest.raises(FileNotFoundError, match="doesnt_exists.py"):
- dispatch(
- entrypoint_file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/doesnt_exists.py"),
- runtime_type=runtime_type,
- start_server=False,
- )
-
- runtime = runtime_type.get_runtime()
- dispath_method_path = f"{runtime.__module__}.{runtime.__name__}.dispatch"
-
- with mock.patch(dispath_method_path) as dispatch_mock_fn:
- dispatch(
- entrypoint_file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
- runtime_type=runtime_type,
- start_server=False,
- )
- dispatch_mock_fn.assert_called_once()
- assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
diff --git a/tests/tests_app/source_code/test_copytree.py b/tests/tests_app/source_code/test_copytree.py
deleted file mode 100644
index fb6f812f4486b..0000000000000
--- a/tests/tests_app/source_code/test_copytree.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import os
-
-from lightning.app.source_code.copytree import _copytree, _read_lightningignore
-
-
-def test_read_lightningignore(tmpdir):
- """_read_lightningignore() removes comments from ignore files."""
- test_path = tmpdir.join(".lightningignore")
- expected = "test"
- not_expected = "# comment"
- with open(test_path, "a") as f:
- f.write(not_expected)
- f.write(expected)
-
- result = _read_lightningignore(test_path)
- assert not_expected not in result
- assert expected not in result
-
-
-def test_read_lightningignore_excludes_empty_lines(tmpdir):
- """_read_lightningignore() excludes empty lines."""
- test_path = tmpdir.join(".lightningignore")
- gitignore = """
-
- foo
-
- bar
-
-
-
- """
- test_path.write(gitignore)
-
- # results exclude all empty lines
- result = _read_lightningignore(test_path)
- assert len(result) == 2
-
-
-def test_copytree_ignoring_files(tmp_path_factory):
- # lightningignore for ignoring txt file in dir2, the whole dir1 and .zip file everywhere
- test_dir = tmp_path_factory.mktemp("lightningignore-test")
- source = test_dir / "source"
- source.mkdir()
-
- # lightningignore at root
- source.joinpath(".lightningignore").write_text("dir1/*.txt\ndir0\n*.zip")
-
- # not creating the destination directory
- dest = test_dir / "dest"
-
- # # setting up test files and nested lightningignore in dir4
- source.joinpath("dir3").mkdir()
- source.joinpath("dir3").joinpath(".lightningignore").write_text("*.pt")
- source.joinpath("dir3").joinpath("model.pt").write_text("")
- source.joinpath("dir3").joinpath("model.non-pt").write_text("")
-
- source.joinpath("dir0").mkdir() # dir0 is ignored
- source.joinpath("dir0/file1").write_text("") # ignored because the parent dir is ignored
- source.joinpath("dir1").mkdir()
- source.joinpath("dir1/file.tar.gz").write_text("")
- source.joinpath("dir1/file.txt").write_text("") # .txt in dir1 is ignored
- source.joinpath("dir2").mkdir()
- source.joinpath("dir2/file.txt").write_text("")
- source.joinpath("dir2/file.zip").write_text("") # .zip everywhere is ignored
-
- files_copied = _copytree(source, dest)
- relative_names = set()
- for file in files_copied:
- relative_names.add(file.split("source")[1].strip("/").strip("\\"))
-
- if os.name == "nt":
- assert {
- ".lightningignore",
- "dir2\\file.txt",
- "dir3\\.lightningignore",
- "dir3\\model.non-pt",
- "dir1\\file.tar.gz",
- } == relative_names
- else:
- assert {
- ".lightningignore",
- "dir2/file.txt",
- "dir3/.lightningignore",
- "dir3/model.non-pt",
- "dir1/file.tar.gz",
- } == relative_names
-
- first_level_dirs = list(dest.iterdir())
- assert len(first_level_dirs) == 4 # .lightningignore, dir2, dir1 and dir3
- assert {".lightningignore", "dir2", "dir1", "dir3"} == {d.name for d in first_level_dirs}
-
- for d in first_level_dirs:
- if d.name == "dir1":
- assert "file.txt" not in [file.name for file in d.iterdir()]
- assert "file.tar.gz" in [file.name for file in d.iterdir()]
- assert len([file.name for file in d.iterdir()]) == 1
-
- if d.name == "dir2":
- assert "file.zip" not in [file.name for file in d.iterdir()]
- assert "file.txt" in [file.name for file in d.iterdir()]
- assert len([file.name for file in d.iterdir()]) == 1
-
- if d.name == "dir3":
- assert "model.pt" not in [file.name for file in d.iterdir()]
- assert "model.non-pt" in [file.name for file in d.iterdir()]
- assert ".lightningignore" in [file.name for file in d.iterdir()]
- assert len([file.name for file in d.iterdir()]) == 2
diff --git a/tests/tests_app/source_code/test_local.py b/tests/tests_app/source_code/test_local.py
deleted file mode 100644
index 11fa62dd28bf5..0000000000000
--- a/tests/tests_app/source_code/test_local.py
+++ /dev/null
@@ -1,376 +0,0 @@
-import os
-import sys
-import tarfile
-import uuid
-from pathlib import Path
-from unittest import mock
-
-import pytest
-from lightning.app.source_code import LocalSourceCodeDir
-
-
-def test_repository_checksum(tmp_path):
- """LocalRepository.version() generates a different version each time."""
- repository = LocalSourceCodeDir(path=Path(tmp_path))
- version_a = repository.version
-
- # version is different
- repository = LocalSourceCodeDir(path=Path(tmp_path))
- version_b = repository.version
-
- assert version_a != version_b
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="this runs only on linux")
-@mock.patch.dict(os.environ, {"LIGHTNING_VSCODE_WORKSPACE": "something"})
-def test_local_cache_path_tmp(tmp_path):
- """LocalRepository.cache_location is under tmp."""
- repository = LocalSourceCodeDir(path=Path(tmp_path))
- assert str(repository.cache_location).startswith("/tmp")
-
-
-def test_local_cache_path_home(tmp_path):
- """LocalRepository.cache_location is under home."""
- repository = LocalSourceCodeDir(path=Path(tmp_path))
- assert str(repository.cache_location).startswith(str(Path.home()))
-
-
-def test_repository_package(tmp_path, monkeypatch):
- """LocalRepository.package() creates package from local dir."""
- cache_path = Path(tmp_path)
- source_path = cache_path / "nested"
- source_path.mkdir(parents=True, exist_ok=True)
- (source_path / "test.txt").write_text("test")
-
- repository = LocalSourceCodeDir(path=source_path)
- repository.cache_location = cache_path
- repository.package()
-
- # test that package is created
- for file in cache_path.glob("**/*"):
- if file.is_file() and file.name.endswith(".tar.gz"):
- assert file.name == f"{repository.version}.tar.gz"
-
-
-def test_repository_lightningignore(tmp_path):
- """LocalRepository.version uses the assumed checksum correctly."""
- # write .lightningignore file
- lightningignore = """
- # ignore files in this dir
- ignore/
-
- """
- (tmp_path / ".lightningignore").write_text(lightningignore)
- (tmp_path / "test.txt").write_text("test")
-
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
- # write file that needs to be ignored
- (tmp_path / "ignore").mkdir()
- (tmp_path / "ignore/test.txt").write_text(str(uuid.uuid4()))
-
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
-
-def test_repository_filters_with_absolute_relative_path(tmp_path):
- """.lightningignore parsing parses paths starting with / correctly."""
- lightningignore = """
- /ignore_file/test.txt
-
- /ignore_dir
- """
- (tmp_path / ".lightningignore").write_text(lightningignore)
- (tmp_path / "test.txt").write_text("test")
-
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
- # write file that needs to be ignored
- (tmp_path / "ignore_file").mkdir()
- (tmp_path / "ignore_dir").mkdir()
- (tmp_path / "ignore_file/test.txt").write_text(str(uuid.uuid4()))
- (tmp_path / "ignore_dir/test.txt").write_text(str(uuid.uuid4()))
-
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
-
-def test_repository_lightningignore_supports_different_patterns(tmp_path):
- """.lightningignore parsing supports different patterns."""
- # write .lightningignore file
- # default github python .gitignore
- lightningignore = """
- # ignore files in this dir
- ignore/
-
- # Byte-compiled / optimized / DLL files
- __pycache__/
- *.py[cod]
- *$py.class
-
- # C extensions
- *.so
-
- # Distribution / packaging
- .Python
- build/
- develop-eggs/
- dist/
- downloads/
- eggs/
- .eggs/
- lib/
- lib64/
- parts/
- sdist/
- var/
- wheels/
- *.egg-info/
- .installed.cfg
- *.egg
- MANIFEST
-
- # PyInstaller
- # Usually these files are written by a python script from a template
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
- *.manifest
- *.spec
-
- # Installer logs
- pip-log.txt
- pip-delete-this-directory.txt
-
- # Unit test / coverage reports
- htmlcov/
- .tox/
- .coverage
- .coverage.*
- .cache
- nosetests.xml
- coverage.xml
- *.cover
- .hypothesis/
- .pytest_cache/
-
- # Translations
- *.mo
- *.pot
-
- # Django stuff:
- *.log
- local_settings.py
- db.sqlite3
-
- # Flask stuff:
- instance/
- .webassets-cache
-
- # Scrapy stuff:
- .scrapy
-
- # Sphinx documentation
- docs/_build/
-
- # PyBuilder
- target/
-
- # Jupyter Notebook
- .ipynb_checkpoints
-
- # pyenv
- .python-version
-
- # celery beat schedule file
- celerybeat-schedule
-
- # SageMath parsed files
- *.sage.py
-
- # Environments
- .env
- .env.docker
- .venv
- env/
- venv/
- ENV/
- env.bak/
- venv.bak/
-
- # Spyder project settings
- .spyderproject
- .spyproject
-
- # Rope project settings
- .ropeproject
-
- # mkdocs documentation
- /site
-
- # mypy
- .mypy_cache/
-
- # VS Code files
- .vscode/
-
- # UI files
- node_modules/
-
- # Data files
- models/
- models/*
- !grid/openapi/models
- postgresql_data/
- redis_data/
-
- # Secrets folders
- secrets/
-
- # Built UI
- ui/
-
- # Ignores Grid Runner
- vendor/
- ignore_test.py
-
- # Ignore cov report
- *.xml
-
- """
- (tmp_path / ".lightningignore").write_text(lightningignore)
- (tmp_path / "test.txt").write_text("test")
-
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
- # write file that needs to be ignored
- (tmp_path / "ignore").mkdir()
- (tmp_path / "ignore/test.txt").write_text(str(uuid.uuid4()))
-
- # check that version remains the same
- repository = LocalSourceCodeDir(path=Path(tmp_path))
-
- assert set(repository.files) == {str(tmp_path / ".lightningignore"), str(tmp_path / "test.txt")}
-
-
-def test_repository_lightningignore_unpackage(tmp_path, monkeypatch):
- """.lightningignore behaves similarly to the gitignore standard."""
- lorem_ipsum = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
-
- cache_path = tmp_path / "cache"
- source_path = tmp_path / "source"
- source_path.mkdir()
-
- # set cache location to temp dir
-
- lightningignore = """
- # Ignore on all levels
- *.pyc
- *__pycache__/
- build/
- .env
- # Ignore wildcard on one level
- ./*.txt
- /*.md
- ./one-level/*.txt
- /one-level/*.md
- # Ignore only relative
- ./downloads
- /relative_downloads
- # nested
- /nested//level/
- /nested/level/
- """
- (source_path / ".lightningignore").write_text(lightningignore)
-
- # Dir structure
- (source_path / "include.py").write_text(lorem_ipsum)
- (source_path / "exclude.pyc").write_text(lorem_ipsum)
- (source_path / "__pycache__").mkdir()
- (source_path / "__pycache__" / "exclude.py").write_text(
- lorem_ipsum
- ) # Even tho it's .py it's in excluded __pycache__ directory
- (source_path / "__pycache__" / "exclude.pyc").write_text(
- lorem_ipsum
- ) # Even tho it's .py it's in excluded __pycache__ directory
- (source_path / "build.py").write_text(lorem_ipsum) # Common prefix with excluded build but it's not it
- (source_path / "builds").mkdir() # Common prefix with excluded build but it's not excluded
- (source_path / "builds" / "include.py").write_text(lorem_ipsum)
- (source_path / "builds" / "__pycache__").mkdir() # Recursively excluded
- (source_path / "builds" / "__pycache__" / "exclude.py").write_text(lorem_ipsum)
- (source_path / "build").mkdir() # Recursively excluded
- (source_path / "build" / "exclude.db").write_text(lorem_ipsum)
- (source_path / ".env").write_text(lorem_ipsum) # No issues with handling hidden (.dot) files
- (source_path / "downloads").mkdir() # exclude
- (source_path / "downloads" / "something.jpeg").write_text(lorem_ipsum)
- (source_path / "relative_downloads").mkdir() # exclude
- (source_path / "relative_downloads" / "something.jpeg").write_text(lorem_ipsum)
- (source_path / "include").mkdir() # include
- (source_path / "include" / "exclude.pyc").write_text(lorem_ipsum) # exclude because of *.pyc rule
- (source_path / "include" / "include.py").write_text(lorem_ipsum) # include
- (source_path / "include" / "downloads").mkdir() # include because it was excluded only relative to root
- (source_path / "include" / "downloads" / "something.jpeg").write_text(lorem_ipsum)
- (source_path / "include" / "relative_downloads").mkdir() # include because it was excluded only relative to root
- (source_path / "include" / "relative_downloads" / "something.jpeg").write_text(lorem_ipsum)
- (source_path / "exclude.txt").write_text(lorem_ipsum)
- (source_path / "exclude.md").write_text(lorem_ipsum)
- (source_path / "one-level").mkdir()
- (source_path / "one-level" / "exclude.txt").write_text(lorem_ipsum)
- (source_path / "one-level" / "exclude.md").write_text(lorem_ipsum)
- (source_path / "one-level" / "include.py").write_text(lorem_ipsum)
- (source_path / "nested").mkdir()
- (source_path / "nested" / "include.py").write_text(lorem_ipsum)
- (source_path / "nested" / "level").mkdir()
- (source_path / "nested" / "level" / "exclude.py").write_text(lorem_ipsum)
-
- # create repo object
- repository = LocalSourceCodeDir(path=source_path)
- repository.cache_location = cache_path
- repository.package()
-
- unpackage_path = tmp_path / "unpackage"
-
- with tarfile.open(repository.package_path) as f:
- f.extractall(unpackage_path)
-
- assert (unpackage_path / "include.py").exists()
- assert not (unpackage_path / "exclude.pyc").exists() # Excluded by *.pyc
- assert not (unpackage_path / "__pycache__").exists()
- assert not (
- unpackage_path / "__pycache__" / "exclude.py"
- ).exists() # Even tho it's .py it's in excluded __pycache__ directory
- assert not (
- unpackage_path / "__pycache__" / "exclude.pyc"
- ).exists() # Even tho it's .py it's in excluded __pycache__ directory
- assert (unpackage_path / "build.py").exists() # Common prefix with excluded build but it's not it
- assert (unpackage_path / "builds" / "include.py").exists()
- assert not (unpackage_path / "builds" / "__pycache__").exists() # Recursively excluded
- assert not (unpackage_path / "builds" / "__pycache__" / "exclude.py").exists()
- assert not (unpackage_path / "build").exists() # Recursively excluded
- assert not (unpackage_path / "build" / "exclude.db").exists()
- assert not (unpackage_path / ".env").exists() # No issues with handling hidden (.dot) files
- assert not (unpackage_path / "downloads").mkdir() # exclude
- assert not (unpackage_path / "downloads" / "something.jpeg").exists()
- assert not (unpackage_path / "relative_downloads").mkdir() # exclude
- assert not (unpackage_path / "relative_downloads" / "something.jpeg").exists()
- assert not (unpackage_path / "include" / "exclude.pyc").exists() # exclude because of *.pyc rule
- assert (unpackage_path / "include" / "include.py").exists() # include
- assert (
- unpackage_path / "include" / "downloads" / "something.jpeg"
- ).exists() # include because it was excluded only relative to root
- assert (
- unpackage_path / "include" / "relative_downloads" / "something.jpeg"
- ).exists() # include because it was excluded only relative to root
- assert not (unpackage_path / "exclude.txt").exists()
- assert not (unpackage_path / "exclude.md").exists()
- assert not (unpackage_path / "one-level" / "exclude.txt").exists()
- assert not (unpackage_path / "one-level" / "exclude.md").exists()
- assert (unpackage_path / "one-level" / "include.py").exists()
- assert (unpackage_path / "nested" / "include.py").exists()
- assert not (unpackage_path / "nested" / "level" / "exclude.py").exists()
diff --git a/tests/tests_app/source_code/test_tar.py b/tests/tests_app/source_code/test_tar.py
deleted file mode 100644
index a1e8d99bfe8de..0000000000000
--- a/tests/tests_app/source_code/test_tar.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import math
-import os
-import tarfile
-from pathlib import Path
-
-import pytest
-from lightning.app.source_code.tar import MAX_SPLIT_COUNT, _get_dir_size_and_count, _get_split_size, _tar_path
-
-
-def _create_files(basedir: Path):
- source_dir = basedir / "source"
- inner_dir = source_dir / "dir"
- os.makedirs(inner_dir)
- with open(source_dir / "f1", "w") as fp:
- fp.write("f1")
-
- with open(inner_dir / "f2", "w") as fp:
- fp.write("f2")
- return source_dir, inner_dir
-
-
-def test_max_upload_parts():
- import click
-
- with pytest.raises(click.ClickException):
- barely_over = MAX_SPLIT_COUNT * 2**31 + 1
- _get_split_size(barely_over)
-
-
-def test_almost_max_upload_parts():
- barely_under = MAX_SPLIT_COUNT * 2**31 - 1
- assert _get_split_size(barely_under) == math.ceil(barely_under / MAX_SPLIT_COUNT)
-
-
-@pytest.mark.parametrize("size", [1024 * 512, 1024 * 1024 * 5])
-def test_get_dir_size_and_count(tmpdir: Path, size):
- data = os.urandom(size)
- with open(os.path.join(tmpdir, "a"), "wb") as f:
- f.write(data)
- with open(os.path.join(tmpdir, "b"), "wb") as f:
- f.write(data)
- assert _get_dir_size_and_count(tmpdir, "a") == (size, 1)
-
-
-def test_tar_path(tmpdir: Path, monkeypatch):
- source_dir, inner_dir = _create_files(tmpdir)
-
- # Test directory
- target_file = tmpdir / "target.tar.gz"
- results = _tar_path(source_path=source_dir, target_file=target_file)
- assert results.before_size > 0
- assert results.after_size > 0
-
- verify_dir = tmpdir / "verify"
- os.makedirs(verify_dir)
- with tarfile.open(target_file) as tar:
- tar.extractall(verify_dir)
-
- assert (verify_dir / "f1").exists()
- assert (verify_dir / "dir" / "f2").exists()
-
- # Test single file
- f2_path = inner_dir / "f2"
-
- target_file = tmpdir / "target_file.tar.gz"
- results = _tar_path(source_path=f2_path, target_file=target_file)
- assert results.before_size > 0
- assert results.after_size > 0
-
- verify_dir = tmpdir / "verify_file"
- os.makedirs(verify_dir)
- with tarfile.open(target_file) as tar:
- tar.extractall(verify_dir)
-
- assert (verify_dir / "f2").exists()
-
- # Test single file (local)
- monkeypatch.chdir(inner_dir)
-
- f2_path = "f2"
-
- target_file = tmpdir / "target_file_local.tar.gz"
- results = _tar_path(source_path=f2_path, target_file=target_file)
- assert results.before_size > 0
- assert results.after_size > 0
-
- verify_dir = tmpdir / "verify_file_local"
- os.makedirs(verify_dir)
- with tarfile.open(target_file) as tar:
- tar.extractall(verify_dir)
-
- assert (verify_dir / "f2").exists()
-
-
-def test_get_split_size():
- split_size = _get_split_size(minimum_split_size=1024 * 1000 * 10, max_split_count=10000, total_size=200000000001)
-
- # We shouldn't go over the max split count
- assert math.ceil(200000000001 / split_size) <= 10000
-
- split_size = _get_split_size(
- minimum_split_size=1024 * 1000 * 10, max_split_count=10000, total_size=1024 * 500 * 1000 * 10
- )
-
- assert split_size == 1024 * 1000 * 10
-
-
-def test_tar_path_no_compression(tmpdir):
- source_dir, _ = _create_files(tmpdir)
-
- target_file = tmpdir / "target.tar.gz"
- _tar_path(source_path=source_dir, target_file=target_file, compression=False)
-
- verify_dir = tmpdir / "verify"
- os.makedirs(verify_dir)
- with tarfile.open(target_file) as target_tar:
- target_tar.extractall(verify_dir)
-
- assert (verify_dir / "f1").exists()
- assert (verify_dir / "dir" / "f2").exists()
diff --git a/tests/tests_app/source_code/test_uploader.py b/tests/tests_app/source_code/test_uploader.py
deleted file mode 100644
index 7617d9c647510..0000000000000
--- a/tests/tests_app/source_code/test_uploader.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from unittest import mock
-from unittest.mock import ANY, MagicMock
-
-import pytest
-from lightning.app.source_code import uploader
-
-# keeping as global var so individual tests can access/modify it
-response = {"response": MagicMock(headers={"ETag": "test-etag"})}
-
-
-class MockedRequestSession(MagicMock):
- def put(self, url, data):
- assert url == "https://test-url"
- assert data == "test-data"
- return response["response"]
-
- def mount(self, prefix, adapter):
- assert prefix == "https://"
- assert adapter.max_retries.total == 10
-
-
-@mock.patch("builtins.open", mock.mock_open(read_data="test-data"))
-@mock.patch("lightning.app.source_code.uploader.requests.Session", MockedRequestSession)
-def test_file_uploader():
- file_uploader = uploader.FileUploader(
- presigned_url="https://test-url", source_file="test.txt", total_size=100, name="test.txt"
- )
- file_uploader.progress = MagicMock()
-
- file_uploader.upload()
-
- file_uploader.progress.add_task.assert_called_once_with("upload", filename="test.txt", total=100)
- file_uploader.progress.start.assert_called_once()
- file_uploader.progress.update.assert_called_once_with(ANY, advance=9)
-
-
-@mock.patch("builtins.open", mock.mock_open(read_data="test-data"))
-@mock.patch("lightning.app.source_code.uploader.requests.Session", MockedRequestSession)
-def test_file_uploader_failing_when_no_etag():
- response["response"] = MagicMock(headers={})
- presigned_url = "https://test-url"
- file_uploader = uploader.FileUploader(
- presigned_url=presigned_url, source_file="test.txt", total_size=100, name="test.txt"
- )
- file_uploader.progress = MagicMock()
-
- with pytest.raises(ValueError, match=f"Unexpected response from {presigned_url}, response"):
- file_uploader.upload()
diff --git a/tests/tests_app/storage/__init__.py b/tests/tests_app/storage/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/storage/test_copier.py b/tests/tests_app/storage/test_copier.py
deleted file mode 100644
index 14d3f965ce422..0000000000000
--- a/tests/tests_app/storage/test_copier.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import os
-import pathlib
-from unittest import mock
-from unittest.mock import Mock
-
-import lightning.app
-import pytest
-from lightning.app.storage.copier import _Copier, _copy_files
-from lightning.app.storage.path import Path
-from lightning.app.storage.requests import _ExistsRequest, _GetRequest
-from lightning.app.testing.helpers import _MockQueue
-
-
-class MockPatch:
- @staticmethod
- def _handle_get_request(work, request):
- return Path._handle_get_request(work, request)
-
- @staticmethod
- def _handle_exists_request(work, request):
- return Path._handle_exists_request(work, request)
-
-
-@mock.patch("lightning.app.storage.path.pathlib.Path.is_dir")
-@mock.patch("lightning.app.storage.path.pathlib.Path.stat")
-@mock.patch("lightning.app.storage.copier._filesystem")
-def test_copier_copies_all_files(fs_mock, stat_mock, dir_mock, tmpdir):
- """Test that the Copier calls the copy with the information provided in the request."""
- stat_mock().st_size = 0
- dir_mock.return_value = False
- copy_request_queue = _MockQueue()
- copy_response_queue = _MockQueue()
- work = mock.Mock()
- work.name = MockPatch()
- work._paths = {"file": {"source": "src", "path": "file", "hash": "123", "destination": "dest", "name": "name"}}
- with mock.patch.dict(os.environ, {"SHARED_MOUNT_DIRECTORY": str(tmpdir / ".shared")}):
- copier = _Copier(work, copy_request_queue=copy_request_queue, copy_response_queue=copy_response_queue)
- request = _GetRequest(source="src", path="file", hash="123", destination="dest", name="name")
- copy_request_queue.put(request)
- copier.run_once()
- fs_mock().put.assert_called_once_with("file", tmpdir / ".shared" / "123")
-
-
-@mock.patch("lightning.app.storage.path.pathlib.Path.is_dir")
-@mock.patch("lightning.app.storage.path.pathlib.Path.stat")
-def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch):
- """Test that the Copier captures exceptions from the file copy and forwards them through the queue without raising
- it."""
- stat_mock().st_size = 0
- dir_mock.return_value = False
- copy_request_queue = _MockQueue()
- copy_response_queue = _MockQueue()
- fs = mock.Mock()
- fs.exists.return_value = False
- fs.put = mock.Mock(side_effect=OSError("Something went wrong"))
- monkeypatch.setattr(lightning.app.storage.copier, "_filesystem", mock.Mock(return_value=fs))
-
- work = mock.Mock()
- work.name = MockPatch()
- work._paths = {"file": {"source": "src", "path": "file", "hash": "123", "destination": "dest", "name": "name"}}
- copier = _Copier(work, copy_request_queue=copy_request_queue, copy_response_queue=copy_response_queue)
- request = _GetRequest(source="src", path="file", hash="123", destination="dest", name="name")
- copy_request_queue.put(request)
- copier.run_once()
- response = copy_response_queue.get()
- assert type(response.exception) is OSError
- assert response.exception.args[0] == "Something went wrong"
-
-
-def test_copier_existence_check(tmpdir):
- """Test that the Copier responds to an existence check request."""
- copy_request_queue = _MockQueue()
- copy_response_queue = _MockQueue()
-
- work = mock.Mock()
- work.name = MockPatch()
- work._paths = {
- "file": {
- "source": "src",
- "path": str(tmpdir / "notexists"),
- "hash": "123",
- "destination": "dest",
- "name": "name",
- }
- }
-
- copier = _Copier(work, copy_request_queue=copy_request_queue, copy_response_queue=copy_response_queue)
-
- # A Path that does NOT exist
- request = _ExistsRequest(source="src", path=str(tmpdir / "notexists"), destination="dest", name="name", hash="123")
- copy_request_queue.put(request)
- copier.run_once()
- response = copy_response_queue.get()
- assert response.exists is False
-
- # A Path that DOES exist
- request = _ExistsRequest(source="src", path=str(tmpdir), destination="dest", name="name", hash="123")
- copy_request_queue.put(request)
- copier.run_once()
- response = copy_response_queue.get()
- assert response.exists is True
-
-
-def test_copy_files(tmpdir):
- """Test that the `test_copy_files` utility can handle both files and folders when the destination does not
- exist."""
- # copy from a src that does not exist
- src = pathlib.Path(tmpdir, "dir1")
- dst = pathlib.Path(tmpdir, "dir2")
- with pytest.raises(FileNotFoundError):
- _copy_files(src, dst)
-
- # copy to a dst dir that does not exist
- src.mkdir()
- (src / "empty.txt").touch()
- assert not dst.exists()
- _copy_files(src, dst)
- assert dst.is_dir()
-
- # copy to a destination dir that already exists (no error should be raised)
- _copy_files(src, dst)
- assert dst.is_dir()
-
- # copy file to a dst that does not exist
- src = pathlib.Path(tmpdir, "dir3", "src-file.txt")
- dst = pathlib.Path(tmpdir, "dir4", "dst-file.txt")
- src.parent.mkdir(parents=True)
- src.touch()
- assert not dst.exists()
- _copy_files(src, dst)
- assert dst.is_file()
-
-
-def test_copy_files_with_exception(tmpdir):
- """Test that the `test_copy_files` utility properly raises exceptions from within the ThreadPoolExecutor."""
- fs_mock = Mock()
- fs_mock().put = Mock(side_effect=ValueError("error from thread"))
-
- src = pathlib.Path(tmpdir, "src")
- src.mkdir()
- assert src.is_dir()
- pathlib.Path(src, "file.txt").touch()
- dst = pathlib.Path(tmpdir, "dest")
-
- with mock.patch("lightning.app.storage.copier._filesystem", fs_mock), pytest.raises(
- ValueError, match="error from thread"
- ):
- _copy_files(src, dst)
diff --git a/tests/tests_app/storage/test_drive.py b/tests/tests_app/storage/test_drive.py
deleted file mode 100644
index d9fc9b4504372..0000000000000
--- a/tests/tests_app/storage/test_drive.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import os
-import pathlib
-from copy import deepcopy
-from time import sleep
-
-import pytest
-from deepdiff import DeepDiff
-from lightning.app import LightningFlow, LightningWork
-from lightning.app.core.app import LightningApp
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.drive import Drive, _maybe_create_drive
-from lightning.app.utilities.component import _set_flow_context
-
-
-class SyncWorkLITDriveA(LightningWork):
- def __init__(self, tmpdir):
- super().__init__()
- self.tmpdir = tmpdir
-
- def run(self, drive: Drive):
- with open(f"{self.tmpdir}/a.txt", "w") as f:
- f.write("example")
-
- drive.root_folder = self.tmpdir
- drive.put("a.txt")
- os.remove(f"{self.tmpdir}/a.txt")
-
-
-class SyncWorkLITDriveB(LightningWork):
- def run(self, drive: Drive):
- assert not os.path.exists("a.txt")
- drive.get("a.txt")
- assert os.path.exists("a.txt")
-
-
-class SyncFlowLITDrives(LightningFlow):
- def __init__(self, tmpdir):
- super().__init__()
- self.log_dir = Drive("lit://log_dir")
- self.work_a = SyncWorkLITDriveA(str(tmpdir))
- self.work_b = SyncWorkLITDriveB()
-
- def run(self):
- self.work_a.run(self.log_dir)
- self.work_b.run(self.log_dir)
- self.stop()
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=5) # todo: likely dead feature, fine to crash...
-def test_synchronization_lit_drive(tmpdir):
- if os.path.exists("a.txt"):
- os.remove("a.txt")
- app = LightningApp(SyncFlowLITDrives(tmpdir))
- MultiProcessRuntime(app, start_server=False).dispatch()
- if os.path.exists("a.txt"):
- os.remove("a.txt")
-
-
-class LITDriveWork(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.drive = None
- self.counter = 0
-
- def run(self, *args, **kwargs):
- if self.counter == 0:
- self.drive = Drive("lit://this_drive_id")
- sleep(10)
- with open("a.txt", "w") as f:
- f.write("example")
-
- self.drive.put("a.txt")
- else:
- assert self.drive
- assert self.drive.list(".") == ["a.txt"]
- self.drive.delete("a.txt")
- assert self.drive.list(".") == []
- self.counter += 1
-
-
-class LITDriveWork2(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
-
- def run(self, drive: Drive, **kwargs):
- assert drive.list(".") == []
- drive.get("a.txt", timeout=60)
- assert drive.list(".") == ["a.txt"]
- assert drive.list(".", component_name=self.name) == []
-
-
-class LITDriveFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = LITDriveWork()
- self.work2 = LITDriveWork2()
-
- def run(self):
- self.work.run("0")
- if self.work.drive:
- self.work2.run(self.work.drive, something="hello")
- if self.work2.has_succeeded:
- self.work.run("1")
- if self.work.counter == 2:
- self.stop()
-
-
-@pytest.mark.flaky(reruns=3, reruns_delay=5) # todo: likely dead feature, fine to crash...
-def test_lit_drive_transferring_files():
- app = LightningApp(LITDriveFlow())
- MultiProcessRuntime(app, start_server=False).dispatch()
- os.remove("a.txt")
-
-
-@pytest.mark.xfail(strict=False) # todo: likely dead feature, fine to crash...
-def test_lit_drive():
- with pytest.raises(Exception, match="Unknown protocol for the drive 'id' argument"):
- Drive("invalid_drive_id")
-
- with pytest.raises(
- Exception, match="The id should be unique to identify your drive. Found `this_drive_id/something_else`."
- ):
- Drive("lit://this_drive_id/something_else")
-
- drive = Drive("lit://this_drive_id")
- with pytest.raises(Exception, match="The component name needs to be known to put a path to the Drive."):
- drive.put(".")
-
- with pytest.raises(Exception, match="The component name needs to be known to delete a path to the Drive."):
- drive.delete(".")
-
- with open("a.txt", "w") as f:
- f.write("example")
-
- os.makedirs("checkpoints")
- with open("checkpoints/a.txt", "w") as f:
- f.write("example")
-
- drive = Drive("lit://drive_1", allow_duplicates=False)
- drive.component_name = "root.work_1"
- assert drive.list(".") == []
- drive.put("a.txt")
- assert drive.list(".") == ["a.txt"]
- drive.component_name = "root.work_2"
- with pytest.raises(Exception, match="The file a.txt can't be added as already found in the Drive."):
- drive.put("a.txt")
- drive.get("a.txt")
-
- drive = Drive("lit://drive_2", allow_duplicates=False)
- drive.component_name = "root.work_1"
- drive.put("checkpoints/a.txt")
- drive.component_name = "root.work_2"
- with pytest.raises(Exception, match="The file checkpoints/a.txt can't be added as already found in the Drive."):
- drive.put("checkpoints/a.txt")
-
- drive = Drive("lit://drive_3", allow_duplicates=False)
- drive.component_name = "root.work_1"
- drive.put("checkpoints/")
- drive.component_name = "root.work_2"
- with pytest.raises(Exception, match="The file checkpoints/a.txt can't be added as already found in the Drive."):
- drive.put("checkpoints/a.txt")
-
- drive = Drive("lit://drive_3", allow_duplicates=True)
- drive.component_name = "root.work_1"
- drive.put("checkpoints/")
- drive.component_name = "root.work_2"
- with pytest.raises(
- Exception, match="The file checkpoints/a.txt doesn't exists in the component_name space root.work_2."
- ):
- drive.delete("checkpoints/a.txt")
- drive.put("checkpoints/a.txt")
- drive.delete("checkpoints/a.txt")
-
- drive = Drive("lit://drive_3", allow_duplicates=True)
- drive.component_name = "root.work_1"
- drive.put("checkpoints/")
- with pytest.raises(Exception, match="['root.work_1', 'root.work_2']"):
- drive.get("checkpoints/")
- drive.get("checkpoints/a.txt", component_name="root.work_1")
- drive.get("checkpoints/a.txt", component_name="root.work_1", timeout=1)
-
- with pytest.raises(FileNotFoundError):
- drive.get("checkpoints/b.txt", component_name="root.work_1")
- with pytest.raises(Exception, match="The following checkpoints/b.txt wasn't found in 1 seconds"):
- drive.get("checkpoints/b.txt", component_name="root.work_1", timeout=1)
- drive.component_name = "root.work_2"
- drive.put("checkpoints/")
- drive.component_name = "root.work_3"
- with pytest.raises(Exception, match="We found several matching files created by multiples components"):
- drive.get("checkpoints/a.txt")
- with pytest.raises(Exception, match="We found several matching files created by multiples components"):
- drive.get("checkpoints/a.txt", timeout=1)
-
- drive = Drive("lit://drive_4", allow_duplicates=True)
- drive.component_name = "root.work_1"
- with pytest.raises(Exception, match="The following checkpoints/a.txt wasn't found in 1 seconds."):
- drive.get("checkpoints/a.txt", timeout=1)
-
- drive = Drive("lit://test", allow_duplicates=True)
- drive.component_name = "root.work1"
- drive.put("checkpoints")
- drive.get("checkpoints", overwrite=True)
- with pytest.raises(FileExistsError, match="overwrite=True"):
- drive.get("checkpoints")
-
- drive = Drive("lit://drive_5", allow_duplicates=True)
- drive.component_name = "root.work"
- _set_flow_context()
- with pytest.raises(Exception, match="The flow isn't allowed to put files into a Drive."):
- drive.put("a.txt")
- with pytest.raises(Exception, match="The flow isn't allowed to list files from a Drive."):
- drive.list("a.txt")
- with pytest.raises(Exception, match="The flow isn't allowed to get files from a Drive."):
- drive.get("a.txt")
-
- os.remove("checkpoints/a.txt")
- os.rmdir("checkpoints")
- os.remove("a.txt")
-
-
-@pytest.mark.parametrize("drive_id", ["lit://drive"])
-def test_maybe_create_drive(drive_id):
- drive = Drive(drive_id, allow_duplicates=False)
- drive.component_name = "root.work1"
- assert isinstance(drive.root_folder, pathlib.Path)
- drive_state = drive.to_dict()
- assert isinstance(drive_state["root_folder"], str)
- new_drive = _maybe_create_drive(drive.component_name, drive.to_dict())
- assert isinstance(drive.root_folder, pathlib.Path)
- assert new_drive.protocol == drive.protocol
- assert new_drive.id == drive.id
- assert new_drive.component_name == drive.component_name
- drive_state["root_folder"] = pathlib.Path(drive_state["root_folder"])
- copy_drive_state = deepcopy(drive_state)
- deep_diff = DeepDiff(copy_drive_state, drive_state)
- assert "unprocessed" in deep_diff
- deep_diff.pop("unprocessed")
-
-
-@pytest.mark.parametrize("drive_id", ["lit://drive"])
-def test_drive_deepcopy(drive_id):
- drive = Drive(drive_id, allow_duplicates=True)
- drive.component_name = "root.work1"
- new_drive = deepcopy(drive)
- assert new_drive.id == drive.id
- assert new_drive.component_name == drive.component_name
-
-
-def test_s3_drive_raises_error_telling_users_to_use_mounts():
- with pytest.raises(ValueError, match="Using S3 buckets in a Drive is no longer supported."):
- Drive("s3://foo/")
-
-
-def test_drive_root_folder_breaks():
- with pytest.raises(Exception, match="The provided root_folder isn't a directory: a"):
- Drive("lit://drive", root_folder="a")
diff --git a/tests/tests_app/storage/test_filesystem.py b/tests/tests_app/storage/test_filesystem.py
deleted file mode 100644
index 8bc760209c602..0000000000000
--- a/tests/tests_app/storage/test_filesystem.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import os
-import sys
-
-import pytest
-from lightning.app.storage import FileSystem
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="TODO: Add support for windows")
-def test_filesystem(tmpdir):
- fs = FileSystem()
-
- with open(f"{tmpdir}/a.txt", "w") as f:
- f.write("example")
-
- os.makedirs(f"{tmpdir}/checkpoints", exist_ok=True)
- with open(f"{tmpdir}/checkpoints/a.txt", "w") as f:
- f.write("example")
-
- with open(f"{tmpdir}/info.txt", "w") as f:
- f.write("example")
-
- assert fs.listdir("/") == []
- fs.put(f"{tmpdir}/a.txt", "/a.txt")
- fs.put(f"{tmpdir}/info.txt", "/info.txt")
- assert fs.listdir("/") == ["a.txt"]
-
- assert fs.isfile("/a.txt")
-
- fs.put(f"{tmpdir}/checkpoints", "/checkpoints")
- assert not fs.isfile("/checkpoints")
- assert fs.isdir("/checkpoints")
- assert fs.isfile("/checkpoints/a.txt")
-
- assert fs.listdir("/") == ["a.txt", "checkpoints"]
- assert fs.walk("/") == ["a.txt", "checkpoints/a.txt"]
-
- os.remove(f"{tmpdir}/a.txt")
-
- assert not os.path.exists(f"{tmpdir}/a.txt")
-
- fs.get("/a.txt", f"{tmpdir}/a.txt")
-
- assert os.path.exists(f"{tmpdir}/a.txt")
-
- fs.rm("/a.txt")
-
- assert fs.listdir("/") == ["checkpoints"]
- fs.rm("/checkpoints/a.txt")
- assert fs.listdir("/") == ["checkpoints"]
- assert fs.walk("/checkpoints") == []
- fs.rm("/checkpoints/")
- assert fs.listdir("/") == []
-
- with pytest.raises(FileExistsError, match="HERE"):
- fs.put("HERE", "/HERE")
-
- with pytest.raises(RuntimeError, match="The provided path"):
- fs.listdir("/space")
-
-
-@pytest.mark.skipif(sys.platform == "win32", reason="TODO: Add support for windows")
-def test_filesystem_root(tmpdir):
- fs = FileSystem()
-
- with open(f"{tmpdir}/a.txt", "w") as f:
- f.write("example")
-
- os.makedirs(f"{tmpdir}/checkpoints", exist_ok=True)
- with open(f"{tmpdir}/checkpoints/a.txt", "w") as f:
- f.write("example")
-
- assert fs.listdir("/") == []
- fs.put(f"{tmpdir}/a.txt", "/")
- fs.put(f"{tmpdir}/checkpoints", "/")
- assert fs.listdir("/") == ["a.txt", "checkpoints"]
diff --git a/tests/tests_app/storage/test_mount.py b/tests/tests_app/storage/test_mount.py
deleted file mode 100644
index ab4802c76aecf..0000000000000
--- a/tests/tests_app/storage/test_mount.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import pytest
-from lightning.app.storage.mount import Mount
-
-
-def test_create_s3_mount_successfully():
- mount = Mount(source="s3://foo/bar/", mount_path="/foo")
- assert mount.source == "s3://foo/bar/"
- assert mount.mount_path == "/foo"
- assert mount.protocol == "s3://"
-
-
-def test_create_non_s3_mount_fails():
- with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
- Mount(source="foo/bar/", mount_path="/foo")
-
- with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
- Mount(source="gcs://foo/bar/", mount_path="/foo")
-
- with pytest.raises(ValueError, match="Unknown protocol for the mount 'source' argument"):
- Mount(source="3://foo/bar/", mount_path="/foo")
-
-
-def test_create_s3_mount_without_directory_prefix_fails():
- with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
- Mount(source="s3://foo/bar", mount_path="/foo")
-
- with pytest.raises(ValueError, match="S3 mounts must end in a trailing slash"):
- Mount(source="s3://foo", mount_path="/foo")
-
-
-def test_create_mount_without_mount_path_argument():
- m = Mount(source="s3://foo/")
- assert m.mount_path == "/data/foo"
-
- m = Mount(source="s3://foo/bar/")
- assert m.mount_path == "/data/bar"
-
-
-def test_create_mount_path_with_relative_path_errors():
- with pytest.raises(ValueError, match="mount_path argument must be an absolute path"):
- Mount(source="s3://foo/", mount_path="./doesnotwork")
diff --git a/tests/tests_app/storage/test_orchestrator.py b/tests/tests_app/storage/test_orchestrator.py
deleted file mode 100644
index 41a8d0191098b..0000000000000
--- a/tests/tests_app/storage/test_orchestrator.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from unittest.mock import MagicMock
-
-from lightning.app.storage.orchestrator import StorageOrchestrator
-from lightning.app.storage.requests import _GetRequest, _GetResponse
-from lightning.app.testing.helpers import _MockQueue
-from lightning.app.utilities.enum import WorkStageStatus
-
-
-def test_orchestrator():
- """Simulate orchestration when Work B requests a file from Work A."""
- request_queues = {"work_a": _MockQueue(), "work_b": _MockQueue()}
- response_queues = {"work_a": _MockQueue(), "work_b": _MockQueue()}
- copy_request_queues = {"work_a": _MockQueue(), "work_b": _MockQueue()}
- copy_response_queues = {"work_a": _MockQueue(), "work_b": _MockQueue()}
- app = MagicMock()
- work = MagicMock()
- work.status.stage = WorkStageStatus.RUNNING
- app.get_component_by_name = MagicMock(return_value=work)
-
- orchestrator = StorageOrchestrator(
- app,
- request_queues=request_queues,
- response_queues=response_queues,
- copy_request_queues=copy_request_queues,
- copy_response_queues=copy_response_queues,
- )
-
- # test idle behavior when queues are empty
- orchestrator.run_once("work_a")
- orchestrator.run_once("work_b")
- assert not orchestrator.waiting_for_response
-
- # simulate Work B sending a request for a file in Work A
- request = _GetRequest(source="work_a", path="/a/b/c.txt", hash="", destination="", name="")
- request_queues["work_b"].put(request)
- orchestrator.run_once("work_a")
- assert not orchestrator.waiting_for_response
- orchestrator.run_once("work_b")
-
- # orchestrator is now waiting for a response for copier in Work A
- assert "work_b" in orchestrator.waiting_for_response
- assert len(request_queues["work_a"]) == 0
- assert request in copy_request_queues["work_a"]
- assert request.destination == "work_b"
-
- # simulate loop while waiting for new elements in the queues
- orchestrator.run_once("work_a")
- orchestrator.run_once("work_b")
-
- # edge case: `None` requests on the queue get ignored
- # TODO: Investigate how `None` values end up in the queue
- request_queues["work_a"].put(None)
- orchestrator.run_once("work_a")
- orchestrator.run_once("work_b")
- assert not request_queues["work_a"]._queue
-
- # simulate copier A confirms that the file is available on the shared volume
- response = _GetResponse(source="work_a", path="/a/b/c.txt", hash="", destination="work_b", name="")
- copy_request_queues["work_a"].get()
- copy_response_queues["work_a"].put(response)
-
- # orchestrator processes confirmation and confirms to the pending request from Work B
- orchestrator.run_once("work_a")
- assert len(copy_response_queues["work_a"]) == 0
- assert response in response_queues["work_b"]
- assert not orchestrator.waiting_for_response
- orchestrator.run_once("work_b")
-
- # simulate loop while waiting for new elements in the queues
- orchestrator.run_once("work_a")
- orchestrator.run_once("work_b")
- assert not orchestrator.waiting_for_response
-
- # simulate Work B receiving the confirmation that the file was copied
- response = response_queues["work_b"].get()
- assert response.source == "work_a"
- assert response.destination == "work_b"
- assert response.exception is None
-
- # all queues should be empty
- assert all(len(queue) == 0 for queue in request_queues.values())
- assert all(len(queue) == 0 for queue in response_queues.values())
- assert all(len(queue) == 0 for queue in copy_request_queues.values())
- assert all(len(queue) == 0 for queue in copy_response_queues.values())
diff --git a/tests/tests_app/storage/test_path.py b/tests/tests_app/storage/test_path.py
deleted file mode 100644
index 2ba617d195ffc..0000000000000
--- a/tests/tests_app/storage/test_path.py
+++ /dev/null
@@ -1,724 +0,0 @@
-import json
-import os
-import pathlib
-import pickle
-import sys
-from re import escape
-from time import sleep
-from unittest import TestCase, mock
-from unittest.mock import MagicMock, Mock
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.path import (
- Path,
- _artifacts_path,
- _filesystem,
- _is_lit_path,
- _shared_storage_path,
- _storage_root_dir,
-)
-from lightning.app.storage.requests import _ExistsResponse, _GetResponse
-from lightning.app.testing.helpers import EmptyWork, _MockQueue, _RunIf
-from lightning.app.utilities.app_helpers import LightningJSONEncoder
-from lightning.app.utilities.component import _context
-from lightning.app.utilities.imports import _IS_WINDOWS, _is_s3fs_available
-
-
-def test_path_instantiation():
- assert Path() == pathlib.Path()
- assert Path("a/b") == pathlib.Path("a/b")
- assert Path("a", "b") == pathlib.Path("a", "b")
- assert Path(pathlib.Path("a"), pathlib.Path("b")) == pathlib.Path("a/b")
- assert Path(Path(Path("a/b"))) == pathlib.Path("a/b")
-
- path = Path()
- assert path._origin is path._consumer is path._request_queue is path._response_queue is None
-
- folder = Path("x/y/z")
- folder._origin = "origin"
- folder._consumer = "consumer"
-
- # from parts where the first is a Lightning Path and the other(s) are string
- file = Path(folder, "file.txt")
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
- # from parts that are instance of Path and have no origin
- file = Path(folder, Path("file.txt"))
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
- # from parts that are instance of Path and have a different origin than the top folder
- filename = Path("file.txt")
- filename._origin = "different"
- with pytest.raises(TypeError, match="Tried to instantiate a Lightning Path from multiple other Paths"):
- Path(folder, filename)
-
- # from parts that are instance of Path and have the SAME origin as the top folder
- filename = Path("file.txt")
- filename._origin = "origin"
- file = Path(folder, filename)
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
-
-def test_path_instantiation_lit():
- assert Path("lit://") == _storage_root_dir()
- assert Path("lit://a/b") == pathlib.Path(_storage_root_dir(), "a/b")
- assert Path("lit://", "a", "b") == pathlib.Path(_storage_root_dir(), "a", "b")
- assert Path("lit://", pathlib.Path("a"), pathlib.Path("b")) == pathlib.Path(_storage_root_dir(), "a/b")
- assert Path(Path(Path("lit://a/b"))) == pathlib.Path(_storage_root_dir(), "a", "b")
- assert str(Path("lit://lit-path")) == os.path.join(_storage_root_dir(), "lit-path")
-
-
-def test_is_lit_path():
- assert not _is_lit_path("lit")
- assert not _is_lit_path(Path("lit"))
- assert _is_lit_path("lit://")
- assert _is_lit_path(Path("lit://"))
- assert _is_lit_path("lit://a/b/c")
- assert _is_lit_path(Path("lit://a/b/c"))
- assert _is_lit_path(_storage_root_dir())
-
-
-def test_path_copy():
- """Test that Path creates an exact copy when passing a Path instance to the constructor."""
- path = Path("x/y/z")
- path._origin = "origin"
- path._consumer = "consumer"
- path._request_queue = Mock()
- path._response_queue = Mock()
- path_copy = Path(path)
- assert path_copy._origin == path._origin
- assert path_copy._consumer == path._consumer
- assert path_copy._request_queue == path._request_queue
- assert path_copy._response_queue == path._response_queue
-
-
-def test_path_inheritance():
- """Test that the Lightning Path is a drop-in replacement for pathlib.Path without compromises."""
- file = Path("file.txt")
- pathlibfile = pathlib.Path("file.txt")
- assert file == pathlibfile
- assert isinstance(file, Path)
- assert isinstance(file, pathlib.Path)
-
- folder = Path("./x/y")
- file = folder / "file.txt"
- assert isinstance(file, Path)
-
- file.with_suffix(".png")
- assert isinstance(file, Path)
-
-
-def test_path_concatenation():
- """Test that path concatentaions keep the properties of the paths on the right-hand side of the join."""
- folder = Path("x/y/z")
- folder._origin = "origin"
- folder._consumer = "consumer"
- other = Path("other")
-
- # test __truediv__ when Path is on the left-hand side
- file = folder / other / "more" / "file.txt"
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
- # test __rtruediv__ when Path is on the right-hand side
- switched = pathlib.Path("/") / folder
- assert isinstance(switched, Path)
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
-
-def test_path_with_replacement():
- """Test that the ``Path.with_*`` modifiers keep the properties."""
- folder = Path("x", "y", "z")
- folder._origin = "origin"
- folder._consumer = "consumer"
-
- # with_name
- file = folder.with_name("file.txt")
- assert str(file) == os.path.join("x", "y", "file.txt")
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
- # with_suffix
- file = file.with_suffix(".png")
- assert str(file) == os.path.join("x", "y", "file.png")
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
- # relative_to
- rel_path = folder.relative_to("x")
- assert str(rel_path) == os.path.join("y", "z")
- assert rel_path._origin == "origin"
- assert rel_path._consumer == "consumer"
-
-
-@_RunIf(min_python="3.9")
-def test_path_with_stem_replacement():
- """Test that the ``Path.with_stem`` modifier keep the properties.
-
- This is only available in Python 3.9+.
-
- """
- file = Path("x", "y", "file.txt")
- file._origin = "origin"
- file._consumer = "consumer"
- file = file.with_stem("text")
- assert str(file) == os.path.join("x", "y", "text.txt")
- assert file._origin == "origin"
- assert file._consumer == "consumer"
-
-
-def test_path_parents():
- """Test that the ``Path.parent`` and ``Path.parent`` properties return Paths that inherit the origin and consumer
- attributes."""
- path = Path("a", "b", "c", "d")
- path._origin = "origin"
- path._consumer = "consumer"
-
- # .parent
- assert isinstance(path.parent, Path)
- assert str(path.parent) == os.path.join("a", "b", "c")
- assert path.parent._origin == "origin"
- assert path.parent._consumer == "consumer"
-
- # .parents
- assert path.parents == [Path("a", "b", "c"), Path("a", "b"), Path("a"), Path(".")]
- assert all(parent._origin == "origin" for parent in path.parents)
- assert all(parent._consumer == "consumer" for parent in path.parents)
-
-
-def test_path_hash():
- """Test that the value of the Path hash is a function of the path name and the origin."""
- # a path without origin has no hash
- assert Path("one").hash is Path("two").hash is None
-
- # identical paths with identical origins have the same hash
- path1 = Path("one")
- path2 = Path("one")
- path1._origin = "origin1"
- path1._consumer = "consumer1"
- path2._origin = "origin1"
- path1._consumer = "consumer2"
- assert path1.hash == path2.hash
-
- # identical paths with different origins have different hash
- path2._origin = "origin2"
- assert path1.hash != path2.hash
-
- # different paths but same owner yields a different hash
- path1 = Path("one")
- path2 = Path("other")
- path1._origin = "same"
- path2._origin = "same"
- assert path1.hash != path2.hash
-
-
-def test_path_pickleable():
- path = Path("a/b/c.txt")
- path._origin = "root.x.y.z"
- path._consumer = "root.p.q.r"
- path._request_queue = Mock()
- path._response_queue = Mock()
- loaded = pickle.loads(pickle.dumps(path))
- assert isinstance(loaded, Path)
- assert loaded == path
- assert loaded._origin == path._origin
- assert loaded._consumer == path._consumer
- assert loaded._request_queue is None
- assert loaded._response_queue is None
-
-
-def test_path_json_serializable():
- path = Path("a/b/c.txt")
- path._origin = "root.x.y.z"
- path._consumer = "root.p.q.r"
- path._request_queue = Mock()
- path._response_queue = Mock()
- json_dump = json.dumps(path, cls=LightningJSONEncoder)
- assert "path" in json_dump
- # the replacement of \ is needed for Windows paths
- assert str(path).replace("\\", "\\\\") in json_dump
- assert "origin_name" in json_dump
- assert path._origin in json_dump
- assert "consumer_name" in json_dump
- assert path._consumer in json_dump
-
-
-def test_path_to_dict_from_dict():
- path = Path("a/b/c.txt")
- path._origin = "root.x.y.z"
- path._consumer = "root.p.q.r"
- path._request_queue = Mock()
- path._response_queue = Mock()
- path_dict = path.to_dict()
- same_path = Path.from_dict(path_dict)
- assert same_path == path
- assert same_path._origin == path._origin
- assert same_path._consumer == path._consumer
- assert same_path._request_queue is None
- assert same_path._response_queue is None
- assert same_path._metadata == path._metadata
-
-
-def test_path_attach_work():
- """Test that attaching a path to a LighitningWork will make the Work either the origin or a consumer."""
- path = Path()
- assert path._origin is None
- work1 = EmptyWork()
- work2 = EmptyWork()
- work3 = EmptyWork()
- path._attach_work(work=work1)
- assert path._origin is work1
- # path already has an owner
- path._attach_work(work=work2)
- assert path._origin is work1
- assert path._consumer is work2
-
- # path gets a new consumer
- path._attach_work(work=work3)
- assert path._origin is work1
- assert path._consumer is work3
-
-
-def test_path_attach_queues():
- path = Path()
- request_queue = Mock()
- response_queue = Mock()
- path._attach_queues(request_queue=request_queue, response_queue=response_queue)
- assert path._request_queue is request_queue
- assert path._response_queue is response_queue
-
-
-@pytest.mark.parametrize("cls", [LightningFlow, LightningWork])
-def test_path_in_flow_and_work(cls, tmpdir):
- class PathComponent(cls):
- def __init__(self):
- super().__init__()
- self.path_one = Path("a", "b")
- self.path_one = Path("a", "b", "c")
- self.path_two = Path(tmpdir) / "write.txt"
-
- def run(self):
- self.path_one = self.path_one / "d.txt"
- assert self.path_one == Path("a", "b", "c", "d.txt")
- with open(self.path_two, "w") as file:
- file.write("Hello")
-
- class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.path_component = PathComponent()
-
- def run(self):
- self.path_component.run()
-
- root = RootFlow()
- _ = LightningApp(root) # create an app to convert all paths that got attached
-
- root.run()
-
- assert root.path_component.path_one == Path("a", "b", "c", "d.txt")
- assert root.path_component.path_one == pathlib.Path("a", "b", "c", "d.txt")
- if isinstance(root.path_component, LightningWork):
- assert root.path_component.path_one.origin_name == "root.path_component"
- assert root.path_component.path_one.consumer_name == "root.path_component"
- else:
- assert root.path_component.path_one._origin is None
- assert root.path_component.path_one._consumer is None
- with open(root.path_component.path_two) as fo:
- assert fo.readlines() == ["Hello"]
-
-
-class SourceWork(LightningWork):
- def __init__(self, tmpdir):
- super().__init__(cache_calls=True)
- self.path = Path(tmpdir, "src.txt")
- assert self.path.origin_name == ""
-
- def run(self):
- with open(self.path, "w") as f:
- f.write("Hello from SourceWork")
-
-
-class DestinationWork(LightningWork):
- def __init__(self, source_path):
- super().__init__(cache_calls=True)
- assert source_path.origin_name == "root.src_work"
- self.path = source_path
- assert self.path.origin_name == "root.src_work"
- self.other = Path("other")
- assert self.other.origin_name == ""
-
- def run(self):
- assert self.path.origin_name == "root.src_work"
- assert self.other.origin_name == "root.dst_work"
- # we are running locally, the file is already there (no transfer needed)
- self.path.get(overwrite=True)
- assert self.path.is_file()
- assert self.path.read_text() == "Hello from SourceWork"
-
-
-class SourceToDestFlow(LightningFlow):
- def __init__(self, tmpdir):
- super().__init__()
- self.src_work = SourceWork(tmpdir)
- self.dst_work = DestinationWork(self.src_work.path)
-
- def run(self):
- self.src_work.run()
- if self.src_work.has_succeeded:
- self.dst_work.run()
- if self.dst_work.has_succeeded:
- self.stop()
-
-
-@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow on Windows or macOs")
-def test_multiprocess_path_in_work_and_flow(tmpdir):
- root = SourceToDestFlow(tmpdir)
- app = LightningApp(root, log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class DynamicSourceToDestFlow(LightningFlow):
- def __init__(self, tmpdir):
- super().__init__()
- self.tmpdir = str(tmpdir)
-
- def run(self):
- if not hasattr(self, "src_work"):
- self.src_work = SourceWork(self.tmpdir)
- self.src_work.run()
- if self.src_work.has_succeeded:
- if not hasattr(self, "dst_work"):
- self.dst_work = DestinationWork(self.src_work.path)
- self.dst_work.run()
- if hasattr(self, "dst_work") and self.dst_work.has_succeeded:
- self.stop()
-
-
-# FIXME(alecmerdler): This test is failing...
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="hanging...")
-def test_multiprocess_path_in_work_and_flow_dynamic(tmpdir):
- root = DynamicSourceToDestFlow(tmpdir)
- app = LightningApp(root)
- MultiProcessRuntime(app).dispatch()
-
-
-class RunPathFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.src_work = PathSourceWork()
- self.run_work = RunPathWork(cache_calls=True)
-
- def run(self):
- self.src_work.run()
- assert self.src_work.src_path_0.origin_name == "root.src_work"
- assert self.src_work.src_path_0.consumer_name == "root.src_work"
-
- # local_path is not attached to any Work
- local_path_0 = Path("local", "file_0.txt")
- local_path_1 = Path("local", "file_1.txt")
- assert local_path_0.origin_name is None
- assert local_path_0.consumer_name is None
-
- nested_local_path = (99, {"nested": local_path_1})
- nested_kwarg_path = ["x", (self.src_work.src_path_1,)]
-
- # TODO: support returning a path from run()
- self.run_work.run(
- self.src_work.src_path_0,
- local_path_0,
- nested_local_path,
- kwarg_path=local_path_1,
- nested_kwarg_path=nested_kwarg_path,
- )
- sleep(1)
- self.stop()
-
-
-class PathSourceWork(EmptyWork):
- def __init__(self):
- super().__init__()
- self.src_path_0 = Path("src", "file_0.txt")
- self.src_path_1 = Path("src", "file_1.txt")
-
-
-class RunPathWork(LightningWork):
- def run(self, src_path_0, local_path_0, nested_local_path, kwarg_path=None, nested_kwarg_path=None):
- all_paths = []
-
- # src_path_0 has an origin which must be preserved, this work becomes consumer
- assert str(src_path_0) == os.path.join("src", "file_0.txt")
- assert src_path_0.origin_name == "root.src_work"
- all_paths.append(src_path_0)
-
- # local_path_0 had no origin, this work becomes both the origin and the consumer
- assert str(local_path_0) == os.path.join("local", "file_0.txt")
- assert local_path_0.origin_name is None
- assert local_path_0.consumer_name is None
- all_paths.append(local_path_0)
-
- # nested_local_path is a nested container that contains a Path
- assert str(nested_local_path[1]["nested"]) == os.path.join("local", "file_1.txt")
- assert nested_local_path[1]["nested"].origin_name is None
- assert nested_local_path[1]["nested"].consumer_name is None
- all_paths.append(nested_local_path[1]["nested"])
-
- # keywoard arguments can also contain Paths
- assert str(kwarg_path) == os.path.join("local", "file_1.txt")
- assert kwarg_path.origin_name is None
- assert kwarg_path.consumer_name is None
- all_paths.append(kwarg_path)
-
- assert str(nested_kwarg_path[1][0]) == os.path.join("src", "file_1.txt")
- assert nested_kwarg_path[1][0].origin_name == "root.src_work"
- all_paths.append(nested_kwarg_path[1][0])
-
- all(p._request_queue == self._request_queue for p in all_paths)
- all(p._response_queue == self._response_queue for p in all_paths)
- all(p.consumer_name == self.name == "root.run_work" for p in all_paths)
-
-
-def test_path_as_argument_to_run_method():
- """Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that the
- origin, consumer and queues get automatically attached."""
- root = RunPathFlow()
- app = LightningApp(root)
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-def test_path_get_errors(tmpdir):
- with _context("work"):
- with pytest.raises(
- RuntimeError, match="Trying to get the file .* but the path is not attached to a LightningApp"
- ):
- Path().get()
-
- with pytest.raises(
- RuntimeError, match="Trying to get the file .* but the path is not attached to a LightningWork"
- ):
- path = Path()
- path._attach_queues(Mock(), Mock())
- path.get()
-
- with pytest.raises(FileExistsError, match="The file or folder .* exists locally. Pass `overwrite=True"):
- path = Path(tmpdir)
- path._attach_queues(Mock(), Mock())
- path._attach_work(Mock())
- path.get()
-
-
-class SourceOverwriteWork(LightningWork):
- def __init__(self, tmpdir):
- super().__init__(raise_exception=True)
- self.path = Path(tmpdir, "folder")
-
- def run(self):
- self.path.mkdir(parents=True, exist_ok=True)
- (self.path / "file.txt").touch()
- assert self.path.exists_local()
-
-
-class DestinationOverwriteWork(LightningWork):
- def __init__(self, source_path):
- super().__init__(raise_exception=True)
- self.path = source_path
-
- def run(self):
- assert self.path.exists()
- with mock.patch("lightning.app.storage.path.shutil") as shutil_mock:
- self.path.get(overwrite=True)
- shutil_mock.rmtree.assert_called_with(self.path)
- assert self.path.exists()
- assert (self.path / "file.txt").exists()
-
-
-class OverwriteFolderFlow(LightningFlow):
- def __init__(self, tmpdir):
- super().__init__()
- self.src_work = SourceOverwriteWork(tmpdir)
- self.dst_work = DestinationOverwriteWork(self.src_work.path)
-
- def run(self):
- self.src_work.run()
- if self.src_work.has_succeeded:
- self.dst_work.run()
- if self.dst_work.has_succeeded:
- self.stop()
-
-
-def test_path_get_overwrite(tmpdir):
- """Test that .get(overwrite=True) overwrites the entire directory and replaces all files."""
- root = OverwriteFolderFlow(tmpdir)
- app = LightningApp(root, log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-def test_path_get_error_in_flow_context():
- with pytest.raises(RuntimeError, match=escape("`Path.get()` can only be called from within the `run()`")), _context(
- "flow"
- ):
- Path().get()
-
-
-def test_path_response_with_exception(tmpdir):
- request_queue = _MockQueue()
- response_queue = _MockQueue()
- path = Path(tmpdir / "file.txt")
- path._attach_queues(request_queue, response_queue)
- path._origin = "origin"
- path._consumer = "consumer"
-
- # simulate that a response will come with an exception raised
- response_queue.put(
- _GetResponse(
- source="origin",
- path=str(tmpdir / "file.txt"),
- hash=path.hash,
- destination="consumer",
- exception=OSError("Something went wrong"),
- name="",
- )
- )
-
- with pytest.raises(
- RuntimeError, match="An exception was raised while trying to transfer the contents at"
- ), _context("work"):
- path.get()
-
-
-def test_path_response_not_matching_reqeuest(tmpdir):
- request_queue = _MockQueue()
- response_queue = _MockQueue()
- path = Path(tmpdir / "file.txt")
- path._attach_queues(request_queue, response_queue)
- path._origin = "origin"
- path._consumer = "consumer"
-
- # simulate a response that has a different owner than the request had
- response = _GetResponse(
- source="other_origin", path=str(tmpdir / "file.txt"), hash=path.hash, destination="consumer", name=""
- )
-
- response_queue.put(response)
- with pytest.raises(
- RuntimeError, match="Tried to get the file .* but received a response for a request it did not send."
- ):
- path.get()
-
- # simulate a response that has a different hash than the request had
- assert len(response_queue) == 0
- response.path = str(path)
- response.hash = "other_hash"
- response_queue.put(response)
- with pytest.raises(
- RuntimeError, match="Tried to get the file .* but received a response for a request it did not send."
- ):
- path.get()
-
-
-def test_path_exists(tmpdir):
- """Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if not,
- send a message to the orchestrator to eventually check the existenc on the origin Work."""
- # Local Path (no Work queues attached)
- assert not Path("file").exists()
- assert Path(tmpdir).exists()
- with open(tmpdir / "file", "w"):
- assert Path(tmpdir / "file").exists()
-
- # A local path that exists
- path = Path(tmpdir)
- path.exists_remote = Mock()
- path.exists_local = Mock(return_value=True)
- assert path.exists() is True
- path.exists_local.assert_called_once()
- path.exists_remote.assert_not_called() # don't check remotely
-
- # A local path that does not exist, but has no Work attached
- path = Path("not-exists.txt")
- path.exists_local = Mock(return_value=False)
- path.exists_remote = Mock()
- assert not path.exists()
- path.exists_local.assert_called_once()
- path.exists_remote.assert_not_called() # don't check remotely
-
- # A local path that does not exist, but it exists remotely
- path = Path("exists-remotely-only.txt")
- path.exists_local = Mock(return_value=False)
- path.exists_remote = Mock(return_value=True)
- path._origin = "origin"
- assert path.exists()
- path.exists_local.assert_called_once()
- path.exists_remote.assert_called_once() # check remotely
-
-
-def test_path_exists_local(tmpdir):
- assert not Path("file").exists_local()
- assert Path(tmpdir).exists_local()
- with open(tmpdir / "file", "w"):
- assert Path(tmpdir / "file").exists_local()
-
-
-def test_path_exists_remote(tmpdir):
- path = Path(tmpdir / "not-attached.txt")
- with pytest.raises(RuntimeError, match="the path is not attached to a LightningWork"):
- path.exists_remote()
-
- # If Path does not exist locally, ask the orchestrator
- request_queue = _MockQueue()
- response_queue = _MockQueue()
- path = Path(tmpdir / "not-exists.txt")
- path._attach_queues(request_queue, response_queue)
- path._origin = "origin"
- path._consumer = "consumer"
-
- # Put the response into the queue to simulate the orchestrator responding
- response_queue.put(_ExistsResponse(source=path.origin_name, path=str(path), name="", hash="123", exists=False))
- assert not path.exists_remote()
- assert request_queue.get()
-
- response_queue.put(_ExistsResponse(source=path.origin_name, path=str(path), name="", hash="123", exists=True))
- assert path.exists_remote()
- assert request_queue.get()
-
-
-def test_artifacts_path():
- work = Mock()
- work.name = "root.flow.work"
- assert _artifacts_path(work) == _shared_storage_path() / "artifacts" / "root.flow.work"
-
-
-@pytest.mark.skipif(not _is_s3fs_available(), reason="This test requires s3fs.")
-@mock.patch.dict(os.environ, {"LIGHTNING_BUCKET_ENDPOINT_URL": "a"})
-@mock.patch.dict(os.environ, {"LIGHTNING_BUCKET_NAME": "b"})
-@mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_APP_ID": "e"})
-def test_filesystem(monkeypatch):
- from lightning.app.storage import path
-
- mock = MagicMock()
- monkeypatch.setattr(path, "S3FileSystem", mock)
- fs = _filesystem()
- assert fs == mock()
-
-
-class TestSharedStoragePath(TestCase):
- @mock.patch.dict(os.environ, {"LIGHTNING_STORAGE_PATH": "test-bucket/lightningapps/test-project/test-app"})
- def test_shared_storage_path_storage_path_set(self):
- assert pathlib.Path("test-bucket/lightningapps/test-project/test-app") == _shared_storage_path()
-
- @mock.patch.dict(os.environ, {"LIGHTNING_CLOUD_APP_ID": "test-app", "LIGHTNING_BUCKET_NAME": "test-bucket"})
- def test_shared_storage_path_bucket_and_app_id_set(self):
- assert pathlib.Path("test-bucket/lightningapps/test-app") == _shared_storage_path()
-
- @mock.patch.dict(os.environ, {"SHARED_MOUNT_DIRECTORY": "test-app/.shared"})
- def test_shared_storage_path_mount_directory_set(self):
- assert _shared_storage_path().match("*/test-app/.shared")
-
- def test_shared_storage_path_no_envvars_set(self):
- assert _shared_storage_path().match("*/.shared")
diff --git a/tests/tests_app/storage/test_payload.py b/tests/tests_app/storage/test_payload.py
deleted file mode 100644
index 453cead1c6583..0000000000000
--- a/tests/tests_app/storage/test_payload.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import os
-import pathlib
-import pickle
-from copy import deepcopy
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.runners.multiprocess import MultiProcessRuntime
-from lightning.app.storage.payload import Payload
-from lightning.app.storage.requests import _GetRequest
-
-
-def test_payload_copy():
- """Test that Payload creates an exact copy when passing a Payload instance to the constructor."""
- payload = Payload(None)
- payload._origin = "origin"
- payload._consumer = "consumer"
- payload._request_queue = "MockQueue"
- payload._response_queue = "MockQueue"
- payload_copy = deepcopy(payload)
- assert payload_copy._origin == payload._origin
- assert payload_copy._consumer == payload._consumer
- assert payload_copy._request_queue == payload._request_queue
- assert payload_copy._response_queue == payload._response_queue
-
-
-def test_payload_pickable():
- payload = Payload("MyObject")
- payload._origin = "root.x.y.z"
- payload._consumer = "root.p.q.r"
- payload._name = "var_a"
- loaded = pickle.loads(pickle.dumps(payload))
-
- assert isinstance(loaded, Payload)
- assert loaded._origin == payload._origin
- assert loaded._consumer == payload._consumer
- assert loaded._name == payload._name
- assert loaded._request_queue is None
- assert loaded._response_queue is None
-
-
-def test_path_attach_queues():
- path = Payload(None)
- request_queue = Mock()
- response_queue = Mock()
- path._attach_queues(request_queue=request_queue, response_queue=response_queue)
- assert path._request_queue is request_queue
- assert path._response_queue is response_queue
-
-
-class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.var_a = Payload(None)
-
- def run(self):
- pass
-
-
-def test_payload_in_init():
- with pytest.raises(
- AttributeError, match="The Payload object should be set only within the run method of the work."
- ):
- Work()
-
-
-class WorkRun(LightningWork):
- def __init__(self, tmpdir):
- super().__init__()
- self.var_a = None
- self.tmpdir = tmpdir
-
- def run(self):
- self.var_a = Payload("something")
- assert self.var_a.name == "var_a"
- assert self.var_a._origin == "root.a"
- assert self.var_a.hash == "9bd514ad51fc33d895c50657acd0f0582301cf3e"
- source_path = pathlib.Path(self.tmpdir, self.var_a.name)
- assert not source_path.exists()
- response = self.var_a._handle_get_request(
- self,
- _GetRequest(
- name="var_a",
- hash=self.var_a.hash,
- source="root.a",
- path=str(source_path),
- destination="root",
- ),
- )
- assert source_path.exists()
- assert self.var_a.load(str(source_path)) == "something"
- assert not response.exception
-
-
-def test_payload_in_run(tmpdir):
- work = WorkRun(str(tmpdir))
- work._name = "root.a"
- work.run()
-
-
-class Sender(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.value_all = None
- self.value_b = None
- self.value_c = None
-
- def run(self):
- self.value_all = Payload(["A", "B", "C"])
- self.value_b = Payload("B")
- self.value_c = Payload("C")
-
-
-class WorkReceive(LightningWork):
- def __init__(self, expected):
- super().__init__(parallel=True)
- self.expected = expected
-
- def run(self, generated):
- assert generated.value == self.expected
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.sender = Sender()
- self.receiver_all = WorkReceive(["A", "B", "C"])
- self.receiver_b = WorkReceive("B")
- self.receiver_c = WorkReceive("C")
-
- def run(self):
- self.sender.run()
- if self.sender.value_all:
- self.receiver_all.run(self.sender.value_all)
- if self.sender.value_b:
- self.receiver_b.run(self.sender.value_b)
- if self.sender.value_c:
- self.receiver_c.run(self.sender.value_c)
- if self.receiver_all.has_succeeded and self.receiver_b.has_succeeded and self.receiver_c.has_succeeded:
- self.stop()
-
-
-@pytest.mark.xfail(strict=False, reason="flaky")
-def test_payload_works(tmpdir):
- """This tests validates the payload api can be used to transfer return values from a work to another."""
- with mock.patch("lightning.app.storage.path._storage_root_dir", return_value=pathlib.Path(tmpdir)):
- app = LightningApp(Flow(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- os.remove("value_all")
- os.remove("value_b")
- os.remove("value_c")
diff --git a/tests/tests_app/structures/__init__.py b/tests/tests_app/structures/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py
deleted file mode 100644
index 2d0c5e58005f0..0000000000000
--- a/tests/tests_app/structures/test_structures.py
+++ /dev/null
@@ -1,564 +0,0 @@
-import os
-from copy import deepcopy
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage.payload import Payload
-from lightning.app.structures import Dict, List
-from lightning.app.testing.helpers import EmptyFlow
-from lightning.app.utilities.enum import CacheCallsKeys, WorkStageStatus
-from lightning.app.utilities.imports import _IS_WINDOWS
-
-
-def test_dict():
- class WorkA(LightningWork):
- def __init__(self):
- super().__init__(port=1)
- self.c = 0
-
- def run(self):
- pass
-
- class A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.dict = Dict(**{"work_a": WorkA(), "work_b": WorkA(), "work_c": WorkA(), "work_d": WorkA()})
-
- def run(self):
- pass
-
- flow = A()
-
- # TODO: these assertions are wrong, the works are getting added under "flows" instead of "works"
- # state
- assert len(flow.state["structures"]["dict"]["works"]) == len(flow.dict) == 4
- assert list(flow.state["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"]
- assert all(
- flow.state["structures"]["dict"]["works"][f"work_{k}"]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_display_name": "",
- "_internal_ip": "",
- "_public_ip": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for k in ("a", "b", "c", "d")
- )
- assert all(
- flow.state["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None}
- for k in ("a", "b", "c", "d")
- )
- assert all(flow.state["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {} for k in ("a", "b", "c", "d"))
-
- # state_vars
- assert len(flow.state_vars["structures"]["dict"]["works"]) == len(flow.dict) == 4
- assert list(flow.state_vars["structures"]["dict"]["works"].keys()) == ["work_a", "work_b", "work_c", "work_d"]
- assert all(
- flow.state_vars["structures"]["dict"]["works"][f"work_{k}"]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_display_name": "",
- "_internal_ip": "",
- "_public_ip": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for k in ("a", "b", "c", "d")
- )
-
- # state_with_changes
- assert len(flow.state_with_changes["structures"]["dict"]["works"]) == len(flow.dict) == 4
- assert list(flow.state_with_changes["structures"]["dict"]["works"].keys()) == [
- "work_a",
- "work_b",
- "work_c",
- "work_d",
- ]
- assert all(
- flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_display_name": "",
- "_internal_ip": "",
- "_public_ip": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for k in ("a", "b", "c", "d")
- )
- assert all(
- flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["calls"]
- == {CacheCallsKeys.LATEST_CALL_HASH: None}
- for k in ("a", "b", "c", "d")
- )
- assert all(
- flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {}
- for k in ("a", "b", "c", "d")
- )
-
- # set_state
- state = deepcopy(flow.state)
- state["structures"]["dict"]["works"]["work_b"]["vars"]["c"] = 1
- flow.set_state(state)
- assert flow.dict["work_b"].c == 1
-
-
-def test_dict_name():
- d = Dict(a=EmptyFlow(), b=EmptyFlow())
- assert d.name == "root"
- assert d["a"].name == "root.a"
- assert d["b"].name == "root.b"
-
- class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.dict = Dict(x=EmptyFlow(), y=EmptyFlow())
-
- def run(self):
- pass
-
- root = RootFlow()
- assert root.name == "root"
- assert root.dict.name == "root.dict"
- assert root.dict["x"].name == "root.dict.x"
- assert root.dict["y"].name == "root.dict.y"
-
-
-def test_list():
- class WorkA(LightningWork):
- def __init__(self):
- super().__init__(port=1)
- self.c = 0
-
- def run(self):
- pass
-
- class A(LightningFlow):
- def __init__(self):
- super().__init__()
- self.list = List(WorkA(), WorkA(), WorkA(), WorkA())
-
- def run(self):
- pass
-
- flow = A()
-
- # TODO: these assertions are wrong, the works are getting added under "flows" instead of "works"
- # state
- assert len(flow.state["structures"]["list"]["works"]) == len(flow.list) == 4
- assert list(flow.state["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
- assert all(
- flow.state["structures"]["list"]["works"][str(i)]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for i in range(4)
- )
- assert all(
- flow.state["structures"]["list"]["works"][str(i)]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None}
- for i in range(4)
- )
- assert all(flow.state["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4))
-
- # state_vars
- assert len(flow.state_vars["structures"]["list"]["works"]) == len(flow.list) == 4
- assert list(flow.state_vars["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
- assert all(
- flow.state_vars["structures"]["list"]["works"][str(i)]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for i in range(4)
- )
-
- # state_with_changes
- assert len(flow.state_with_changes["structures"]["list"]["works"]) == len(flow.list) == 4
- assert list(flow.state_with_changes["structures"]["list"]["works"].keys()) == ["0", "1", "2", "3"]
- assert all(
- flow.state_with_changes["structures"]["list"]["works"][str(i)]["vars"]
- == {
- "c": 0,
- "_url": "",
- "_future_url": "",
- "_port": 1,
- "_host": "127.0.0.1",
- "_paths": {},
- "_restarting": False,
- "_internal_ip": "",
- "_public_ip": "",
- "_display_name": "",
- "_cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "mounts": None,
- "shm_size": 0,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- }
- for i in range(4)
- )
- assert all(
- flow.state_with_changes["structures"]["list"]["works"][str(i)]["calls"]
- == {CacheCallsKeys.LATEST_CALL_HASH: None}
- for i in range(4)
- )
- assert all(flow.state_with_changes["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4))
-
- # set_state
- state = deepcopy(flow.state)
- state["structures"]["list"]["works"]["0"]["vars"]["c"] = 1
- flow.set_state(state)
- assert flow.list[0].c == 1
-
-
-def test_list_name():
- lst = List(EmptyFlow(), EmptyFlow())
- assert lst.name == "root"
- assert lst[0].name == "root.0"
- assert lst[1].name == "root.1"
-
- class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.list = List(EmptyFlow(), EmptyFlow())
-
- def run(self):
- pass
-
- root = RootFlow()
- assert root.name == "root"
- assert root.list.name == "root.list"
- assert root.list[0].name == "root.list.0"
- assert root.list[1].name == "root.list.1"
-
-
-class CounterWork(LightningWork):
- def __init__(self, cache_calls, parallel=False):
- super().__init__(cache_calls=cache_calls, parallel=parallel)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-@pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut exception")
-@pytest.mark.xfail(strict=False, reason="tchaton: Resolve this test.")
-@pytest.mark.parametrize("run_once_iterable", [False, True])
-@pytest.mark.parametrize("cache_calls", [False, True])
-@pytest.mark.parametrize("use_list", [False, True])
-def test_structure_with_iterate_and_fault_tolerance(run_once_iterable, cache_calls, use_list):
- class DummyFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- pass
-
- class RootFlow(LightningFlow):
- def __init__(self, use_list, run_once_iterable, cache_calls):
- super().__init__()
- self.looping = 0
- self.run_once_iterable = run_once_iterable
- self.restarting = False
- if use_list:
- self.iter = List(
- CounterWork(cache_calls),
- CounterWork(cache_calls),
- CounterWork(cache_calls),
- CounterWork(cache_calls),
- DummyFlow(),
- )
- else:
- self.iter = Dict(**{
- "0": CounterWork(cache_calls),
- "1": CounterWork(cache_calls),
- "2": CounterWork(cache_calls),
- "3": CounterWork(cache_calls),
- "4": DummyFlow(),
- })
-
- def run(self):
- for work_idx, work in self.experimental_iterate(enumerate(self.iter), run_once=self.run_once_iterable):
- if not self.restarting and work_idx == 1:
- # gives time to the delta to be sent.
- self.stop()
- if isinstance(work, str) and isinstance(self.iter, Dict):
- work = self.iter[work]
- work.run()
- if self.looping > 0:
- self.stop()
- self.looping += 1
-
- app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root.iter[0 if use_list else "0"].counter == 1
- assert app.root.iter[1 if use_list else "1"].counter == 0
- assert app.root.iter[2 if use_list else "2"].counter == 0
- assert app.root.iter[3 if use_list else "3"].counter == 0
-
- app = LightningApp(RootFlow(use_list, run_once_iterable, cache_calls))
- app.root.restarting = True
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- expected_value = 1 if run_once_iterable else 1 if cache_calls else 2
- assert app.root.iter[0 if use_list else "0"].counter == expected_value
- assert app.root.iter[1 if use_list else "1"].counter == expected_value
- assert app.root.iter[2 if use_list else "2"].counter == expected_value
- assert app.root.iter[3 if use_list else "3"].counter == expected_value
-
-
-class CheckpointCounter(LightningWork):
- def __init__(self):
- super().__init__(cache_calls=False)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class CheckpointFlow(LightningFlow):
- def __init__(self, collection, depth=0, exit=11):
- super().__init__()
- self.depth = depth
- self.exit = exit
- if depth == 0:
- self.counter = 0
-
- if depth >= 4:
- self.collection = collection
- else:
- self.flow = CheckpointFlow(collection, depth + 1)
-
- def run(self):
- if hasattr(self, "counter"):
- self.counter += 1
- if self.counter >= self.exit:
- self.stop()
- if self.depth >= 4:
- self.collection.run()
- else:
- self.flow.run()
-
-
-class SimpleCounterWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowDict(LightningFlow):
- def __init__(self):
- super().__init__()
- self.dict = Dict()
-
- def run(self):
- if "w" not in self.dict:
- self.dict["w"] = SimpleCounterWork()
-
- if self.dict["w"].status.stage == WorkStageStatus.SUCCEEDED:
- self.stop()
-
- self.dict["w"].run()
-
-
-def test_dict_with_queues():
- app = LightningApp(FlowDict())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class FlowList(LightningFlow):
- def __init__(self):
- super().__init__()
- self.list = List()
-
- def run(self):
- if not len(self.list):
- self.list.append(SimpleCounterWork())
-
- if self.list[-1].status.stage == WorkStageStatus.SUCCEEDED:
- self.stop()
-
- self.list[-1].run()
-
-
-def test_list_with_queues():
- app = LightningApp(FlowList())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkS(LightningWork):
- def __init__(self):
- super().__init__()
- self.payload = None
-
- def run(self):
- self.payload = Payload(2)
-
-
-class WorkD(LightningWork):
- def run(self, payload):
- assert payload.value == 2
-
-
-class FlowPayload(LightningFlow):
- def __init__(self):
- super().__init__()
- self.src = WorkS()
- self.dst = Dict(**{"0": WorkD(parallel=True), "1": WorkD(parallel=True)})
-
- def run(self):
- self.src.run()
- if self.src.payload:
- for work in self.dst.values():
- work.run(self.src.payload)
- if all(w.has_succeeded for w in self.dst.values()):
- self.stop()
-
-
-@pytest.mark.xfail(strict=False, reason="flaky")
-def test_structures_with_payload():
- app = LightningApp(FlowPayload(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
- os.remove("payload")
-
-
-def test_structures_have_name_on_init():
- """Test that the children in structures have the correct name assigned upon initialization."""
-
- class ChildWork(LightningWork):
- def run(self):
- pass
-
- class Collection(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.list_structure = List()
- self.list_structure.append(ChildWork())
-
- self.dict_structure = Dict()
- self.dict_structure["dict_child"] = ChildWork()
-
- flow = Collection()
- LightningApp(flow) # wrap in app to init all component names
- assert flow.list_structure[0].name == "root.list_structure.0"
- assert flow.dict_structure["dict_child"].name == "root.dict_structure.dict_child"
-
-
-class FlowWiStructures(LightningFlow):
- def __init__(self):
- super().__init__()
-
- self.ws = [EmptyFlow(), EmptyFlow()]
-
- self.ws1 = {"a": EmptyFlow(), "b": EmptyFlow()}
-
- self.ws2 = {
- "a": EmptyFlow(),
- "b": EmptyFlow(),
- "c": List(EmptyFlow(), EmptyFlow()),
- "d": Dict(**{"a": EmptyFlow()}),
- }
-
- def run(self):
- pass
-
-
-def test_flow_without_structures():
- flow = FlowWiStructures()
- assert isinstance(flow.ws, List)
- assert isinstance(flow.ws1, Dict)
diff --git a/tests/tests_app/test_imports.py b/tests/tests_app/test_imports.py
deleted file mode 100644
index 8c17a611552ec..0000000000000
--- a/tests/tests_app/test_imports.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import importlib
-import inspect
-import os
-import types
-from typing import TypeVar
-
-import lightning.app
-
-
-def _is_attribute(member, module):
- return all([
- hasattr(member, "__module__") and module.__name__ in member.__module__,
- not isinstance(member, TypeVar),
- not isinstance(member, types.ModuleType),
- ])
-
-
-def _find_exports(module):
- members = inspect.getmembers(module)
- attributes = {member[0] for member in members if _is_attribute(member[1], module)}
- public_attributes = list(filter(lambda attribute: not attribute.startswith("_"), attributes))
- exports = {attribute: module.__name__ for attribute in public_attributes}
-
- if module.__file__ is not None and "__init__.py" in module.__file__:
- root = os.path.dirname(module.__file__)
- submodule_paths = os.listdir(root)
- submodule_paths = [path for path in submodule_paths if not path.startswith("_")]
- submodules = [
- path.replace(".py", "")
- for path in submodule_paths
- if os.path.isdir(os.path.join(root, path)) or path.endswith(".py")
- ]
- for submodule in submodules:
- deeper_exports = _find_exports(importlib.import_module(f".{submodule}", module.__name__))
- exports = {**deeper_exports, **exports}
-
- return exports or {}
-
-
-def test_import_depth(
- ignore=[
- "lightning.app.cli",
- "lightning.app.components.serve.types",
- "lightning.app.core",
- "lightning.app.launcher",
- "lightning.app.runners",
- "lightning.app.utilities",
- ],
-):
- """This test ensures that any public exports (functions, classes, etc.) can be imported by users with at most a
- depth of two. This guarantees that everything user-facing can be imported with (at most) ``lightning.app.*.*``.
-
- Args:
- ignore: Sub-module paths to ignore (usually sub-modules that are not intended to be user-facing).
-
- """
- exports = _find_exports(lightning.app)
- depths = {export: len(path.replace("lightning.app", "").split(".")) for export, path in exports.items()}
- deep_exports = [export for export, depth in depths.items() if depth > 2]
- deep_exports = list(
- filter(lambda export: not any(exports[export].startswith(path) for path in ignore), deep_exports)
- )
- if len(deep_exports) > 0:
- raise RuntimeError(
- "Found exports with a depth greater than two. "
- "Either expose them at a higher level or make them private. "
- f"Found: {', '.join(sorted(f'{exports[export]}.{export}' for export in deep_exports))}"
- )
diff --git a/tests/tests_app/utilities/__init__.py b/tests/tests_app/utilities/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/__init__.py b/tests/tests_app/utilities/packaging/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/Dockerfile.cpu b/tests/tests_app/utilities/packaging/projects/Dockerfile.cpu
deleted file mode 100644
index f5bb8e842e9b6..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/Dockerfile.cpu
+++ /dev/null
@@ -1 +0,0 @@
-FROM pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.13-cuda11.7.1
diff --git a/tests/tests_app/utilities/packaging/projects/dock/__init__.py b/tests/tests_app/utilities/packaging/projects/dock/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dock/app.py b/tests/tests_app/utilities/packaging/projects/dock/app.py
deleted file mode 100644
index 683e04c51cb30..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dock/app.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import os
-import sys
-
-from lightning.app import LightningApp
-
-if __name__ == "__main__":
- sys.path.append(os.path.dirname(__file__))
-
- from compo.a.a import AA
- from compo.b.b import BB
-
- app = LightningApp(BB(AA()))
diff --git a/tests/tests_app/utilities/packaging/projects/dock/compo/__init__.py b/tests/tests_app/utilities/packaging/projects/dock/compo/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dock/compo/a/__init__.py b/tests/tests_app/utilities/packaging/projects/dock/compo/a/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dock/compo/a/a.py b/tests/tests_app/utilities/packaging/projects/dock/compo/a/a.py
deleted file mode 100644
index 64274fa7a3a1e..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dock/compo/a/a.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import logging
-
-from lightning.app import LightningWork
-
-logger = logging.getLogger(__name__)
-
-
-class AA(LightningWork):
- def __init__(self):
- super().__init__()
- self.message = None
-
- def run(self):
- self.message = "hello world!"
diff --git a/tests/tests_app/utilities/packaging/projects/dock/compo/b/__init__.py b/tests/tests_app/utilities/packaging/projects/dock/compo/b/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dock/compo/b/b.py b/tests/tests_app/utilities/packaging/projects/dock/compo/b/b.py
deleted file mode 100644
index 8f2622363861e..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dock/compo/b/b.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class BB(LightningFlow):
- def __init__(self, work):
- super().__init__()
- self.work = work
-
- def run(self):
- self.stop()
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/__init__.py b/tests/tests_app/utilities/packaging/projects/dockerfile/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/app.py b/tests/tests_app/utilities/packaging/projects/dockerfile/app.py
deleted file mode 100644
index 3481d181686d0..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dockerfile/app.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import os
-import sys
-
-from lightning.app import LightningApp
-
-if __name__ == "__main__":
- sys.path.append(os.path.dirname(__file__))
- from comp_dockerfile.a.a import AAA
- from comp_dockerfile.b.b import BBB
-
- app = LightningApp(BBB(AAA()))
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/__init__.py b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/Dockerfile b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/Dockerfile
deleted file mode 100644
index f5bb8e842e9b6..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/Dockerfile
+++ /dev/null
@@ -1 +0,0 @@
-FROM pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.13-cuda11.7.1
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/__init__.py b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/a.py b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/a.py
deleted file mode 100644
index d7195668bdceb..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/a/a.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from lightning.app import LightningWork
-
-
-class AAA(LightningWork):
- def run(self):
- pass
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/b/__init__.py b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/b/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/b/b.py b/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/b/b.py
deleted file mode 100644
index 8a12b3fd9ea5d..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/dockerfile/comp_dockerfile/b/b.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class BBB(LightningFlow):
- def __init__(self, work):
- super().__init__()
- self.work = work
-
- def run(self):
- self.stop()
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/__init__.py b/tests/tests_app/utilities/packaging/projects/no_req/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/app.py b/tests/tests_app/utilities/packaging/projects/no_req/app.py
deleted file mode 100644
index 5e7155e7d3e5c..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/no_req/app.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import os
-import sys
-
-from lightning.app import LightningApp
-
-if __name__ == "__main__":
- sys.path.append(os.path.dirname(__file__))
-
- from comp.a.a import AA
- from comp.b.b import BB
-
- app = LightningApp(BB(AA()))
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/comp/__init__.py b/tests/tests_app/utilities/packaging/projects/no_req/comp/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/comp/a/__init__.py b/tests/tests_app/utilities/packaging/projects/no_req/comp/a/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/comp/a/a.py b/tests/tests_app/utilities/packaging/projects/no_req/comp/a/a.py
deleted file mode 100644
index c7c529ba7c1d0..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/no_req/comp/a/a.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import pandas # noqa F401
-
-from lightning.app import LightningWork
-
-
-class AA(LightningWork):
- def run(self):
- pass
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/comp/b/__init__.py b/tests/tests_app/utilities/packaging/projects/no_req/comp/b/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/no_req/comp/b/b.py b/tests/tests_app/utilities/packaging/projects/no_req/comp/b/b.py
deleted file mode 100644
index 8f2622363861e..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/no_req/comp/b/b.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class BB(LightningFlow):
- def __init__(self, work):
- super().__init__()
- self.work = work
-
- def run(self):
- self.stop()
diff --git a/tests/tests_app/utilities/packaging/projects/req/__init__.py b/tests/tests_app/utilities/packaging/projects/req/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/req/app.py b/tests/tests_app/utilities/packaging/projects/req/app.py
deleted file mode 100644
index ea23603efc478..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/req/app.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import os
-import sys
-
-from lightning.app import LightningApp
-
-if __name__ == "__main__":
- sys.path.append(os.path.dirname(__file__))
-
- from comp_req.a.a import A
- from comp_req.b.b import B
-
- app = LightningApp(B(A()))
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/__init__.py b/tests/tests_app/utilities/packaging/projects/req/comp_req/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/a/__init__.py b/tests/tests_app/utilities/packaging/projects/req/comp_req/a/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/a/a.py b/tests/tests_app/utilities/packaging/projects/req/comp_req/a/a.py
deleted file mode 100644
index 6a77dc180b3b8..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/req/comp_req/a/a.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import pandas # noqa F401
-
-from lightning.app import LightningWork
-
-
-class A(LightningWork):
- def run(self):
- pass
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/a/requirements.txt b/tests/tests_app/utilities/packaging/projects/req/comp_req/a/requirements.txt
deleted file mode 100644
index 44ccabb86bcec..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/req/comp_req/a/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-pandas
-pytorch_lightning==1.5.9
-git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/b/__init__.py b/tests/tests_app/utilities/packaging/projects/req/comp_req/b/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_app/utilities/packaging/projects/req/comp_req/b/b.py b/tests/tests_app/utilities/packaging/projects/req/comp_req/b/b.py
deleted file mode 100644
index 6d6f8ee3ddf11..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/req/comp_req/b/b.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from lightning.app import LightningFlow
-
-
-class B(LightningFlow):
- def __init__(self, work):
- super().__init__()
- self.work = work
-
- def run(self):
- self.stop()
diff --git a/tests/tests_app/utilities/packaging/projects/requirements.txt b/tests/tests_app/utilities/packaging/projects/requirements.txt
deleted file mode 100644
index 7bf0698bf1180..0000000000000
--- a/tests/tests_app/utilities/packaging/projects/requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-cloud-stars
diff --git a/tests/tests_app/utilities/packaging/test_app_config.py b/tests/tests_app/utilities/packaging/test_app_config.py
deleted file mode 100644
index d6346d67ffa31..0000000000000
--- a/tests/tests_app/utilities/packaging/test_app_config.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import pathlib
-
-from lightning.app.utilities.packaging.app_config import AppConfig, _get_config_file
-
-
-def _make_empty_config_file(folder):
- file = pathlib.Path(folder / ".lightning")
- file.parent.mkdir(parents=True, exist_ok=True)
- file.touch()
- return file
-
-
-def test_get_config_file(tmpdir):
- _ = _make_empty_config_file(tmpdir)
- config_file1 = _make_empty_config_file(tmpdir)
-
- assert _get_config_file(tmpdir) == pathlib.Path(tmpdir, ".lightning")
- assert _get_config_file(config_file1) == pathlib.Path(tmpdir, ".lightning")
-
-
-def test_app_config_save_load(tmpdir):
- config = AppConfig("my_app")
- config.save_to_file(tmpdir / ".lightning")
- loaded_config = AppConfig.load_from_file(tmpdir / ".lightning")
- assert config == loaded_config
-
- config = AppConfig("my_app2")
- config.save_to_dir(tmpdir)
- loaded_config = AppConfig.load_from_dir(tmpdir)
- assert config == loaded_config
-
-
-def test_app_config_default_name():
- """Test that the default name gets auto-generated."""
- config = AppConfig()
- assert config.name
diff --git a/tests/tests_app/utilities/packaging/test_build_spec.py b/tests/tests_app/utilities/packaging/test_build_spec.py
deleted file mode 100644
index a34409a9f568b..0000000000000
--- a/tests/tests_app/utilities/packaging/test_build_spec.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import logging
-import os
-import sys
-from unittest.mock import Mock
-
-from lightning.app.testing import LightningTestApp, application_testing
-from lightning.app.utilities.packaging.build_config import BuildConfig
-
-from tests_app import _TESTS_ROOT
-
-EXTRAS_ARGS = ["--blocking", "False", "--multiprocess", "--open-ui", "False"]
-
-
-class NoRequirementsLightningTestApp(LightningTestApp):
- def on_after_run_once(self):
- assert self.root.work.local_build_config.requirements == []
- assert self.root.work.cloud_build_config.requirements == []
- return super().on_after_run_once()
-
-
-def test_build_config_no_requirements():
- command_line = [os.path.join(_TESTS_ROOT, "utilities/packaging/projects/no_req/app.py")]
- application_testing(NoRequirementsLightningTestApp, command_line + EXTRAS_ARGS)
- sys.path = sys.path[:-1]
-
-
-def test_build_config_requirements_provided():
- spec = BuildConfig(requirements=["dask", "./projects/req/comp_req/a/requirements.txt"])
- assert spec.requirements == [
- "dask",
- "pandas",
- "pytorch_lightning==1.5.9",
- "git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0",
- ]
- assert spec == BuildConfig.from_dict(spec.to_dict())
-
-
-class BuildSpecTest(BuildConfig):
- def build_commands(self):
- return super().build_commands() + ["pip install redis"]
-
-
-def test_build_config_invalid_requirements():
- spec = BuildSpecTest(requirements=["./projects/requirements.txt"])
- assert spec.requirements == ["cloud-stars"]
- assert spec.build_commands() == ["pip install redis"]
-
-
-def test_build_config_dockerfile_provided():
- spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu")
- assert not spec.requirements
- # ugly hack due to replacing `pytorch_lightning string
- assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0]
-
-
-class DockerfileLightningTestApp(LightningTestApp):
- def on_after_run_once(self):
- print(self.root.work.local_build_config.dockerfile)
- # ugly hack due to replacing `pytorch_lightning string
- assert "pytorchlightning/pytorch_" + "lightning" in self.root.work.local_build_config.dockerfile.data[0]
- return super().on_after_run_once()
-
-
-def test_build_config_dockerfile():
- command_line = [os.path.join(_TESTS_ROOT, "utilities/packaging/projects/dockerfile/app.py")]
- application_testing(DockerfileLightningTestApp, command_line + EXTRAS_ARGS)
- sys.path = sys.path[:-1]
-
-
-class RequirementsLightningTestApp(LightningTestApp):
- def on_after_run_once(self):
- assert self.root.work.local_build_config.requirements == [
- "git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0",
- "pandas",
- "pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string
- ]
- return super().on_after_run_once()
-
-
-def test_build_config_requirements():
- command_line = [os.path.join(_TESTS_ROOT, "utilities/packaging/projects/req/app.py")]
- application_testing(RequirementsLightningTestApp, command_line + EXTRAS_ARGS)
- sys.path = sys.path[:-1]
-
-
-def test_build_config_requirements_warns(monkeypatch, caplog):
- requirements = ["foo", "bar"]
- bc = BuildConfig(requirements=requirements)
- monkeypatch.setattr(bc, "_find_requirements", lambda *_, **__: ["baz"])
- work = Mock()
- with caplog.at_level(logging.INFO):
- bc.on_work_init(work)
- assert "requirements.txt' exists with ['baz'] but ['foo', 'bar']" in caplog.text
- assert bc.requirements == requirements # they are not merged or replaced
-
-
-def test_build_config_dockerfile_warns(monkeypatch, caplog):
- dockerfile = "foo"
- bc = BuildConfig(dockerfile=dockerfile)
- monkeypatch.setattr(bc, "_find_dockerfile", lambda *_, **__: "bar")
- work = Mock()
- with caplog.at_level(logging.INFO):
- bc.on_work_init(work)
- assert "exists at 'bar' but 'foo' was passed" in caplog.text
- assert bc.dockerfile == dockerfile # they are not merged or replaced
diff --git a/tests/tests_app/utilities/packaging/test_cloud_compute.py b/tests/tests_app/utilities/packaging/test_cloud_compute.py
deleted file mode 100644
index bc17c53555eb9..0000000000000
--- a/tests/tests_app/utilities/packaging/test_cloud_compute.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import pytest
-from lightning.app import CloudCompute
-from lightning.app.storage import Mount
-
-
-def test_cloud_compute_names():
- assert CloudCompute().name == "cpu-small"
- assert CloudCompute("cpu-small").name == "cpu-small"
- assert CloudCompute("coconut").name == "coconut" # the backend is responsible for validation of names
-
-
-def test_cloud_compute_shared_memory():
- cloud_compute = CloudCompute("gpu", shm_size=1100)
- assert cloud_compute.shm_size == 1100
-
- cloud_compute = CloudCompute("gpu")
- assert cloud_compute.shm_size == 1024
-
- cloud_compute = CloudCompute("cpu")
- assert cloud_compute.shm_size == 0
-
-
-def test_cloud_compute_with_mounts():
- mount_1 = Mount(source="s3://foo/", mount_path="/foo")
- mount_2 = Mount(source="s3://foo/bar/", mount_path="/bar")
-
- cloud_compute = CloudCompute("gpu", mounts=mount_1)
- assert cloud_compute.mounts == mount_1
-
- cloud_compute = CloudCompute("gpu", mounts=[mount_1, mount_2])
- assert cloud_compute.mounts == [mount_1, mount_2]
-
- cc_dict = cloud_compute.to_dict()
- assert "mounts" in cc_dict
- assert cc_dict["mounts"] == [
- {"mount_path": "/foo", "source": "s3://foo/"},
- {"mount_path": "/bar", "source": "s3://foo/bar/"},
- ]
-
- assert CloudCompute.from_dict(cc_dict) == cloud_compute
-
-
-def test_cloud_compute_with_non_unique_mount_root_dirs():
- mount_1 = Mount(source="s3://foo/", mount_path="/foo")
- mount_2 = Mount(source="s3://foo/bar/", mount_path="/foo")
-
- with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
- CloudCompute("gpu", mounts=[mount_1, mount_2])
-
-
-def test_cloud_compute_clone():
- c1 = CloudCompute("gpu")
- c2 = c1.clone()
-
- assert isinstance(c2, CloudCompute)
-
- c1_dict = c1.to_dict()
- c2_dict = c2.to_dict()
-
- assert len(c1_dict) == len(c2_dict)
-
- for k in c1_dict:
- if k == "_internal_id":
- assert c1_dict[k] != c2_dict[k]
- else:
- assert c1_dict[k] == c2_dict[k]
-
-
-def test_interruptible(monkeypatch):
- """Test interruptible can be enabled with env variables and for GPU only."""
- with pytest.raises(ValueError, match="isn't supported yet"):
- CloudCompute("gpu", interruptible=True)
-
- monkeypatch.setenv("LIGHTNING_INTERRUPTIBLE_WORKS", "1")
- with pytest.raises(ValueError, match="supported only with GPU"):
- CloudCompute("cpu", interruptible=True)
-
- cloud_compute = CloudCompute("gpu", interruptible=True)
- assert hasattr(cloud_compute, "interruptible")
- # TODO: To be removed once the platform is updated.
- assert hasattr(cloud_compute, "preemptible")
diff --git a/tests/tests_app/utilities/packaging/test_docker.py b/tests/tests_app/utilities/packaging/test_docker.py
deleted file mode 100644
index 95a9edb1882c5..0000000000000
--- a/tests/tests_app/utilities/packaging/test_docker.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import os
-from time import sleep, time
-
-import pytest
-from lightning.app import LightningWork
-from lightning.app.core.queues import QueuingSystem
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.imports import _is_docker_available
-from lightning.app.utilities.load_app import load_app_from_file
-from lightning.app.utilities.packaging.docker import DockerRunner
-from lightning.app.utilities.redis import check_if_redis_running
-
-
-@pytest.mark.xfail(strict=False, reason="FIXME (tchaton)")
-@pytest.mark.skipif(not _is_docker_available(), reason="docker is required for this test.")
-@pytest.mark.skipif(not check_if_redis_running(), reason="redis is required for this test.")
-@_RunIf(skip_windows=True)
-def test_docker_runner():
- """This test validates that the lightning run work is executable within a container and deltas are sent back
- through the Redis caller_queue."""
- queues = QueuingSystem.REDIS
- queue_id = f"test_docker_runner_{str(int(time()))}"
- app_file = os.path.join(os.path.dirname(__file__), "projects/dock/app.py")
- app = load_app_from_file(app_file)
- work: LightningWork = app.root.work
-
- call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c"
- work._calls[call_hash] = {
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "run_started_counter": 1,
- "statuses": [],
- }
-
- # The script_path needs to be relative to the container.
- docker_runner = DockerRunner(
- "/home/tests/utilities/packaging/projects/dock/app.py", work, queue_id, create_base=True
- )
- docker_runner.run()
-
- caller_queue = queues.get_caller_queue(work_name=work.name, queue_id=queue_id)
- caller_queue.put({
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "state": work.state,
- })
- delta_queue = queues.get_delta_queue(queue_id=queue_id)
- delta_1 = delta_queue.get()
- delta_2 = delta_queue.get()
- delta_3 = delta_queue.get()
-
- def get_item(delta):
- return delta.delta.to_dict()["iterable_item_added"]
-
- assert delta_1.id == "root.work"
- assert delta_2.id == "root.work"
- assert delta_3.id == "root.work"
- assert get_item(delta_1)[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "starting"
- assert delta_2.delta.to_dict()["type_changes"]["root['vars']['message']"]["new_value"] == "hello world!"
- assert get_item(delta_3)[f"root['calls']['{call_hash}']['statuses'][1]"]["stage"] == "succeeded"
- sleep(1)
- assert "Starting WorkRunner" in docker_runner.get_container_logs()
- docker_runner.kill()
diff --git a/tests/tests_app/utilities/packaging/test_lightning_utils.py b/tests/tests_app/utilities/packaging/test_lightning_utils.py
deleted file mode 100644
index 339c75cf9f0aa..0000000000000
--- a/tests/tests_app/utilities/packaging/test_lightning_utils.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import glob
-import os
-from unittest import mock
-
-import pytest
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.git import check_github_repository, get_dir_name
-from lightning.app.utilities.packaging import lightning_utils
-from lightning.app.utilities.packaging.lightning_utils import (
- _prepare_lightning_wheels_and_requirements,
- _verify_lightning_version,
- get_dist_path_if_editable_install,
-)
-from lightning_utilities.core.imports import module_available
-
-
-@pytest.mark.skipif(not module_available("lightning"), reason="TODO: should work for lightning.app too")
-def test_prepare_lightning_wheels_and_requirement(tmpdir):
- """This test ensures the lightning source gets packaged inside the lightning repo."""
- package_name = "lightning"
- if not get_dist_path_if_editable_install(package_name):
- pytest.skip("Requires --editable install")
-
- git_dir_name = get_dir_name() if check_github_repository() else None
- if git_dir_name != package_name:
- pytest.skip("Needs to be run from within the repo")
-
- cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir, package_name=package_name)
- assert len(os.listdir(tmpdir)) == 1
- assert len(glob.glob(str(tmpdir / "lightning-*.tar.gz"))) == 1
-
- cleanup_handle()
- assert os.listdir(tmpdir) == []
-
-
-def _mocked_get_dist_path_if_editable_install(*args, **kwargs):
- return None
-
-
-@mock.patch(
- "lightning.app.utilities.packaging.lightning_utils.get_dist_path_if_editable_install",
- new=_mocked_get_dist_path_if_editable_install,
-)
-def test_prepare_lightning_wheels_and_requirement_for_packages_installed_in_editable_mode(tmpdir):
- """This test ensures the source does not get packaged inside the lightning repo if not installed in editable
- mode."""
- cleanup_handle = _prepare_lightning_wheels_and_requirements(tmpdir)
- assert cleanup_handle is None
-
-
-@pytest.mark.xfail(strict=False, reason="TODO: Find a way to check for the latest version")
-@_RunIf(skip_windows=True)
-def test_verify_lightning_version(monkeypatch):
- monkeypatch.setattr(lightning_utils, "__version__", "0.0.1")
- monkeypatch.setattr(lightning_utils, "_fetch_latest_version", lambda _: "0.0.2")
-
- # Not latest version
- with pytest.raises(Exception, match="You need to use the latest version of Lightning"):
- _verify_lightning_version()
-
- # Latest version
- monkeypatch.setattr(lightning_utils, "__version__", "0.0.1")
- monkeypatch.setattr(lightning_utils, "_fetch_latest_version", lambda _: "0.0.1")
- _verify_lightning_version()
diff --git a/tests/tests_app/utilities/test_app_commands.py b/tests/tests_app/utilities/test_app_commands.py
deleted file mode 100644
index 117e580b7c52a..0000000000000
--- a/tests/tests_app/utilities/test_app_commands.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os
-import sys
-
-import pytest
-from lightning.app.utilities.app_commands import CommandLines, _execute_app_commands, _extract_commands_from_file
-from lightning.app.utilities.exceptions import MisconfigurationException
-
-
-@pytest.mark.parametrize(
- ("filename", "expected_commands", "expected_line_numbers"),
- [
- ("single_command.txt", ['echo "foo"'], [1]),
- ("multiple_commands.txt", ['echo "foo"', 'echo "bar"'], [1, 2]),
- ("commands_with_mixed_comments_1.txt", ['echo "foo"', 'echo "bar"'], [1, 3]),
- ("commands_with_mixed_comments_2.txt", ['echo "foo"', 'echo "bar"'], [2, 4]),
- ("command_after_first_non_comment_line.txt", ['echo "foo"', 'echo "bar"'], [2, 4]),
- ("bang_not_at_start_of_line.txt", ['echo "foo"'], [2]),
- ("space_between_bang_and_command.txt", ['echo "foo"'], [1]),
- ("multiple_spaces_between_band_and_command.txt", ['echo "foo"'], [1]),
- ("app_commands_to_ignore.txt", [], []),
- ],
-)
-def test_extract_app_commands_from_file(filename, expected_commands, expected_line_numbers):
- dir_path = os.path.dirname(os.path.realpath(__file__))
- test_file_path = os.path.join(dir_path, "testdata", "app_commands", filename)
-
- res = _extract_commands_from_file(file_name=test_file_path)
-
- assert res.file == test_file_path
- assert res.commands == expected_commands
- assert res.line_numbers == expected_line_numbers
-
-
-def test_execute_app_commands_runs_single_command(capfd):
- cl = CommandLines(
- file="foo.txt",
- commands=['echo "foo"'],
- line_numbers=[1],
- )
- _execute_app_commands(cl)
- out, _ = capfd.readouterr()
- assert "foo" in out
-
-
-def test_execute_app_commands_runs_multiple_commands(capfd):
- cl = CommandLines(
- file="foo.txt",
- commands=['echo "foo"', 'echo "bar"'],
- line_numbers=[1, 2],
- )
- _execute_app_commands(cl)
- out, _ = capfd.readouterr()
- assert "foo" in out
- assert "bar" in out
-
-
-@pytest.mark.skipif(sys.platform.startswith("win"), reason="env command is not available on windows")
-def test_execute_app_commands_runs_with_env_vars_patched(monkeypatch, capfd):
- monkeypatch.setenv("TEST_EXECUTE_APP_COMMANDS_RUNS_WITH_ENV_VARS_PATCHED", "TRUE")
- cl = CommandLines(
- file="foo.txt",
- commands=["env"],
- line_numbers=[1],
- )
- _execute_app_commands(cl)
- out, _ = capfd.readouterr()
- assert "TEST_EXECUTE_APP_COMMANDS_RUNS_WITH_ENV_VARS_PATCHED=TRUE" in out
-
-
-def test_execute_app_commands_raises_appropriate_traceback_on_error(capfd):
- cl = CommandLines(
- file="foo.txt",
- commands=['echo "foo"', 'CommandDoesNotExist "somearg"'],
- line_numbers=[1, 3],
- )
- with pytest.raises(
- MisconfigurationException,
- match='There was a problem on line 3 of foo.txt while executing the command: CommandDoesNotExist "somearg"',
- ):
- _execute_app_commands(cl)
- out, err = capfd.readouterr()
- assert "foo" in out
- if sys.platform.startswith("linux"):
- assert "CommandDoesNotExist: not found" in err
- elif sys.platform.startswith("darwin"):
- assert "CommandDoesNotExist: command not found" in err
- else:
- assert "CommandDoesNotExist' is not recognized" in err
diff --git a/tests/tests_app/utilities/test_app_helpers.py b/tests/tests_app/utilities/test_app_helpers.py
deleted file mode 100644
index 55e48cceef76c..0000000000000
--- a/tests/tests_app/utilities/test_app_helpers.py
+++ /dev/null
@@ -1,205 +0,0 @@
-from functools import partial
-from unittest import mock
-
-import pytest
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.core.flow import _RootFlow
-from lightning.app.frontend import StaticWebFrontend
-from lightning.app.utilities.app_helpers import (
- AppStatePlugin,
- BaseStatePlugin,
- InMemoryStateStore,
- StateStore,
- _is_headless,
- _MagicMockJsonSerializable,
- is_overridden,
- is_static_method,
-)
-from lightning.app.utilities.exceptions import LightningAppStateException
-
-
-class Work(LightningWork):
- def run(self):
- pass
-
-
-class Flow(LightningFlow):
- def run(self):
- pass
-
-
-def test_is_overridden():
- # edge cases
- assert not is_overridden("whatever", None)
- with pytest.raises(ValueError, match="Expected a parent"):
- is_overridden("whatever", object())
- flow = Flow()
- assert not is_overridden("whatever", flow)
- assert not is_overridden("whatever", flow, parent=Flow)
- # normal usage
- assert is_overridden("run", flow)
- work = Work()
- assert is_overridden("run", work)
-
-
-def test_simple_app_store():
- store = InMemoryStateStore()
- user_id = "1234"
- store.add(user_id)
- state = {"data": user_id}
- store.set_app_state(user_id, state)
- store.set_served_state(user_id, state)
- store.set_served_session_id(user_id, user_id)
- assert store.get_app_state(user_id) == state
- assert store.get_served_state(user_id) == state
- assert store.get_served_session_id(user_id) == user_id
- store.remove(user_id)
- assert isinstance(store, StateStore)
-
-
-@mock.patch("lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES", 120)
-def test_simple_app_store_warning():
- store = InMemoryStateStore()
- user_id = "1234"
- store.add(user_id)
- state = {"data": "I'm a state that's larger than 120 bytes"}
- with pytest.raises(LightningAppStateException, match="is larger than the"):
- store.set_app_state(user_id, state)
-
-
-def test_base_state_plugin():
- class DummyStatePlugin(BaseStatePlugin):
- def should_update_app(self, deep_diff):
- super().should_update_app(deep_diff)
-
- def get_context(self):
- super().get_context()
-
- def render_non_authorized(self):
- super().render_non_authorized()
-
- plugin = DummyStatePlugin()
- plugin.should_update_app(None)
- plugin.get_context()
- plugin.render_non_authorized()
-
- plugin = AppStatePlugin()
- plugin.should_update_app(None)
- plugin.get_context()
- plugin.render_non_authorized()
-
-
-def test_is_static_method():
- class A:
- @staticmethod
- def a(self):
- pass
-
- @staticmethod
- def b(a):
- pass
-
- def c(self):
- pass
-
- assert is_static_method(A, "a")
- assert is_static_method(A, "b")
- assert not is_static_method(A, "c")
-
-
-class FlowWithURLLayout(Flow):
- def configure_layout(self):
- return {"name": "test", "content": "https://appurl"}
-
-
-class FlowWithFrontend(Flow):
- def configure_layout(self):
- return StaticWebFrontend(".")
-
-
-class FlowWithMockedFrontend(Flow):
- def configure_layout(self):
- return _MagicMockJsonSerializable()
-
-
-class FlowWithMockedContent(Flow):
- def configure_layout(self):
- return [{"name": "test", "content": _MagicMockJsonSerializable()}]
-
-
-class NestedFlow(Flow):
- def __init__(self):
- super().__init__()
-
- self.flow = Flow()
-
-
-class NestedFlowWithURLLayout(Flow):
- def __init__(self):
- super().__init__()
-
- self.flow = FlowWithURLLayout()
-
-
-class WorkWithStringLayout(Work):
- def configure_layout(self):
- return "http://appurl"
-
-
-class WorkWithMockedFrontendLayout(Work):
- def configure_layout(self):
- return _MagicMockJsonSerializable()
-
-
-class WorkWithFrontendLayout(Work):
- def configure_layout(self):
- return StaticWebFrontend(".")
-
-
-class WorkWithNoneLayout(Work):
- def configure_layout(self):
- return None
-
-
-class FlowWithWorkLayout(Flow):
- def __init__(self, work):
- super().__init__()
-
- self.work = work()
-
- def configure_layout(self):
- return {"name": "test", "content": self.work}
-
-
-class WorkClassRootFlow(_RootFlow):
- """A ``_RootFlow`` which takes a work class rather than the work itself."""
-
- def __init__(self, work):
- super().__init__(work())
-
-
-@pytest.mark.parametrize(
- ("flow", "expected"),
- [
- (Flow, True),
- (FlowWithURLLayout, False),
- (FlowWithFrontend, False),
- (FlowWithMockedFrontend, False),
- (FlowWithMockedContent, False),
- (NestedFlow, True),
- (NestedFlowWithURLLayout, False),
- (partial(WorkClassRootFlow, WorkWithStringLayout), False),
- (partial(WorkClassRootFlow, WorkWithMockedFrontendLayout), False),
- (partial(WorkClassRootFlow, WorkWithFrontendLayout), False),
- (partial(WorkClassRootFlow, WorkWithNoneLayout), True),
- (partial(FlowWithWorkLayout, Work), False),
- (partial(FlowWithWorkLayout, WorkWithStringLayout), False),
- (partial(FlowWithWorkLayout, WorkWithMockedFrontendLayout), False),
- (partial(FlowWithWorkLayout, WorkWithFrontendLayout), False),
- (partial(FlowWithWorkLayout, WorkWithNoneLayout), True),
- ],
-)
-def test_is_headless(flow, expected):
- flow = flow()
- app = LightningApp(flow)
- assert _is_headless(app) == expected
diff --git a/tests/tests_app/utilities/test_app_logs.py b/tests/tests_app/utilities/test_app_logs.py
deleted file mode 100644
index ccbff41b7c814..0000000000000
--- a/tests/tests_app/utilities/test_app_logs.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from datetime import datetime
-from time import sleep
-from unittest.mock import MagicMock
-
-from lightning.app.utilities.app_logs import _LogEvent
-
-
-def test_log_event():
- event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
- sleep(0.1)
- event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
- assert event_1 < event_2
- assert event_1 <= event_2
diff --git a/tests/tests_app/utilities/test_auth.py b/tests/tests_app/utilities/test_auth.py
deleted file mode 100644
index a8d11e8f573c2..0000000000000
--- a/tests/tests_app/utilities/test_auth.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from typing import Dict
-
-import pytest
-from lightning.app.utilities.auth import _credential_string_to_basic_auth_params
-
-
-@pytest.mark.parametrize(
- ("credential_string", "expected_parsed", "exception_message"),
- [
- ("", None, "Credential string must follow the format username:password; the provided one ('') does not."),
- (":", None, "Username cannot be empty."),
- (":pass", None, "Username cannot be empty."),
- ("user:", None, "Password cannot be empty."),
- ("user:pass", {"username": "user", "password": "pass"}, ""),
- ],
-)
-def test__credential_string_to_basic_auth_params(
- credential_string: str, expected_parsed: Dict[str, str], exception_message: str
-):
- if expected_parsed is not None:
- assert _credential_string_to_basic_auth_params(credential_string) == expected_parsed
- else:
- with pytest.raises(ValueError) as exception:
- _credential_string_to_basic_auth_params(credential_string)
- assert exception_message == str(exception.value)
diff --git a/tests/tests_app/utilities/test_cli_helpers.py b/tests/tests_app/utilities/test_cli_helpers.py
deleted file mode 100644
index 8d6684080d7a5..0000000000000
--- a/tests/tests_app/utilities/test_cli_helpers.py
+++ /dev/null
@@ -1,196 +0,0 @@
-import os
-import sys
-from unittest.mock import Mock, patch
-
-import arrow
-import lightning.app
-import pytest
-from lightning.app.utilities.cli_helpers import (
- _arrow_time_callback,
- _check_environment_and_redirect,
- _format_input_env_variables,
- _get_newer_version,
-)
-
-
-def test_format_input_env_variables():
- with pytest.raises(Exception, match="Invalid format of environment variable"):
- _format_input_env_variables(("invalid-env",))
-
- with pytest.raises(Exception, match="Invalid format of environment variable"):
- _format_input_env_variables(("=invalid",))
-
- with pytest.raises(Exception, match="Invalid format of environment variable"):
- _format_input_env_variables(("=invalid=",))
-
- with pytest.raises(Exception, match="is duplicated. Please only include it once."):
- _format_input_env_variables((
- "FOO=bar",
- "FOO=bar",
- ))
-
- with pytest.raises(
- Exception,
- match="is not a valid name. It is only allowed to contain digits 0-9, letters A-Z",
- ):
- _format_input_env_variables(("*FOO#=bar",))
-
- assert _format_input_env_variables(("FOO=bar", "BLA=bloz")) == {"FOO": "bar", "BLA": "bloz"}
-
-
-def test_arrow_time_callback():
- # Check ISO 8601 variations
- assert _arrow_time_callback(Mock(), Mock(), "2022.08.23") == arrow.Arrow(2022, 8, 23)
-
- assert _arrow_time_callback(Mock(), Mock(), "2022.08.23 12:34") == arrow.Arrow(2022, 8, 23, 12, 34)
-
- assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34") == arrow.Arrow(2022, 8, 23, 12, 34)
-
- assert _arrow_time_callback(Mock(), Mock(), "2022-08-23 12:34:00.000") == arrow.Arrow(2022, 8, 23, 12, 34)
-
- # Just check humanized format is parsed
- assert type(_arrow_time_callback(Mock(), Mock(), "48 hours ago")) is arrow.Arrow
-
- assert type(_arrow_time_callback(Mock(), Mock(), "60 minutes ago")) is arrow.Arrow
-
- assert type(_arrow_time_callback(Mock(), Mock(), "120 seconds ago")) is arrow.Arrow
-
- # Check raising errors
- with pytest.raises(Exception, match="cannot parse time Mon"):
- _arrow_time_callback(Mock(), Mock(), "Mon")
-
- with pytest.raises(Exception, match="cannot parse time Mon Sep 08 16:41:45 2022"):
- _arrow_time_callback(Mock(), Mock(), "Mon Sep 08 16:41:45 2022")
-
- with pytest.raises(Exception, match="cannot parse time 2022.125.12"):
- _arrow_time_callback(Mock(), Mock(), "2022.125.12")
-
- with pytest.raises(Exception, match="cannot parse time 1 time unit ago"):
- _arrow_time_callback(Mock(), Mock(), "1 time unit ago")
-
-
-@pytest.mark.parametrize(
- ("response", "current_version", "newer_version"),
- [
- (
- {
- "info": {
- "version": "2.0.0",
- "yanked": False,
- },
- "releases": {
- "1.0.0": {},
- "2.0.0": {},
- },
- },
- "1.0.0",
- "2.0.0",
- ),
- (
- {
- "info": {
- "version": "2.0.0",
- "yanked": True,
- },
- "releases": {
- "1.0.0": {},
- "2.0.0": {},
- },
- },
- "1.0.0",
- None,
- ),
- (
- {
- "info": {
- "version": "1.0.0",
- "yanked": False,
- },
- "releases": {
- "1.0.0": {},
- },
- },
- "1.0.0",
- None,
- ),
- (
- {
- "info": {
- "version": "2.0.0rc0",
- "yanked": False,
- },
- "releases": {
- "1.0.0": {},
- "2.0.0": {},
- },
- },
- "1.0.0",
- None,
- ),
- (
- {
- "info": {
- "version": "2.0.0",
- "yanked": False,
- },
- "releases": {
- "1.0.0": {},
- "2.0.0": {},
- },
- },
- "1.0.0dev",
- None,
- ),
- ({"this wil trigger an error": True}, "1.0.0", None),
- ({}, "1.0.0rc0", None),
- ],
-)
-@patch("lightning.app.utilities.cli_helpers.requests")
-def test_get_newer_version(mock_requests, response, current_version, newer_version):
- mock_requests.get().json.return_value = response
-
- lightning.app.utilities.cli_helpers.__version__ = current_version
-
- _get_newer_version.cache_clear()
- assert _get_newer_version() == newer_version
-
-
-@patch("lightning.app.utilities.cli_helpers._redirect_command")
-def test_check_environment_and_redirect(mock_redirect_command, tmpdir, monkeypatch):
- # Ensure that the test fails if it tries to redirect
- mock_redirect_command.side_effect = RuntimeError
-
- # Test normal executable on the path
- # Ensure current executable is on the path
- monkeypatch.setenv("PATH", f"{os.path.dirname(sys.executable)}")
-
- assert _check_environment_and_redirect() is None
-
- # Test executable on the path with redirect
- fake_python_path = os.path.join(tmpdir, "python")
-
- os.symlink(sys.executable, fake_python_path)
-
- monkeypatch.setenv("PATH", f"{tmpdir}")
- assert _check_environment_and_redirect() is None
-
- os.remove(fake_python_path)
-
- descriptor = os.open(
- fake_python_path,
- flags=(
- os.O_WRONLY # access mode: write only
- | os.O_CREAT # create if not exists
- | os.O_TRUNC # truncate the file to zero
- ),
- mode=0o777,
- )
-
- with open(descriptor, "w") as f:
- f.writelines([
- "#!/bin/bash\n",
- f'{sys.executable} "$@"',
- ])
-
- monkeypatch.setenv("PATH", f"{tmpdir}")
- assert _check_environment_and_redirect() is None
diff --git a/tests/tests_app/utilities/test_cloud.py b/tests/tests_app/utilities/test_cloud.py
deleted file mode 100644
index 4982874ccb24a..0000000000000
--- a/tests/tests_app/utilities/test_cloud.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import os
-from unittest import mock
-
-from lightning.app.utilities.cloud import _get_project, is_running_in_cloud
-from lightning_cloud.openapi.models import V1Project
-
-
-def test_get_project_queries_by_project_id_directly_if_it_is_passed():
- lightning_client = mock.MagicMock()
- lightning_client.projects_service_get_project = mock.MagicMock(return_value=V1Project(id="project_id"))
- project = _get_project(lightning_client, project_id="project_id")
- assert project.project_id == "project_id"
- lightning_client.projects_service_get_project.assert_called_once_with("project_id")
-
-
-def test_is_running_cloud():
- """We can determine if Lightning is running in the cloud."""
- with mock.patch.dict(os.environ, {}, clear=True):
- assert not is_running_in_cloud()
-
- with mock.patch.dict(os.environ, {"LAI_RUNNING_IN_CLOUD": "0"}, clear=True):
- assert not is_running_in_cloud()
-
- # in the cloud, LIGHTNING_APP_STATE_URL is defined
- with mock.patch.dict(os.environ, {"LIGHTNING_APP_STATE_URL": "defined"}, clear=True):
- assert is_running_in_cloud()
-
- # LAI_RUNNING_IN_CLOUD is used to fake the value of `is_running_in_cloud` when loading the app for --cloud
- with mock.patch.dict(os.environ, {"LAI_RUNNING_IN_CLOUD": "1"}):
- assert is_running_in_cloud()
diff --git a/tests/tests_app/utilities/test_commands.py b/tests/tests_app/utilities/test_commands.py
deleted file mode 100644
index fb74e03ba6e56..0000000000000
--- a/tests/tests_app/utilities/test_commands.py
+++ /dev/null
@@ -1,165 +0,0 @@
-import argparse
-import sys
-from multiprocessing import Process
-from time import sleep
-from unittest.mock import MagicMock
-
-import pytest
-import requests
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.cli.commands.app_commands import _run_app_command
-from lightning.app.cli.connect.app import connect_app, disconnect_app
-from lightning.app.core.constants import APP_SERVER_PORT
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.commands.base import ClientCommand, _download_command, _validate_client_command
-from lightning.app.utilities.state import AppState
-from pydantic import BaseModel
-
-
-class SweepConfig(BaseModel):
- sweep_name: str
- num_trials: int
-
-
-class SweepCommand(ClientCommand):
- def run(self) -> None:
- parser = argparse.ArgumentParser()
- parser.add_argument("--sweep_name", type=str)
- parser.add_argument("--num_trials", type=int)
- hparams = parser.parse_args()
-
- config = SweepConfig(sweep_name=hparams.sweep_name, num_trials=hparams.num_trials)
- response = self.invoke_handler(config=config)
- assert response is True
-
-
-class FlowCommands(LightningFlow):
- def __init__(self):
- super().__init__()
- self.names = []
- self.has_sweep = False
-
- def run(self):
- if self.has_sweep and len(self.names) == 1:
- sleep(1)
- self.stop()
-
- def trigger_method(self, name: str):
- print(name)
- self.names.append(name)
-
- def sweep(self, config: SweepConfig):
- self.has_sweep = True
- return True
-
- def configure_commands(self):
- return [{"user command": self.trigger_method}, {"sweep": SweepCommand(self.sweep)}]
-
-
-class DummyConfig(BaseModel):
- something: str
- something_else: int
-
-
-class DummyCommand(ClientCommand):
- def run(self, something: str, something_else: int) -> None:
- config = DummyConfig(something=something, something_else=something_else)
- response = self.invoke_handler(config=config)
- assert response == {"body": 0}
-
-
-def run(config: DummyConfig):
- assert isinstance(config, DummyCommand)
-
-
-def run_failure_0(name: str):
- pass
-
-
-def run_failure_1(name):
- pass
-
-
-class CustomModel(BaseModel):
- pass
-
-
-def run_failure_2(name: CustomModel):
- pass
-
-
-@_RunIf(skip_windows=True)
-def test_validate_client_command():
- with pytest.raises(Exception, match="The provided annotation for the argument name"):
- _validate_client_command(ClientCommand(run_failure_0))
-
- with pytest.raises(Exception, match="annotate your method"):
- _validate_client_command(ClientCommand(run_failure_1))
-
- starts = "lightning.app".replace(".", "/")
- with pytest.raises(Exception, match=f"{starts}/utilities/commands/base.py"):
- _validate_client_command(ClientCommand(run_failure_2))
-
-
-def test_client_commands(monkeypatch):
- import requests
-
- resp = MagicMock()
- resp.status_code = 200
- value = {"body": 0}
- resp.json = MagicMock(return_value=value)
- post = MagicMock()
- post.return_value = resp
- monkeypatch.setattr(requests, "post", post)
- url = "http//"
- kwargs = {"something": "1", "something_else": "1"}
- command = DummyCommand(run)
- _validate_client_command(command)
- client_command = _download_command(
- command_name="something",
- cls_path=__file__,
- cls_name="DummyCommand",
- )
- client_command._setup("something", app_url=url)
- client_command.run(**kwargs)
-
-
-def target():
- app = LightningApp(FlowCommands())
- MultiProcessRuntime(app).dispatch()
-
-
-@pytest.mark.xfail(strict=False, reason="failing for some reason, need to be fixed.") # fixme
-def test_configure_commands(monkeypatch):
- """This test validates command can be used locally with connect and disconnect."""
- process = Process(target=target)
- process.start()
- time_left = 15
- while time_left > 0:
- try:
- requests.get(f"http://localhost:{APP_SERVER_PORT}/healthz")
- break
- except requests.exceptions.ConnectionError:
- sleep(0.1)
- time_left -= 0.1
-
- sleep(0.5)
- monkeypatch.setattr(sys, "argv", ["lightning", "user", "command", "--name=something"])
- connect_app("localhost")
- _run_app_command("localhost", None)
- sleep(2)
- state = AppState()
- state._request_state()
- assert state.names == ["something"]
- monkeypatch.setattr(sys, "argv", ["lightning", "sweep", "--sweep_name=my_name", "--num_trials=1"])
- _run_app_command("localhost", None)
- time_left = 15
- while time_left > 0:
- if process.exitcode == 0:
- break
- sleep(0.1)
- time_left -= 0.1
- assert process.exitcode == 0
- disconnect_app()
- process.kill()
diff --git a/tests/tests_app/utilities/test_component.py b/tests/tests_app/utilities/test_component.py
deleted file mode 100644
index 050216361eb81..0000000000000
--- a/tests/tests_app/utilities/test_component.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import pytest
-from lightning.app.storage.path import Path
-from lightning.app.testing.helpers import EmptyFlow, EmptyWork
-from lightning.app.utilities.component import (
- _context,
- _convert_paths_after_init,
- _get_context,
- _is_flow_context,
- _is_work_context,
- _set_context,
- _set_work_context,
-)
-from lightning.app.utilities.enum import ComponentContext
-
-
-def test_convert_paths_after_init():
- """Test that we can convert paths after the Flow/Work initialization, i.e., when the LightningApp is fully
- instantiated."""
-
- # TODO: Add a test case for the Lightning List and Dict containers
-
- class Flow1(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.path1 = Path("a")
- self.path2 = Path("b")
-
- flow1 = Flow1()
- assert flow1._paths == {}
- _convert_paths_after_init(flow1)
- assert flow1._paths == {"path1": Path("a").to_dict(), "path2": Path("b").to_dict()}
-
- class Work1(EmptyWork):
- def __init__(self):
- super().__init__()
- self.path3 = Path("c")
-
- class Flow2(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.work1 = Work1()
- self.path4 = Path("d")
-
- flow2 = Flow2()
- assert flow2._paths == {}
- assert flow2.work1._paths == {}
- _convert_paths_after_init(flow2)
- assert flow2._paths == {"path4": Path("d").to_dict()}
- assert set(flow2.work1._paths.keys()) == {"path3"}
- assert flow2.work1._paths["path3"]["origin_name"] == "root.work1"
- assert flow2.work1._paths["path3"]["consumer_name"] == "root.work1"
-
-
-@pytest.mark.parametrize("ctx", [c.value for c in ComponentContext])
-def test_context_context_manager(ctx):
- with _context("flow"):
- assert _get_context().value == "flow"
- assert _get_context() is None
-
-
-@pytest.mark.parametrize("ctx", [c.value for c in ComponentContext])
-def test_set_get_context(ctx):
- assert _get_context() is None
- _set_context(ctx)
- assert _get_context().value == ctx
-
-
-def test_is_context():
- _set_context("flow")
- assert _is_flow_context()
-
- _set_work_context()
- assert _is_work_context()
-
- _set_context(None)
- assert not _is_flow_context()
- assert not _is_work_context()
diff --git a/tests/tests_app/utilities/test_dependency_caching.py b/tests/tests_app/utilities/test_dependency_caching.py
deleted file mode 100644
index ef63e254f29a0..0000000000000
--- a/tests/tests_app/utilities/test_dependency_caching.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from pathlib import Path
-
-from lightning.app.utilities.dependency_caching import get_hash
-
-
-def test_get_hash(tmpdir):
- req_path = tmpdir / "requirements.txt"
- Path(req_path).touch()
-
- # empty requirements file
- assert get_hash(req_path) == "3345524abf6bbe1809449224b5972c41790b6cf2"
-
- # requirements file with one dependency
- req_path.write_text("lightning==1.0", encoding="utf-8")
- assert get_hash(req_path) == "6177677a74b5d256e331cb9e390af58106e20220"
diff --git a/tests/tests_app/utilities/test_exceptions.py b/tests/tests_app/utilities/test_exceptions.py
deleted file mode 100644
index 96ac20deb9fdc..0000000000000
--- a/tests/tests_app/utilities/test_exceptions.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from json import dumps
-from unittest.mock import MagicMock
-
-import pytest
-from click import ClickException, group
-from click.testing import CliRunner
-from lightning.app.utilities.exceptions import _ApiExceptionHandler
-from lightning_cloud.openapi.rest import ApiException
-from urllib3 import HTTPResponse
-
-
-@pytest.fixture()
-def mock_api_handled_group():
- @group(cls=_ApiExceptionHandler)
- def g():
- pass
-
- return g
-
-
-@pytest.fixture()
-def mock_subcommand(mock_api_handled_group):
- @mock_api_handled_group.command()
- def cmd():
- pass
-
- return cmd
-
-
-@pytest.fixture()
-def api_error_msg():
- return "This is an internal error message"
-
-
-class Test_ApiExceptionHandler:
- def test_4xx_exceptions_caught_in_subcommands(self, mock_api_handled_group, mock_subcommand, api_error_msg):
- mock_subcommand.invoke = MagicMock(
- side_effect=ApiException(
- http_resp=HTTPResponse(
- status=400,
- reason="Bad Request",
- body=dumps(
- {
- "code": 3,
- "message": api_error_msg,
- "details": [],
- },
- ),
- )
- )
- )
-
- runner = CliRunner()
- result = runner.invoke(
- mock_api_handled_group,
- [mock_subcommand.name],
- standalone_mode=False, # stop runner from raising SystemExit on ClickException
- )
-
- mock_subcommand.invoke.assert_called
- assert result.exit_code == 1
- assert type(result.exception) is ClickException
- assert api_error_msg == str(result.exception)
-
- def test_original_thrown_if_cannot_decode_body(self, mock_api_handled_group, mock_subcommand):
- mock_subcommand.invoke = MagicMock(
- side_effect=ApiException(
- http_resp=HTTPResponse(
- status=400,
- reason="Bad Request",
- body="message from server is not json encoded!",
- )
- )
- )
-
- runner = CliRunner()
- result = runner.invoke(
- mock_api_handled_group,
- [mock_subcommand.name],
- )
-
- mock_subcommand.invoke.assert_called
- assert result.exit_code == 1
- assert type(result.exception) is ApiException
diff --git a/tests/tests_app/utilities/test_git.py b/tests/tests_app/utilities/test_git.py
deleted file mode 100644
index b655baae5f947..0000000000000
--- a/tests/tests_app/utilities/test_git.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import os
-import sys
-from unittest.mock import patch
-
-from lightning.app.utilities.git import (
- check_github_repository,
- check_if_remote_head_is_different,
- get_dir_name,
- get_git_relative_path,
- has_uncommitted_files,
-)
-
-
-def mock_execute_git_command(args, cwd=None) -> str:
- if args == ["config", "--get", "remote.origin.url"]:
- return "https://github.com/Lightning-AI/lightning.git"
-
- if args == ["rev-parse", "--show-toplevel"]:
- return os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
-
- if args == ["update-index", "--refresh"]:
- return ""
-
- if args == ["rev-parse", "@"]:
- return "local-sha"
-
- if args == ["rev-parse", r"@{u}"] or args == ["merge-base", "@", r"@{u}"]:
- return "remote-sha"
-
- return "Error: Unexpected call"
-
-
-@patch("lightning.app.utilities.git.execute_git_command", mock_execute_git_command)
-def test_execute_git_command():
- assert get_dir_name() == "lightning"
-
- assert check_github_repository()
-
- if sys.platform == "win32":
- assert get_git_relative_path(__file__) == "tests\\tests_app\\utilities\\test_git.py"
- else:
- assert get_git_relative_path(__file__) == "tests/tests_app/utilities/test_git.py"
-
- assert check_if_remote_head_is_different()
-
- assert not has_uncommitted_files()
diff --git a/tests/tests_app/utilities/test_imports.py b/tests/tests_app/utilities/test_imports.py
deleted file mode 100644
index bf124ac6ed198..0000000000000
--- a/tests/tests_app/utilities/test_imports.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import os
-from unittest import mock
-
-import pytest
-from lightning.app import __package_name__
-from lightning.app.utilities.imports import _get_extras, requires
-
-
-def test_get_extras():
- extras = "app-cloud" if __package_name__ == "lightning" else "cloud"
- extras = _get_extras(extras)
- assert "docker" in extras
- assert "redis" in extras
-
- assert _get_extras("fake-extras") == ""
-
-
-@mock.patch.dict(os.environ, {"LIGHTING_TESTING": "0"})
-def test_requires():
- @requires("lightning.app")
- def fn():
- pass
-
- fn()
-
- @requires("shouldnotexist")
- def fn_raise():
- pass
-
- with pytest.raises(ModuleNotFoundError, match="Please run: pip install 'shouldnotexist'"):
- fn_raise()
-
- class ClassRaise:
- @requires("shouldnotexist")
- def __init__(self):
- pass
-
- with pytest.raises(ModuleNotFoundError, match="Please run: pip install 'shouldnotexist'"):
- ClassRaise()
-
-
-@mock.patch.dict(os.environ, {"LIGHTING_TESTING": "0"})
-def test_requires_multiple():
- @requires(["shouldnotexist1", "shouldnotexist2"])
- def fn_raise():
- pass
-
- with pytest.raises(ModuleNotFoundError, match="Please run: pip install 'shouldnotexist1' 'shouldnotexist2'"):
- fn_raise()
diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py
deleted file mode 100644
index ce5d54f4e0746..0000000000000
--- a/tests/tests_app/utilities/test_introspection.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-from numbers import Rational
-
-from lightning.app import LightningApp, LightningFlow
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.imports import _is_pytorch_lightning_available
-from lightning.app.utilities.introspection import Scanner
-
-if _is_pytorch_lightning_available():
- from pytorch_lightning import Trainer
- from pytorch_lightning.cli import LightningCLI
-
-from tests_app import _PROJECT_ROOT
-
-
-def test_introspection():
- """This test validates the scanner can find some class within the provided files."""
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/example_1.py")))
- assert scanner.has_class(Rational)
- assert not scanner.has_class(LightningApp)
-
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/example_2.py")))
- assert scanner.has_class(LightningApp)
- assert not scanner.has_class(LightningFlow)
-
-
-@_RunIf(pl=True)
-def test_introspection_lightning():
- """This test validates the scanner can find some PyTorch Lightning class within the provided files."""
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py")))
- assert not scanner.has_class(Trainer)
- assert scanner.has_class(LightningCLI)
-
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_trainer.py")))
- assert scanner.has_class(Trainer)
- assert not scanner.has_class(LightningCLI)
-
-
-@_RunIf(pl=True)
-def test_introspection_lightning_overrides():
- """This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning in
- the provided files."""
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py")))
- scan = scanner.scan()
- assert set(scan) == {"LightningDataModule", "LightningModule"}
-
- scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_overrides.py")))
- scan = scanner.scan()
- assert set(scan) == {
- "Accelerator",
- "Profiler",
- "Callback",
- "LightningDataModule",
- "Fabric",
- "Logger",
- "LightningModule",
- "Metric",
- "PrecisionPlugin",
- "Trainer",
- }
diff --git a/tests/tests_app/utilities/test_layout.py b/tests/tests_app/utilities/test_layout.py
deleted file mode 100644
index 49d9a75ce2e9e..0000000000000
--- a/tests/tests_app/utilities/test_layout.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import pytest
-from lightning.app.core.flow import LightningFlow
-from lightning.app.core.work import LightningWork
-from lightning.app.frontend.web import StaticWebFrontend
-from lightning.app.utilities.layout import _collect_layout
-
-
-class _MockApp:
- def __init__(self) -> None:
- self.frontends = {}
-
-
-class FlowWithFrontend(LightningFlow):
- def configure_layout(self):
- return StaticWebFrontend(".")
-
-
-class WorkWithFrontend(LightningWork):
- def run(self):
- pass
-
- def configure_layout(self):
- return StaticWebFrontend(".")
-
-
-class FlowWithWorkWithFrontend(LightningFlow):
- def __init__(self):
- super().__init__()
-
- self.work = WorkWithFrontend()
-
- def configure_layout(self):
- return {"name": "work", "content": self.work}
-
-
-class FlowWithUrl(LightningFlow):
- def configure_layout(self):
- return {"name": "test", "content": "https://test"}
-
-
-class WorkWithUrl(LightningWork):
- def run(self):
- pass
-
- def configure_layout(self):
- return "https://test"
-
-
-class FlowWithWorkWithUrl(LightningFlow):
- def __init__(self):
- super().__init__()
-
- self.work = WorkWithUrl()
-
- def configure_layout(self):
- return {"name": "test", "content": self.work}
-
-
-@pytest.mark.parametrize(
- ("flow", "expected_layout", "expected_frontends"),
- [
- (FlowWithFrontend, {}, [("root", StaticWebFrontend)]),
- (FlowWithWorkWithFrontend, {}, [("root", StaticWebFrontend)]),
- (FlowWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []),
- (FlowWithWorkWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []),
- ],
-)
-def test_collect_layout(flow, expected_layout, expected_frontends):
- app = _MockApp()
- flow = flow()
- layout = _collect_layout(app, flow)
-
- assert layout == expected_layout
- assert set(app.frontends.keys()) == {key for key, _ in expected_frontends}
- for key, frontend_type in expected_frontends:
- assert isinstance(app.frontends[key], frontend_type)
-
-
-class FlowWithBadLayout(LightningFlow):
- def configure_layout(self):
- return 100
-
-
-class FlowWithBadLayoutDict(LightningFlow):
- def configure_layout(self):
- return {"this_key_should_not_be_here": "http://appurl"}
-
-
-class FlowWithBadContent(LightningFlow):
- def configure_layout(self):
- return {"content": 100}
-
-
-class WorkWithBadLayout(LightningWork):
- def run(self):
- pass
-
- def configure_layout(self):
- return 100
-
-
-class FlowWithWorkWithBadLayout(LightningFlow):
- def __init__(self):
- super().__init__()
-
- self.work = WorkWithBadLayout()
-
- def configure_layout(self):
- return {"name": "test", "content": self.work}
-
-
-class FlowWithMultipleWorksWithFrontends(LightningFlow):
- def __init__(self):
- super().__init__()
-
- self.work1 = WorkWithFrontend()
- self.work2 = WorkWithFrontend()
-
- def configure_layout(self):
- return [{"name": "test1", "content": self.work1}, {"name": "test2", "content": self.work2}]
-
-
-@pytest.mark.parametrize(
- ("flow", "error_type", "match"),
- [
- (FlowWithBadLayout, TypeError, "is an unsupported layout format"),
- (FlowWithBadLayoutDict, ValueError, "missing a key 'content'."),
- (FlowWithBadContent, ValueError, "contains an unsupported entry."),
- (FlowWithWorkWithBadLayout, TypeError, "is of an unsupported type."),
- (
- FlowWithMultipleWorksWithFrontends,
- TypeError,
- "The tab containing a `WorkWithFrontend` must be the only tab",
- ),
- ],
-)
-def test_collect_layout_errors(flow, error_type, match):
- app = _MockApp()
- flow = flow()
-
- with pytest.raises(error_type, match=match):
- _collect_layout(app, flow)
diff --git a/tests/tests_app/utilities/test_load_app.py b/tests/tests_app/utilities/test_load_app.py
deleted file mode 100644
index 62e09764c1b3c..0000000000000
--- a/tests/tests_app/utilities/test_load_app.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import os
-import sys
-from unittest.mock import ANY
-
-import pytest
-from lightning.app.utilities.exceptions import MisconfigurationException
-from lightning.app.utilities.load_app import extract_metadata_from_app, load_app_from_file
-
-
-def test_load_app_from_file_errors():
- test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
- with pytest.raises(MisconfigurationException, match="There should not be multiple apps instantiated within a file"):
- load_app_from_file(os.path.join(test_script_dir, "two_apps.py"))
-
- with pytest.raises(MisconfigurationException, match="The provided file .* does not contain a LightningApp"):
- load_app_from_file(os.path.join(test_script_dir, "empty.py"))
-
- with pytest.raises(SystemExit, match="1"):
- load_app_from_file(os.path.join(test_script_dir, "script_with_error.py"))
-
-
-@pytest.mark.parametrize("app_path", ["app_metadata.py", "app_with_local_import.py"])
-def test_load_app_from_file(app_path):
- """Test that apps load without error and that sys.path and main module are set."""
- original_main = sys.modules["__main__"]
- test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
- load_app_from_file(os.path.join(test_script_dir, app_path), raise_exception=True)
-
- assert test_script_dir in sys.path
- assert sys.modules["__main__"] != original_main
-
-
-def test_extract_metadata_from_component():
- test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
- app = load_app_from_file(os.path.join(test_script_dir, "app_metadata.py"))
- metadata = extract_metadata_from_app(app)
- assert metadata == [
- {"affiliation": ["root"], "cls_name": "RootFlow", "module": "__main__", "docstring": "RootFlow."},
- {
- "affiliation": ["root", "flow_a_1"],
- "cls_name": "FlowA",
- "module": "__main__",
- "docstring": "FlowA Component.",
- },
- {
- "affiliation": ["root", "flow_a_1", "work_a"],
- "cls_name": "WorkA",
- "module": "__main__",
- "docstring": "WorkA.",
- "local_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "shm_size": 0,
- "mounts": None,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- {
- "affiliation": ["root", "flow_a_2"],
- "cls_name": "FlowA",
- "module": "__main__",
- "docstring": "FlowA Component.",
- },
- {
- "affiliation": ["root", "flow_a_2", "work_a"],
- "cls_name": "WorkA",
- "module": "__main__",
- "docstring": "WorkA.",
- "local_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_compute": {
- "type": "__cloud_compute__",
- "name": "cpu-small",
- "disk_size": 0,
- "idle_timeout": None,
- "shm_size": 0,
- "mounts": None,
- "_internal_id": "default",
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- {"affiliation": ["root", "flow_b"], "cls_name": "FlowB", "module": "__main__", "docstring": "FlowB."},
- {
- "affiliation": ["root", "flow_b", "work_b"],
- "cls_name": "WorkB",
- "module": "__main__",
- "docstring": "WorkB.",
- "local_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_build_config": {"__build_config__": {"requirements": [], "dockerfile": None, "image": None}},
- "cloud_compute": {
- "type": "__cloud_compute__",
- "name": "gpu",
- "disk_size": 0,
- "idle_timeout": None,
- "shm_size": 1024,
- "mounts": None,
- "_internal_id": ANY,
- "interruptible": False,
- "colocation_group_id": None,
- },
- },
- ]
diff --git a/tests/tests_app/utilities/test_log_helpers.py b/tests/tests_app/utilities/test_log_helpers.py
deleted file mode 100644
index 6eae52388383c..0000000000000
--- a/tests/tests_app/utilities/test_log_helpers.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from unittest import TestCase, mock
-
-from lightning.app.utilities.log_helpers import _error_callback
-
-
-class TestErrorCallback(TestCase):
- def test_known_error(self):
- websocket = mock.Mock()
- with self.assertLogs("lightning.app.utilities.log_helpers") as captured:
- _error_callback(websocket, ValueError())
- # check that there is only one log message
- assert len(captured.records) == 1
- # and it contains the error message expected
- assert "Error while reading logs (Malformed date format)" in captured.records[0].getMessage()
-
- def test_unknown_error(self):
- websocket = mock.Mock()
- with self.assertLogs("lightning.app.utilities.log_helpers") as captured:
- _error_callback(websocket, OSError())
- # check that there is only one log message
- assert len(captured.records) == 1
- # and it contains the error message expected
- assert "Error while reading logs (Unknown)" in captured.records[0].getMessage()
diff --git a/tests/tests_app/utilities/test_login.py b/tests/tests_app/utilities/test_login.py
deleted file mode 100644
index d8f6d3d558707..0000000000000
--- a/tests/tests_app/utilities/test_login.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import os
-from unittest import mock
-
-import pytest
-import requests
-from lightning.app.utilities import login
-
-LIGHTNING_CLOUD_URL = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
-
-
-@pytest.fixture(autouse=True)
-def before_each():
- for key in login.Keys:
- os.environ.pop(key.value, None)
- login.Auth().clear()
-
-
-class TestAuthentication:
- def test_can_store_credentials(self):
- auth = login.Auth()
- auth.save(username="superman", user_id="kr-1234")
- assert auth.secrets_file.exists()
-
- auth.clear()
- assert not auth.secrets_file.exists()
-
- def test_e2e(self):
- auth = login.Auth()
- auth.save(username="superman", user_id="kr-1234")
- assert auth.secrets_file.exists()
-
- def test_get_auth_header_should_raise_error(self):
- with pytest.raises(AttributeError):
- login.Auth().auth_header
-
- def test_credentials_file_io(self):
- auth = login.Auth()
- assert not auth.secrets_file.exists()
- assert auth.load() is False
- auth.save(username="", user_id="123", api_key="123")
- assert auth.secrets_file.exists()
- assert auth.load() is True
-
- def test_auth_header(self):
- # fake credentials
- os.environ.setdefault("LIGHTNING_USER_ID", "7c8455e3-7c5f-4697-8a6d-105971d6b9bd")
- os.environ.setdefault("LIGHTNING_API_KEY", "e63fae57-2b50-498b-bc46-d6204cbf330e")
- auth = login.Auth()
- auth.clear()
- auth.authenticate()
-
- assert "Basic" in auth.auth_header
- assert (
- auth.auth_header
- == "Basic N2M4NDU1ZTMtN2M1Zi00Njk3LThhNmQtMTA1OTcxZDZiOWJkOmU2M2ZhZTU3LTJiNTAtNDk4Yi1iYzQ2LWQ2MjA0Y2JmMzMwZQ==" # noqa E501
- )
-
-
-def test_authentication_with_invalid_environment_vars():
- # if api key is passed without user id
- os.environ.setdefault("LIGHTNING_API_KEY", "123")
- with pytest.raises(ValueError):
- auth = login.Auth()
- auth.clear()
- auth.authenticate()
-
-
-@mock.patch("lightning.app.utilities.login.AuthServer.login_with_browser")
-def test_authentication_with_environment_vars(browser_login: mock.MagicMock):
- os.environ.setdefault("LIGHTNING_USER_ID", "abc")
- os.environ.setdefault("LIGHTNING_API_KEY", "abc")
-
- auth = login.Auth()
- auth.clear()
- auth.authenticate()
-
- assert auth.user_id == "abc"
- assert auth.auth_header == "Basic YWJjOmFiYw=="
- assert auth.authenticate() == auth.auth_header
- # should not run login flow when env vars are passed
- browser_login.assert_not_called()
-
- # Check credentials file
- assert auth.secrets_file.exists()
- assert auth.load() is True
-
-
-def test_get_auth_url():
- auth_url = login.AuthServer().get_auth_url(1234)
- assert (
- auth_url == f"{LIGHTNING_CLOUD_URL}/sign-in?redirectTo=http%3A%2F%2Flocalhost%3A1234%2Flogin-complete"
- ) # E501
-
-
-@mock.patch("lightning.app.utilities.login.find_free_network_port")
-@mock.patch("uvicorn.Server.run")
-@mock.patch("requests.head")
-@mock.patch("click.launch")
-def test_login_with_browser(
- click_launch: mock.MagicMock, head: mock.MagicMock, run: mock.MagicMock, port: mock.MagicMock
-):
- port.return_value = 1234
- login.Auth()._run_server()
- url = f"{LIGHTNING_CLOUD_URL}/sign-in?redirectTo=http%3A%2F%2Flocalhost%3A1234%2Flogin-complete" # E501
- head.assert_called_once_with(url)
- click_launch.assert_called_once_with(url)
- run.assert_called_once()
-
-
-@mock.patch("lightning.app.utilities.login.find_free_network_port")
-@mock.patch("uvicorn.Server.run")
-@mock.patch("requests.head")
-@mock.patch("click.launch")
-def test_authenticate(click_launch: mock.MagicMock, head: mock.MagicMock, run: mock.MagicMock, port: mock.MagicMock):
- port.return_value = 1234
- auth = login.Auth()
- auth.clear()
-
- click_launch.side_effect = lambda _: auth.save("", "user_id", "api_key", "user_id")
-
- auth.authenticate()
- url = f"{LIGHTNING_CLOUD_URL}/sign-in?redirectTo=http%3A%2F%2Flocalhost%3A1234%2Flogin-complete" # E501
- head.assert_called_with(url)
- click_launch.assert_called_with(url)
- run.assert_called()
-
- assert auth.auth_header == "Basic dXNlcl9pZDphcGlfa2V5"
-
- auth.authenticate()
- assert auth.auth_header == "Basic dXNlcl9pZDphcGlfa2V5"
-
-
-@mock.patch("uvicorn.Server.run")
-@mock.patch("requests.head")
-def test_network_failure(
- head: mock.MagicMock,
- run: mock.MagicMock,
-):
- head.side_effect = requests.ConnectionError()
- with pytest.raises(requests.ConnectionError):
- login.Auth()._run_server()
- run.assert_not_called()
-
- head.side_effect = requests.RequestException()
- with pytest.raises(requests.RequestException):
- login.Auth()._run_server()
- run.assert_not_called()
-
-
-def test_with_api_key_only():
- auth = login.Auth()
- auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")
- hash_ = "N2M4NDU1ZTMtN2M1Zi00Njk3LThhNmQtMTA1OTcxZDZiOWJkOmU2M2ZhZTU3LTJiNTAtNDk4Yi1iYzQ2LWQ2MjA0Y2JmMzMwZQ"
- assert auth.authenticate() == f"Basic {hash_}==" # E501
diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py
deleted file mode 100644
index e3ccaf662d57d..0000000000000
--- a/tests/tests_app/utilities/test_network.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from unittest import mock
-
-import pytest
-from lightning.app.core import constants
-from lightning.app.utilities.network import find_free_network_port
-
-
-def test_find_free_network_port():
- """Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found."""
- assert find_free_network_port()
-
- with mock.patch("lightning.app.utilities.network.socket") as mock_socket:
- mock_socket.socket().getsockname.return_value = [0, 8888]
- assert find_free_network_port() == 8888
-
- with pytest.raises(RuntimeError, match="Couldn't find a free port."):
- find_free_network_port()
-
- mock_socket.socket().getsockname.return_value = [0, 9999]
- assert find_free_network_port() == 9999
-
-
-@mock.patch("lightning.app.utilities.network.socket")
-@pytest.mark.parametrize(
- "patch_constants",
- [{"LIGHTNING_CLOUDSPACE_HOST": "any", "LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT": 10}],
- indirect=True,
-)
-def test_find_free_network_port_cloudspace(_, patch_constants):
- """Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found when
- cloudspace env variables are set."""
- ports = set()
- num_ports = 0
-
- with pytest.raises(RuntimeError, match="All 10 ports are already in use."):
- for _ in range(11):
- ports.add(find_free_network_port())
- num_ports = num_ports + 1
-
- # Check that all ports are unique
- assert len(ports) == num_ports
-
- # Shouldn't use the APP_SERVER_PORT
- assert constants.APP_SERVER_PORT not in ports
diff --git a/tests/tests_app/utilities/test_port.py b/tests/tests_app/utilities/test_port.py
deleted file mode 100644
index 68a19e3f46c0a..0000000000000
--- a/tests/tests_app/utilities/test_port.py
+++ /dev/null
@@ -1,108 +0,0 @@
-from unittest.mock import MagicMock
-
-import pytest
-from lightning.app.utilities import port
-from lightning.app.utilities.port import _find_lit_app_port, disable_port, enable_port
-from lightning_cloud.openapi import V1NetworkConfig
-
-
-def test_find_lit_app_port(monkeypatch):
- client = MagicMock()
- monkeypatch.setattr(port, "LightningClient", MagicMock(return_value=client))
-
- assert _find_lit_app_port(5701) == 5701
-
- resp = MagicMock()
- lit_app = MagicMock()
- lit_app.id = "a"
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=False),
- ]
- resp.lightningapps = [lit_app]
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = resp
-
- monkeypatch.setenv("LIGHTNING_CLOUD_APP_ID", "a")
- monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "a")
- monkeypatch.setenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "1")
-
- assert _find_lit_app_port(5701) == 1
-
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=True),
- ]
-
- with pytest.raises(RuntimeError, match="No available port was found. Please"):
- _find_lit_app_port(5701)
-
-
-def test_enable_port(monkeypatch):
- client = MagicMock()
- monkeypatch.setattr(port, "LightningClient", MagicMock(return_value=client))
-
- assert _find_lit_app_port(5701) == 5701
-
- resp = MagicMock()
- lit_app = MagicMock()
- lit_app.id = "a"
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=False),
- ]
- resp.lightningapps = [lit_app]
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = resp
-
- monkeypatch.setenv("LIGHTNING_CLOUD_APP_ID", "a")
- monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "a")
- monkeypatch.setenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "1")
-
- assert enable_port()
-
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=True),
- ]
-
- with pytest.raises(RuntimeError, match="No available port was found. Please"):
- assert enable_port()
-
-
-def test_disable_port(monkeypatch):
- client = MagicMock()
- monkeypatch.setattr(port, "LightningClient", MagicMock(return_value=client))
-
- assert _find_lit_app_port(5701) == 5701
-
- resp = MagicMock()
- lit_app = MagicMock()
- lit_app.id = "a"
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=False),
- ]
- resp.lightningapps = [lit_app]
- client.lightningapp_instance_service_list_lightningapp_instances.return_value = resp
-
- monkeypatch.setenv("LIGHTNING_CLOUD_APP_ID", "a")
- monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "a")
- monkeypatch.setenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "1")
-
- disable_port(0)
- assert not lit_app.spec.network_config[0].enable
-
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=False),
- ]
-
- with pytest.raises(RuntimeError, match="The port 1 was already disabled."):
- disable_port(1, ignore_disabled=False)
-
- lit_app.spec.network_config = [
- V1NetworkConfig(host="a", port=0, enable=True),
- V1NetworkConfig(host="a", port=1, enable=False),
- ]
-
- with pytest.raises(ValueError, match="[0, 1]"):
- assert disable_port(10)
diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py
deleted file mode 100644
index 3c5d830e30e02..0000000000000
--- a/tests/tests_app/utilities/test_proxies.py
+++ /dev/null
@@ -1,795 +0,0 @@
-import contextlib
-import logging
-import os
-import pathlib
-import sys
-import time
-import traceback
-from copy import deepcopy
-from queue import Empty
-from unittest import mock
-from unittest.mock import MagicMock, Mock
-
-import pytest
-from deepdiff import DeepDiff, Delta
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.runners import MultiProcessRuntime
-from lightning.app.storage import Drive
-from lightning.app.storage.path import Path, _artifacts_path
-from lightning.app.storage.requests import _GetRequest
-from lightning.app.testing.helpers import EmptyFlow, _MockQueue
-from lightning.app.utilities.component import _convert_paths_after_init
-from lightning.app.utilities.enum import AppStage, CacheCallsKeys, WorkFailureReasons, WorkStageStatus
-from lightning.app.utilities.exceptions import CacheMissException, ExitAppException
-from lightning.app.utilities.proxies import (
- ComponentDelta,
- LightningWorkSetAttrProxy,
- ProxyWorkRun,
- WorkRunner,
- WorkStateObserver,
- persist_artifacts,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class Work(LightningWork):
- def __init__(self, cache_calls=True, parallel=True):
- super().__init__(cache_calls=cache_calls, parallel=parallel)
- self.counter = 0
-
- def run(self):
- self.counter = 1
- return 1
-
-
-def test_lightning_work_setattr():
- """This test valides that the `LightningWorkSetAttrProxy` would push a delta to the `caller_queue` everytime an
- attribute from the work state is being changed."""
-
- w = Work()
- # prepare
- w._name = "root.b"
- # create queue
- caller_queue = _MockQueue("caller_queue")
-
- def proxy_setattr():
- w._setattr_replacement = LightningWorkSetAttrProxy(w._name, w, caller_queue, MagicMock())
-
- proxy_setattr()
- w.run()
- assert len(caller_queue) == 1
- work_proxy_output = caller_queue._queue[0]
- assert isinstance(work_proxy_output, ComponentDelta)
- assert work_proxy_output.id == w._name
- assert work_proxy_output.delta.to_dict() == {"values_changed": {"root['vars']['counter']": {"new_value": 1}}}
-
-
-@pytest.mark.parametrize(
- ("parallel", "cache_calls"),
- [
- (True, True),
- (True, False),
- (False, True),
- pytest.param(False, False, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme
- ],
-)
-@mock.patch("lightning.app.utilities.proxies._Copier", MagicMock())
-@pytest.mark.flaky(reruns=3)
-@pytest.mark.xfail(sys.platform == "win32", strict=False, reason="Fix this on Windows") # TODO @ethanwharris
-def test_work_runner(parallel, cache_calls, *_):
- """This test validates the `WorkRunner` runs the work.run method and properly populates the `delta_queue`,
- `error_queue` and `readiness_queue`."""
-
- class Work(LightningWork):
- def __init__(self, cache_calls=True, parallel=True):
- super().__init__(cache_calls=cache_calls, parallel=parallel)
- self.counter = 0
- self.dummy_path = "lit://test"
-
- def run(self):
- self.counter = 1
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = Work(cache_calls=cache_calls, parallel=parallel)
-
- def run(self):
- pass
-
- class BlockingQueue(_MockQueue):
- """A Mock for the file copier queues that keeps blocking until we want to end the thread."""
-
- keep_blocking = True
-
- def get(self, timeout: int = 0):
- while BlockingQueue.keep_blocking:
- pass
- # A dummy request so the Copier gets something to process without an error
- return _GetRequest(source="src", name="dummy_path", path="test", hash="123", destination="dst")
-
- app = LightningApp(Flow())
- work = app.root.w
- caller_queue = _MockQueue("caller_queue")
- delta_queue = _MockQueue("delta_queue")
- readiness_queue = _MockQueue("readiness_queue")
- error_queue = _MockQueue("error_queue")
- request_queue = _MockQueue("request_queue")
- response_queue = _MockQueue("response_queue")
- copy_request_queue = BlockingQueue("copy_request_queue")
- copy_response_queue = BlockingQueue("copy_response_queue")
-
- call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c"
- work._calls[call_hash] = {
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "run_started_counter": 1,
- "statuses": [],
- }
- caller_queue.put({
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "state": work.state,
- })
- work_runner = WorkRunner(
- work,
- work.name,
- caller_queue,
- delta_queue,
- readiness_queue,
- error_queue,
- request_queue,
- response_queue,
- copy_request_queue,
- copy_response_queue,
- )
- with contextlib.suppress(Empty, Exception):
- work_runner()
-
- assert readiness_queue._queue[0]
- if parallel:
- assert isinstance(error_queue._queue[0], Exception)
- else:
- assert isinstance(error_queue._queue[0], Empty)
- assert len(delta_queue._queue) in [3, 4]
- res = delta_queue._queue[0].delta.to_dict()["iterable_item_added"]
- assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running"
- assert delta_queue._queue[1].delta.to_dict() == {
- "values_changed": {"root['vars']['counter']": {"new_value": 1}}
- }
- index = 3 if len(delta_queue._queue) == 4 else 2
- res = delta_queue._queue[index].delta.to_dict()["dictionary_item_added"]
- assert res[f"root['calls']['{call_hash}']['ret']"] is None
-
- # Stop blocking and let the thread join
- BlockingQueue.keep_blocking = False
- work_runner.copier.join()
-
-
-def test_pathlike_as_argument_to_run_method_warns(tmpdir):
- """Test that Lightning Produces a special warning for strings that look like paths."""
- # all these paths are not proper paths or don't have a file or folder that exists
- no_warning_expected = (
- "looks/like/path",
- pathlib.Path("looks/like/path"),
- "i am not a path",
- 1,
- Path("lightning/path"),
- )
- for path in no_warning_expected:
- _pass_path_argument_to_work_and_test_warning(path=path, warning_expected=False)
-
- # warn if it looks like a folder and the folder exists
- _pass_path_argument_to_work_and_test_warning(path=tmpdir, warning_expected=True)
-
- # warn if it looks like a string or pathlib Path and the file exists
- file = pathlib.Path(tmpdir, "file_exists.txt")
- file.write_text("test")
- assert os.path.exists(file)
- _pass_path_argument_to_work_and_test_warning(path=file, warning_expected=True)
- _pass_path_argument_to_work_and_test_warning(path=str(file), warning_expected=True)
-
- # do not warn if the path is wrapped in Lightning Path (and the file exists)
- file = Path(tmpdir, "file_exists.txt")
- file.write_text("test")
- assert os.path.exists(file)
- _pass_path_argument_to_work_and_test_warning(path=file, warning_expected=False)
-
-
-def _pass_path_argument_to_work_and_test_warning(path, warning_expected):
- class WarnRunPathWork(LightningWork):
- def run(self, *args, **kwargs):
- pass
-
- class Flow(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.work = WarnRunPathWork()
-
- flow = Flow()
- work = flow.work
- proxy_run = ProxyWorkRun(work.run, "some", work, Mock())
-
- warn_ctx = pytest.warns(UserWarning, match="You passed a the value") if warning_expected else pytest.warns(None)
- with warn_ctx as record, pytest.raises(CacheMissException):
- proxy_run(path)
-
- assert warning_expected or all("You passed a the value" not in str(msg.message) for msg in record)
-
-
-class WorkTimeout(LightningWork):
- def __init__(self):
- super().__init__(parallel=True, start_with_flow=False)
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class FlowTimeout(LightningFlow):
- def __init__(self):
- super().__init__()
- self.counter = 0
- self.work = WorkTimeout()
-
- def run(self):
- if not self.work.has_started:
- self.work.run()
- if self.work.has_timeout:
- self.stop()
-
-
-class WorkRunnerPatch(WorkRunner):
- counter = 0
-
- def __call__(self):
- call_hash = "fe3fa0f"
- while True:
- try:
- called = self.caller_queue.get()
- self.work.set_state(called["state"])
- state = deepcopy(self.work.state)
- self.work._calls[call_hash]["statuses"].append({
- "name": self.work.name,
- "stage": WorkStageStatus.FAILED,
- "reason": WorkFailureReasons.TIMEOUT,
- "timestamp": time.time(),
- "message": None,
- })
- self.delta_queue.put(
- ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state, verbose_level=2)))
- )
- self.counter += 1
- except Exception as ex:
- logger.error(traceback.format_exc())
- self.error_queue.put(ex)
- raise ExitAppException
-
-
-@mock.patch("lightning.app.runners.backends.mp_process.WorkRunner", WorkRunnerPatch)
-def test_proxy_timeout():
- app = LightningApp(FlowTimeout(), log_level="debug")
- MultiProcessRuntime(app, start_server=False).dispatch()
-
- call_hash = app.root.work._calls[CacheCallsKeys.LATEST_CALL_HASH]
- assert len(app.root.work._calls[call_hash]["statuses"]) == 3
- assert app.root.work._calls[call_hash]["statuses"][0]["stage"] == "pending"
- assert app.root.work._calls[call_hash]["statuses"][1]["stage"] == "failed"
- assert app.root.work._calls[call_hash]["statuses"][2]["stage"] == "stopped"
-
-
-@mock.patch("lightning.app.utilities.proxies._Copier")
-def test_path_argument_to_transfer(*_):
- """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist)."""
-
- class TransferPathWork(LightningWork):
- def run(self, *args, **kwargs):
- raise ExitAppException
-
- work = TransferPathWork()
-
- path1 = Path("exists-locally.txt")
- path2 = Path("exists-remotely.txt")
- path3 = Path("exists-nowhere.txt")
-
- path1.get = Mock()
- path2.get = Mock()
- path3.get = Mock()
-
- path1.exists_remote = Mock(return_value=False)
- path2.exists_remote = Mock(return_value=True)
- path3.exists_remote = Mock(return_value=False)
-
- path1._origin = "origin"
- path2._origin = "origin"
- path3._origin = "origin"
-
- call = {
- "args": (path1, path2),
- "kwargs": {"path3": path3},
- "call_hash": "any",
- "state": {
- "vars": {"_paths": {}, "_urls": {}},
- "calls": {
- CacheCallsKeys.LATEST_CALL_HASH: "any",
- "any": {
- "name": "run",
- "call_hash": "any",
- "use_args": False,
- "statuses": [{"stage": "requesting", "message": None, "reason": None, "timestamp": 1}],
- },
- },
- "changes": {},
- },
- }
-
- caller_queue = _MockQueue()
- caller_queue.put(call)
-
- runner = WorkRunner(
- work=work,
- work_name="name",
- caller_queue=caller_queue,
- delta_queue=_MockQueue(),
- readiness_queue=_MockQueue(),
- error_queue=_MockQueue(),
- request_queue=_MockQueue(),
- response_queue=_MockQueue(),
- copy_request_queue=_MockQueue(),
- copy_response_queue=_MockQueue(),
- )
-
- with contextlib.suppress(ExitAppException):
- runner()
-
- path1.exists_remote.assert_called_once()
- path1.get.assert_not_called()
-
- path2.exists_remote.assert_called_once()
- path2.get.assert_called_once()
-
- path3.exists_remote.assert_called()
- path3.get.assert_not_called()
-
-
-@pytest.mark.parametrize(
- ("origin", "exists_remote", "expected_get"),
- [
- (None, False, False),
- ("root.work", True, False),
- ("root.work", False, False),
- ("origin", True, True),
- ],
-)
-@mock.patch("lightning.app.utilities.proxies._Copier")
-def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get):
- """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist)."""
- path_mock = Mock()
- path_mock.origin_name = origin
- path_mock.exists_remote = Mock(return_value=exists_remote)
-
- class TransferPathWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.path = Path("test-path.txt")
-
- def run(self):
- raise ExitAppException
-
- def __getattr__(self, item):
- if item == "path":
- return path_mock
- return super().__getattr__(item)
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.work = TransferPathWork()
-
- def run(self):
- self.work.run()
-
- flow = Flow()
- _convert_paths_after_init(flow)
-
- call = {
- "args": (),
- "kwargs": {},
- "call_hash": "any",
- "state": {
- "vars": {"_paths": flow.work._paths, "_urls": {}},
- "calls": {
- CacheCallsKeys.LATEST_CALL_HASH: "any",
- "any": {
- "name": "run",
- "call_hash": "any",
- "use_args": False,
- "statuses": [{"stage": "requesting", "message": None, "reason": None, "timestamp": 1}],
- },
- },
- "changes": {},
- },
- }
-
- caller_queue = _MockQueue()
- caller_queue.put(call)
-
- runner = WorkRunner(
- work=flow.work,
- work_name=flow.work.name,
- caller_queue=caller_queue,
- delta_queue=_MockQueue(),
- readiness_queue=_MockQueue(),
- error_queue=_MockQueue(),
- request_queue=_MockQueue(),
- response_queue=_MockQueue(),
- copy_request_queue=_MockQueue(),
- copy_response_queue=_MockQueue(),
- )
- with contextlib.suppress(ExitAppException):
- runner()
-
- assert path_mock.get.call_count == expected_get
-
-
-def test_proxy_work_run_paths_replace_origin_lightning_work_by_their_name():
- class Work(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.path = None
-
- def run(self, path):
- assert isinstance(path._origin, str)
-
- class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w1 = Work()
- self.w = Work()
-
- def run(self):
- pass
-
- app = LightningApp(Flow())
- work = app.root.w
- caller_queue = _MockQueue("caller_queue")
- app.root.w1.path = Path(__file__)
- assert app.root.w1.path._origin == app.root.w1
- ProxyWorkRun(work.run, work.name, work, caller_queue)(path=app.root.w1.path)
- assert caller_queue._queue[0]["kwargs"]["path"]._origin == app.root.w1.name
-
-
-def test_persist_artifacts(tmp_path):
- """Test that the `persist_artifacts` utility copies the artifacts that exist to the persistent storage."""
-
- class ArtifactWork(LightningWork):
- def __init__(self):
- super().__init__()
- self.file = None
- self.folder = None
- self.not_my_path = None
- self.not_exists = None
-
- def run(self):
- # single file
- self.file = Path(tmp_path, "file.txt")
- self.file.write_text("single file")
- # folder with files
- self.folder = Path(tmp_path, "folder")
- self.folder.mkdir()
- Path(tmp_path, "folder", "file1.txt").write_text("file 1")
- Path(tmp_path, "folder", "file2.txt").write_text("file 2")
-
- # simulate a Path that was synced to this Work from another Work
- self.not_my_path = Path(tmp_path, "external.txt")
- self.not_my_path.touch()
- self.not_my_path._origin = Mock()
-
- self.not_exists = Path(tmp_path, "not-exists")
-
- work = ArtifactWork()
- work._name = "root.work"
-
- rel_tmpdir_path = Path(*tmp_path.parts[1:])
-
- assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "file.txt")
- assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "folder")
- assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "not-exists")
-
- work.run()
-
- with pytest.warns(UserWarning, match="1 artifacts could not be saved because they don't exist"):
- persist_artifacts(work)
-
- assert os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "file.txt")
- assert os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "folder")
- assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "not-exists")
- assert not os.path.exists(_artifacts_path(work) / rel_tmpdir_path / "external.txt")
-
-
-def test_work_state_observer():
- """Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been handled
- by the setattr."""
-
- class WorkWithoutSetattr(LightningWork):
- def __init__(self):
- super().__init__()
- self.var = 1
- self.list = []
- self.dict = {"counter": 0}
-
- def run(self, use_setattr=False, use_containers=False):
- if use_setattr:
- self.var += 1
- if use_containers:
- self.list.append(1)
- self.dict["counter"] += 1
-
- work = WorkWithoutSetattr()
- delta_queue = _MockQueue()
- observer = WorkStateObserver(work, delta_queue)
- setattr_proxy = LightningWorkSetAttrProxy(
- work=work,
- work_name="work_name",
- delta_queue=delta_queue,
- state_observer=observer,
- )
- work._setattr_replacement = setattr_proxy
-
- ##############################
- # 1. Simulate no state changes
- ##############################
- work.run(use_setattr=False, use_containers=False)
- assert len(delta_queue) == 0
-
- ############################
- # 2. Simulate a setattr call
- ############################
- work.run(use_setattr=True, use_containers=False)
-
- # this is necessary only in this test where we simulate the calls
- work._calls.clear()
- work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None})
-
- delta = delta_queue.get().delta.to_dict()
- assert delta["values_changed"] == {"root['vars']['var']": {"new_value": 2}}
- assert len(observer._delta_memory) == 1
-
- # The observer should not trigger any deltas being sent and only consume the delta memory
- assert len(delta_queue) == 0
- observer.run_once()
- assert len(delta_queue) == 0
- assert not observer._delta_memory
-
- ################################
- # 3. Simulate a container update
- ################################
- work.run(use_setattr=False, use_containers=True)
- assert len(delta_queue) == 0
- assert not observer._delta_memory
- observer.run_once()
- observer.run_once() # multiple runs should not affect how many deltas are sent unless there are changes
- delta = delta_queue.get().delta.to_dict()
- assert delta["values_changed"] == {"root['vars']['dict']['counter']": {"new_value": 1}}
- assert delta["iterable_item_added"] == {"root['vars']['list'][0]": 1}
-
- ##########################
- # 4. Simulate both updates
- ##########################
- work.run(use_setattr=True, use_containers=True)
-
- # this is necessary only in this test where we siumulate the calls
- work._calls.clear()
- work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None})
-
- delta = delta_queue.get().delta.to_dict()
- assert delta == {"values_changed": {"root['vars']['var']": {"new_value": 3}}}
- assert len(delta_queue) == 0
- assert len(observer._delta_memory) == 1
- observer.run_once()
-
- delta = delta_queue.get().delta.to_dict()
- assert delta["values_changed"] == {"root['vars']['dict']['counter']": {"new_value": 2}}
- assert delta["iterable_item_added"] == {"root['vars']['list'][1]": 1}
-
- assert len(delta_queue) == 0
- assert not observer._delta_memory
-
-
-class WorkState(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.vars = []
- self.counter = 0
-
- def run(self, *args):
- for counter in range(1, 11):
- self.vars.append(counter)
- self.counter = counter
-
-
-class FlowState(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = WorkState()
- self.counter = 1
-
- def run(self):
- self.w.run()
- if self.counter == 1:
- if len(self.w.vars) == 10 and self.w.counter == 10:
- self.w.vars = []
- self.w.counter = 0
- self.w.run("")
- self.counter = 2
- elif self.counter == 2 and len(self.w.vars) == 10 and self.w.counter == 10:
- self.stop()
-
-
-def test_state_observer():
- app = LightningApp(FlowState())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-@pytest.mark.parametrize(
- ("patch_constants", "environment", "expected_public_ip", "expected_private_ip"),
- [
- ({}, {}, "", "127.0.0.1"),
- ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "", "0.0.0.0"), # noqa: S104
- (
- {},
- {"LIGHTNING_NODE_IP": "85.44.2.25", "LIGHTNING_NODE_PRIVATE_IP": "10.10.10.5"},
- "85.44.2.25",
- "10.10.10.5",
- ),
- ],
- indirect=["patch_constants"],
-)
-def test_work_runner_sets_public_and_private_ip(patch_constants, environment, expected_public_ip, expected_private_ip):
- """Test that the WorkRunner updates the public and private address as soon as the Work starts running."""
-
- class Work(LightningWork):
- def run(self):
- pass
-
- work = Work()
- work_runner = WorkRunner(
- work,
- work.name,
- caller_queue=_MockQueue("caller_queue"),
- delta_queue=Mock(),
- readiness_queue=Mock(),
- error_queue=Mock(),
- request_queue=Mock(),
- response_queue=Mock(),
- copy_request_queue=Mock(),
- copy_response_queue=Mock(),
- )
-
- # Make a fake call
- call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c"
- work._calls[call_hash] = {
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "run_started_counter": 1,
- "statuses": [],
- }
- work_runner.caller_queue.put({
- "args": (),
- "kwargs": {},
- "call_hash": call_hash,
- "state": work.state,
- })
-
- with mock.patch.dict(os.environ, environment, clear=True):
- work_runner.setup()
- # The public ip address only becomes available once the hardware is up / the work is running.
- assert work.public_ip == ""
- assert work.internal_ip == ""
- with contextlib.suppress(Empty):
- work_runner.run_once()
- assert work.public_ip == expected_public_ip
- assert work.internal_ip == expected_private_ip
-
-
-class WorkBi(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.finished = False
- self.counter = 0
- self.counter_2 = 0
-
- def run(self):
- while not self.finished:
- self.counter_2 += 1
- time.sleep(0.1)
- self.counter = -1
- time.sleep(1)
-
-
-class FlowBi(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = WorkBi()
-
- def run(self):
- self.w.run()
- if not self.w.finished:
- self.w.counter += 1
- if self.w.counter > 3:
- self.w.finished = True
- if self.w.counter == -1 and self.w.has_succeeded:
- self.stop()
-
-
-def test_bi_directional_proxy():
- app = LightningApp(FlowBi())
- MultiProcessRuntime(app, start_server=False).dispatch()
-
-
-class WorkBi2(LightningWork):
- def __init__(self):
- super().__init__(parallel=True)
- self.finished = False
- self.counter = 0
- self.d = {}
-
- def run(self):
- self.counter -= 1
- while not self.finished:
- self.counter -= 1
- time.sleep(1)
-
-
-class FlowBi2(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w = WorkBi2()
-
- def run(self):
- self.w.run()
- if self.w.counter == 1:
- self.w.d["self.w.counter"] = 0
- if not self.w.finished:
- self.w.counter += 1
-
-
-def test_bi_directional_proxy_forbidden(monkeypatch):
- mock = MagicMock()
- monkeypatch.setattr(sys, "exit", mock)
- app = LightningApp(FlowBi2())
- MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.stage == AppStage.FAILED
- assert "A forbidden operation to update the work" in str(app.exception)
-
-
-class WorkDrive(LightningFlow):
- def __init__(self, drive):
- super().__init__()
- self.drive = drive
- self.path = Path("data")
-
- def run(self):
- pass
-
-
-class FlowDrive(LightningFlow):
- def __init__(self):
- super().__init__()
- self.data = Drive("lit://data")
- self.counter = 0
-
- def run(self):
- if not hasattr(self, "w"):
- self.w = WorkDrive(self.data)
- self.counter += 1
-
-
-def test_bi_directional_proxy_filtering():
- app = LightningApp(FlowDrive())
- app.root.run()
- assert app._extract_vars_from_component_name(app.root.w.name, app.state) == {}
diff --git a/tests/tests_app/utilities/test_safe_pickle.py b/tests/tests_app/utilities/test_safe_pickle.py
deleted file mode 100644
index 2c9e6d49a2448..0000000000000
--- a/tests/tests_app/utilities/test_safe_pickle.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import subprocess
-from pathlib import Path
-
-
-def test_safe_pickle_app():
- test_dir = Path(__file__).parent / "testdata"
- proc = subprocess.Popen(
- ["lightning_app", "run", "app", "safe_pickle_app.py", "--open-ui", "false"],
- stdout=subprocess.PIPE,
- cwd=test_dir,
- )
- stdout, _ = proc.communicate()
- assert "Exiting the pickling app successfully" in stdout.decode("UTF-8")
diff --git a/tests/tests_app/utilities/test_secrets.py b/tests/tests_app/utilities/test_secrets.py
deleted file mode 100644
index 1d28af4531b1a..0000000000000
--- a/tests/tests_app/utilities/test_secrets.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from typing import Dict, List
-from unittest import mock
-from unittest.mock import MagicMock
-
-import lightning.app
-import pytest
-from lightning.app.utilities.secrets import _names_to_ids
-from lightning_cloud.openapi import V1ListMembershipsResponse, V1ListSecretsResponse, V1Membership, V1Secret
-
-
-@pytest.mark.parametrize(
- ("secret_names", "secrets", "expected", "expected_exception"),
- [
- ([], [], {}, False),
- (
- ["first-secret", "second-secret"],
- [
- V1Secret(name="first-secret", id="1234"),
- ],
- {},
- True,
- ),
- (
- ["first-secret", "second-secret"],
- [V1Secret(name="first-secret", id="1234"), V1Secret(name="second-secret", id="5678")],
- {"first-secret": "1234", "second-secret": "5678"},
- False,
- ),
- ],
-)
-@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
-def test_names_to_ids(
- secret_names: List[str],
- secrets: List[V1Secret],
- expected: Dict[str, str],
- expected_exception: bool,
- monkeypatch,
-):
- class FakeLightningClient:
- def projects_service_list_memberships(self, *_, **__):
- return V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
-
- def secret_service_list_secrets(self, *_, **__):
- return V1ListSecretsResponse(secrets=secrets)
-
- monkeypatch.setattr(lightning.app.utilities.secrets, "LightningClient", FakeLightningClient)
-
- if expected_exception:
- with pytest.raises(ValueError):
- _names_to_ids(secret_names)
- else:
- assert _names_to_ids(secret_names) == expected
diff --git a/tests/tests_app/utilities/test_state.py b/tests/tests_app/utilities/test_state.py
deleted file mode 100644
index ec96c17339375..0000000000000
--- a/tests/tests_app/utilities/test_state.py
+++ /dev/null
@@ -1,335 +0,0 @@
-import os
-from re import escape
-from unittest import mock
-
-import lightning.app
-import pytest
-import requests
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.structures import Dict, List
-from lightning.app.utilities.app_helpers import AppStatePlugin, BaseStatePlugin
-from lightning.app.utilities.state import AppState
-from lightning_cloud.openapi import Externalv1LightningappInstance, V1LightningappInstanceStatus
-
-
-@mock.patch("lightning.app.utilities.state._configure_session", return_value=requests)
-def test_app_state_not_connected(_):
- """Test an error message when a disconnected AppState tries to access attributes."""
- state = AppState(port=8000)
- with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
- _ = state.value
- with pytest.raises(AttributeError, match="Failed to connect and fetch the app state"):
- state.value = 1
-
-
-@pytest.mark.parametrize(
- ("my_affiliation", "global_affiliation", "expected"),
- [
- (None, (), ()),
- ((), (), ()),
- ((), ("a", "b"), ()),
- (None, ("a", "b"), ("a", "b")),
- ],
-)
-@mock.patch("lightning.app.utilities.state._configure_session", return_value=requests)
-def test_app_state_affiliation(_, my_affiliation, global_affiliation, expected):
- AppState._MY_AFFILIATION = global_affiliation
- state = AppState(my_affiliation=my_affiliation)
- assert state._my_affiliation == expected
- AppState._MY_AFFILIATION = ()
-
-
-def test_app_state_state_access():
- """Test the many ways an AppState object can be accessed to set or get attributes on the state."""
- mocked_state = {
- "vars": {"root_var": "root"},
- "flows": {
- "child0": {
- "vars": {"child_var": 1},
- "flows": {},
- "works": {},
- }
- },
- "works": {
- "work0": {
- "vars": {"work_var": 2},
- "flows": {},
- "works": {},
- }
- },
- }
-
- state = AppState()
- state._state = state._last_state = mocked_state
-
- assert state.root_var == "root"
- assert isinstance(state.child0, AppState)
- assert isinstance(state.work0, AppState)
- assert state.child0.child_var == 1
- assert state.work0.work_var == 2
-
- with pytest.raises(AttributeError, match="Failed to access 'non_existent_var' through `AppState`."):
- _ = state.work0.non_existent_var
-
- with pytest.raises(AttributeError, match="Failed to access 'non_existent_var' through `AppState`."):
- state.work0.non_existent_var = 22
-
- # TODO: improve msg
- with pytest.raises(AttributeError, match="You shouldn't set the flows"):
- state.child0 = "child0"
-
- # TODO: verify with tchaton
- with pytest.raises(AttributeError, match="You shouldn't set the works"):
- state.work0 = "work0"
-
-
-@mock.patch("lightning.app.utilities.state.AppState.send_delta")
-def test_app_state_state_access_under_affiliation(*_):
- """Test the access to attributes when the state is restricted under the given affiliation."""
- mocked_state = {
- "vars": {"root_var": "root"},
- "flows": {
- "child0": {
- "vars": {"child0_var": 0},
- "flows": {
- "child1": {
- "vars": {"child1_var": 1},
- "flows": {
- "child2": {
- "vars": {"child2_var": 2},
- "flows": {},
- "works": {},
- },
- },
- "works": {},
- },
- },
- "works": {
- "work1": {
- "vars": {"work1_var": 11},
- },
- },
- },
- },
- "works": {},
- }
-
- # root-level affiliation
- state = AppState(my_affiliation=())
- state._store_state(mocked_state)
- assert isinstance(state.child0, AppState)
- assert state.child0.child0_var == 0
- assert state.child0.child1.child1_var == 1
- assert state.child0.child1.child2.child2_var == 2
-
- # one child deep
- state = AppState(my_affiliation=("child0",))
- state._store_state(mocked_state)
- assert state._state == mocked_state["flows"]["child0"]
- with pytest.raises(AttributeError, match="Failed to access 'child0' through `AppState`"):
- _ = state.child0
- assert state.child0_var == 0
- assert state.child1.child1_var == 1
- assert state.child1.child2.child2_var == 2
-
- # two flows deep
- state = AppState(my_affiliation=("child0", "child1"))
- state._store_state(mocked_state)
- assert state._state == mocked_state["flows"]["child0"]["flows"]["child1"]
- with pytest.raises(AttributeError, match="Failed to access 'child1' through `AppState`"):
- _ = state.child1
- state.child1_var = 111
- assert state.child1_var == 111
- assert state.child2.child2_var == 2
-
- # access to work
- state = AppState(my_affiliation=("child0", "work1"))
- state._store_state(mocked_state)
- assert state._state == mocked_state["flows"]["child0"]["works"]["work1"]
- with pytest.raises(AttributeError, match="Failed to access 'child1' through `AppState`"):
- _ = state.child1
- assert state.work1_var == 11
- state.work1_var = 111
- assert state.work1_var == 111
-
- # affiliation does not match state
- state = AppState(my_affiliation=("child1", "child0"))
- with pytest.raises(
- ValueError, match=escape("Failed to extract the state under the affiliation '('child1', 'child0')'")
- ):
- state._store_state(mocked_state)
-
-
-def test_app_state_repr():
- app_state = AppState()
- assert repr(app_state) == "None"
-
- app_state = AppState()
- app_state._store_state({"vars": {"x": 1, "y": 2}})
- assert repr(app_state) == "{'vars': {'x': 1, 'y': 2}}"
-
- app_state = AppState()
- app_state._store_state({"vars": {"x": 1, "y": 2}})
- assert repr(app_state.y) == "2"
-
- app_state = AppState()
- app_state._store_state({"vars": {}, "flows": {"child": {"vars": {"child_var": "child_val"}}}})
- assert repr(app_state.child) == "{'vars': {'child_var': 'child_val'}}"
-
-
-def test_app_state_bool():
- app_state = AppState()
- assert not bool(app_state)
-
- app_state = AppState()
- app_state._store_state({"vars": {"x": 1, "y": 2}})
- assert bool(app_state)
-
-
-class _CustomAppStatePlugin(BaseStatePlugin):
- def should_update_app(self, deep_diff):
- pass
-
- def get_context(self):
- pass
-
- def render_non_authorized(self):
- pass
-
-
-def test_attach_plugin():
- """Test how plugins get attached to the AppState and the default behavior when no plugin is specified."""
- app_state = AppState()
- assert isinstance(app_state._plugin, AppStatePlugin)
-
- app_state = AppState(plugin=_CustomAppStatePlugin())
- assert isinstance(app_state._plugin, _CustomAppStatePlugin)
-
-
-@mock.patch("lightning.app.utilities.state._configure_session", return_value=requests)
-def test_app_state_connection_error(_):
- """Test an error message when a connection to retrieve the state can't be established."""
- app_state = AppState(port=8000)
- with pytest.raises(AttributeError, match=r"Failed to connect and fetch the app state\. Is the app running?"):
- app_state._request_state()
-
- with pytest.raises(AttributeError, match=r"Failed to connect and fetch the app state\. Is the app running?"):
- app_state.var = 1
-
-
-class Work(LightningWork):
- def __init__(self):
- super().__init__()
- self.counter = 0
-
- def run(self):
- self.counter += 1
-
-
-class Flow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.should_start = False
- self.w = Work()
-
- def run(self):
- if self.should_start:
- self.w.run()
- self.stop()
-
-
-class MockResponse:
- def __init__(self, state, status_code):
- self._state = state
- self.status_code = status_code
-
- def json(self):
- return self._state
-
-
-def test_get_send_request(monkeypatch):
- app = LightningApp(Flow())
- monkeypatch.setattr(lightning.app.utilities.state, "_configure_session", mock.MagicMock())
-
- state = AppState(plugin=AppStatePlugin())
- state._session.get._mock_return_value = MockResponse(app.state_with_changes, 500)
- state._request_state()
- state._session.get._mock_return_value = MockResponse(app.state_with_changes, 200)
- state._request_state()
- assert state._my_affiliation == ()
- with pytest.raises(Exception, match="The response from"):
- state._session.post._mock_return_value = MockResponse(app.state_with_changes, 500)
- state.w.counter = 1
- state._session.post._mock_return_value = MockResponse(app.state_with_changes, 200)
- state.w.counter = 1
-
-
-@mock.patch.dict(
- os.environ,
- {
- "LIGHTNING_APP_STATE_URL": "https://lightning-cloud.com",
- "LIGHTNING_CLOUD_PROJECT_ID": "test-project-id",
- "LIGHTNING_CLOUD_APP_ID": "test-app-id",
- },
-)
-@mock.patch("lightning.app.utilities.state.LightningClient")
-def test_app_state_with_env_var(mock_client):
- mock_client().lightningapp_instance_service_get_lightningapp_instance.return_value = Externalv1LightningappInstance(
- status=V1LightningappInstanceStatus(ip_address="test-ip"),
- )
- state = AppState()
- url = state._url
-
- mock_client().lightningapp_instance_service_get_lightningapp_instance.assert_called_once_with(
- "test-project-id",
- "test-app-id",
- )
-
- assert url == "http://test-ip:8080"
- assert not state._port
-
-
-@mock.patch.dict(os.environ, {})
-def test_app_state_with_no_env_var(**__):
- state = AppState()
- assert state._host == "http://127.0.0.1"
- assert state._port == 7501
- assert state._url == "http://127.0.0.1:7501"
-
-
-class FlowStructures(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_list = List(Work(), Work())
- self.w_dict = Dict(**{"toto": Work(), "toto_2": Work()})
-
- def run(self):
- self.stop()
-
-
-class FlowStructuresEmpty(LightningFlow):
- def __init__(self):
- super().__init__()
- self.w_list = List()
- self.w_dict = Dict()
-
- def run(self):
- self.stop()
-
-
-def test_app_state_with_structures():
- app = LightningApp(FlowStructures())
- state = AppState()
- state._last_state = app.state
- state._state = app.state
- assert state.w_list["0"].counter == 0
- assert len(state.w_list) == 2
- assert state.w_dict["toto"].counter == 0
- assert [k for k, _ in state.w_dict.items()] == ["toto", "toto_2"]
- assert [k for k, _ in state.w_list.items()] == ["0", "1"]
-
- app = LightningApp(FlowStructuresEmpty())
- state = AppState()
- state._last_state = app.state
- state._state = app.state
- assert state.w_list
diff --git a/tests/tests_app/utilities/test_tracer.py b/tests/tests_app/utilities/test_tracer.py
deleted file mode 100644
index a48d2ca4e2566..0000000000000
--- a/tests/tests_app/utilities/test_tracer.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import os
-import sys
-
-from lightning.app.testing.helpers import _RunIf
-from lightning.app.utilities.tracer import Tracer
-
-from tests_app import _PROJECT_ROOT
-
-
-@_RunIf(pl=True)
-def test_tracer():
- from pytorch_lightning import Trainer
-
- def pre_fn(self, *args, **kwargs):
- kwargs["fast_dev_run"] = True
- return {}, args, kwargs
-
- def post_fn(self, ret):
- return {}, ret
-
- tracer = Tracer()
- tracer.add_traced(Trainer, "__init__", pre_fn=pre_fn, post_fn=post_fn)
- traced_file = os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_trainer.py")
- assert os.path.exists(traced_file)
- # This is required to get the right sys.argv for `runpy``.
- sys.argv = [traced_file]
- tracer.trace(traced_file)
diff --git a/tests/tests_app/utilities/test_tree.py b/tests/tests_app/utilities/test_tree.py
deleted file mode 100644
index 62ee29e946f21..0000000000000
--- a/tests/tests_app/utilities/test_tree.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import pytest
-from lightning.app import LightningFlow, LightningWork
-from lightning.app.testing.helpers import EmptyFlow, EmptyWork
-from lightning.app.utilities.tree import breadth_first
-
-
-class LeafFlow(EmptyFlow):
- pass
-
-
-class LeafWork(EmptyWork):
- pass
-
-
-class SimpleFlowTree(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.simple_flow_left = LeafFlow()
- self.simple_flow_right = LeafFlow()
-
-
-class SimpleWorkTree(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.simple_work_left = LeafWork()
- self.simple_work_right = LeafWork()
-
-
-class MixedTree(EmptyFlow):
- def __init__(self):
- super().__init__()
- self.mixed_left = SimpleFlowTree()
- self.work_tree = SimpleWorkTree()
- self.mixed_right = SimpleFlowTree()
-
-
-@pytest.mark.parametrize(
- ("input_tree", "types", "expected_sequence"),
- [
- (LeafFlow(), (LightningFlow,), ["root"]),
- (LeafWork(), (LightningFlow,), []),
- (
- SimpleFlowTree(),
- (LightningFlow,),
- [
- "root",
- "root.simple_flow_left",
- "root.simple_flow_right",
- ],
- ),
- (SimpleWorkTree(), (LightningFlow,), ["root"]),
- (
- SimpleWorkTree(),
- (LightningFlow, LightningWork),
- [
- "root",
- "root.simple_work_left",
- "root.simple_work_right",
- ],
- ),
- (
- MixedTree(),
- (LightningFlow,),
- [
- "root",
- "root.mixed_left",
- "root.mixed_right",
- "root.work_tree",
- "root.mixed_left.simple_flow_left",
- "root.mixed_left.simple_flow_right",
- "root.mixed_right.simple_flow_left",
- "root.mixed_right.simple_flow_right",
- ],
- ),
- (
- MixedTree(),
- (LightningWork,),
- [
- "root.work_tree.simple_work_left",
- "root.work_tree.simple_work_right",
- ],
- ),
- (
- MixedTree(),
- (LightningFlow, LightningWork),
- [
- "root",
- "root.mixed_left",
- "root.mixed_right",
- "root.work_tree",
- "root.mixed_left.simple_flow_left",
- "root.mixed_left.simple_flow_right",
- "root.mixed_right.simple_flow_left",
- "root.mixed_right.simple_flow_right",
- "root.work_tree.simple_work_left",
- "root.work_tree.simple_work_right",
- ],
- ),
- ],
-)
-def test_breadth_first(input_tree, types, expected_sequence):
- assert [node.name for node in breadth_first(input_tree, types=types)] == expected_sequence
diff --git a/tests/tests_app/utilities/testdata/app_commands/app_commands_to_ignore.txt b/tests/tests_app/utilities/testdata/app_commands/app_commands_to_ignore.txt
deleted file mode 100644
index 60ee57ca921a3..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/app_commands_to_ignore.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/usr/bin/python
-#!/usr/local/bin/python
-#!/usr/bin/env python
-#!/usr/bin/env python3
diff --git a/tests/tests_app/utilities/testdata/app_commands/bang_not_at_start_of_line.txt b/tests/tests_app/utilities/testdata/app_commands/bang_not_at_start_of_line.txt
deleted file mode 100644
index a937beff29d1f..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/bang_not_at_start_of_line.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-# This is somefile.py! should not execute this!
-# !echo "foo"
diff --git a/tests/tests_app/utilities/testdata/app_commands/command_after_first_non_comment_line.txt b/tests/tests_app/utilities/testdata/app_commands/command_after_first_non_comment_line.txt
deleted file mode 100644
index 1cd80f15779df..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/command_after_first_non_comment_line.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-
-# !echo "foo"
-import lighting
-# !echo "bar"
diff --git a/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_1.txt b/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_1.txt
deleted file mode 100644
index f98505df2bef6..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_1.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-# !echo "foo"
-# some other explanation
-# !echo "bar"
-# another comment
diff --git a/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_2.txt b/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_2.txt
deleted file mode 100644
index 0c1641c234f8d..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/commands_with_mixed_comments_2.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-# filename.py
-# !echo "foo"
-# some other explanation
-# !echo "bar"
-# another comment
diff --git a/tests/tests_app/utilities/testdata/app_commands/multiple_commands.txt b/tests/tests_app/utilities/testdata/app_commands/multiple_commands.txt
deleted file mode 100644
index 52e45d039fa44..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/multiple_commands.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-# !echo "foo"
-# !echo "bar"
diff --git a/tests/tests_app/utilities/testdata/app_commands/multiple_spaces_between_band_and_command.txt b/tests/tests_app/utilities/testdata/app_commands/multiple_spaces_between_band_and_command.txt
deleted file mode 100644
index cf2599bd9ae49..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/multiple_spaces_between_band_and_command.txt
+++ /dev/null
@@ -1 +0,0 @@
-# ! echo "foo"
diff --git a/tests/tests_app/utilities/testdata/app_commands/single_command.txt b/tests/tests_app/utilities/testdata/app_commands/single_command.txt
deleted file mode 100644
index 46ff010622e9c..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/single_command.txt
+++ /dev/null
@@ -1 +0,0 @@
-# !echo "foo"
diff --git a/tests/tests_app/utilities/testdata/app_commands/space_between_bang_and_command.txt b/tests/tests_app/utilities/testdata/app_commands/space_between_bang_and_command.txt
deleted file mode 100644
index d2df15d9c19ec..0000000000000
--- a/tests/tests_app/utilities/testdata/app_commands/space_between_bang_and_command.txt
+++ /dev/null
@@ -1 +0,0 @@
-# ! echo "foo"
diff --git a/tests/tests_app/utilities/testdata/safe_pickle_app.py b/tests/tests_app/utilities/testdata/safe_pickle_app.py
deleted file mode 100644
index 7a02b0d1ade52..0000000000000
--- a/tests/tests_app/utilities/testdata/safe_pickle_app.py
+++ /dev/null
@@ -1,63 +0,0 @@
-"""
-This app tests three things
-1. Can a work pickle `self`
-2. Can the pickled work be unpickled in another work
-3. Can the pickled work be unpickled from a script
-"""
-
-import subprocess
-from pathlib import Path
-
-from lightning.app import LightningApp, LightningFlow, LightningWork
-from lightning.app.utilities import safe_pickle
-
-
-class SelfPicklingWork(LightningWork):
- def run(self):
- with open("work.pkl", "wb") as f:
- safe_pickle.dump(self, f)
-
- def get_test_string(self):
- return f"Hello from {self.__class__.__name__}!"
-
-
-class WorkThatLoadsPickledWork(LightningWork):
- def run(self):
- with open("work.pkl", "rb") as f:
- work = safe_pickle.load(f)
- assert work.get_test_string() == "Hello from SelfPicklingWork!"
-
-
-script_load_pickled_work = """
-import pickle
-work = pickle.load(open("work.pkl", "rb"))
-print(work.get_test_string())
-"""
-
-
-class RootFlow(LightningFlow):
- def __init__(self):
- super().__init__()
- self.self_pickling_work = SelfPicklingWork()
- self.work_that_loads_pickled_work = WorkThatLoadsPickledWork()
-
- def run(self):
- self.self_pickling_work.run()
- self.work_that_loads_pickled_work.run()
-
- with open("script_that_loads_pickled_work.py", "w") as f:
- f.write(script_load_pickled_work)
-
- # read the output from subprocess
- proc = subprocess.Popen(["python", "script_that_loads_pickled_work.py"], stdout=subprocess.PIPE)
- assert "Hello from SelfPicklingWork" in proc.stdout.read().decode("UTF-8")
-
- # deleting the script
- Path("script_that_loads_pickled_work.py").unlink()
- # deleting the pkl file
- Path("work.pkl").unlink()
-
- self.stop("Exiting the pickling app successfully!!")
-
-
-app = LightningApp(RootFlow())
diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py
index 4b2265670b8bf..e323ada908cd1 100644
--- a/tests/tests_fabric/accelerators/test_cuda.py
+++ b/tests/tests_fabric/accelerators/test_cuda.py
@@ -25,8 +25,6 @@
CUDAAccelerator,
_check_cuda_matmul_precision,
find_usable_cuda_devices,
- is_cuda_available,
- num_cuda_devices,
)
from tests_fabric.helpers.runif import RunIf
@@ -67,18 +65,6 @@ def test_set_cuda_device(_, set_device_mock):
set_device_mock.assert_called_once_with(device)
-@mock.patch("lightning.fabric.accelerators.cuda._device_count_nvml", return_value=-1)
-@mock.patch("torch.cuda.is_available", return_value=True)
-@mock.patch("torch.cuda.device_count", return_value=100)
-def test_num_cuda_devices_without_nvml(*_):
- """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining
- CUDA availability."""
- num_cuda_devices.cache_clear()
- assert is_cuda_available()
- assert num_cuda_devices() == 100
- num_cuda_devices.cache_clear()
-
-
@mock.patch.dict(os.environ, {}, clear=True)
def test_force_nvml_based_cuda_check():
"""Test that we force PyTorch to use the NVML-based CUDA checks."""
diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py
index 7dbd8da055995..446994167d0a1 100644
--- a/tests/tests_fabric/conftest.py
+++ b/tests/tests_fabric/conftest.py
@@ -14,6 +14,7 @@
import os
import sys
import threading
+from pathlib import Path
from typing import List
from unittest.mock import Mock
@@ -22,7 +23,7 @@
import torch.distributed
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
-from lightning.fabric.utilities.distributed import _distributed_is_initialized
+from lightning.fabric.utilities.distributed import _destroy_dist_connection
if sys.version_info >= (3, 9):
from concurrent.futures.process import _ExecutorManagerThread
@@ -77,8 +78,7 @@ def restore_env_variables():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
- if _distributed_is_initialized():
- torch.distributed.destroy_process_group()
+ _destroy_dist_connection()
@pytest.fixture(autouse=True)
@@ -101,7 +101,11 @@ def thread_police_duuu_daaa_duuu_daaa():
assert not thread.is_alive()
elif isinstance(thread, _ChildProcessObserver):
thread.join(timeout=10)
- elif thread.name == "QueueFeederThread": # tensorboardX
+ elif (
+ thread.name == "QueueFeederThread" # tensorboardX
+ or thread.name == "QueueManagerThread" # torch.compile
+ or "(_read_thread)" in thread.name # torch.compile
+ ):
thread.join(timeout=20)
elif (
sys.version_info >= (3, 9)
@@ -185,6 +189,17 @@ def caplog(caplog):
lightning_logger.propagate = propagate
+@pytest.fixture(autouse=True)
+def leave_no_artifacts_behind():
+ tests_root = Path(__file__).parent.parent
+ files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
+ yield
+ files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
+ difference = files_after - files_before
+ difference = {str(f.relative_to(tests_root)) for f in difference}
+ assert not difference, f"Test left artifacts behind: {difference}"
+
+
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
"""An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`"""
initial_size = len(items)
diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py
index d8929ca527e80..fa685241ea1b5 100644
--- a/tests/tests_fabric/loggers/test_tensorboard.py
+++ b/tests/tests_fabric/loggers/test_tensorboard.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import os
from argparse import Namespace
from unittest import mock
@@ -213,8 +212,7 @@ def test_tensorboard_finalize(monkeypatch, tmp_path):
logger.experiment.close.assert_called()
-@mock.patch("lightning.fabric.loggers.tensorboard.log")
-def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
+def test_tensorboard_with_symlink(tmp_path, monkeypatch):
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
relative paths."""
monkeypatch.chdir(tmp_path) # need to use relative paths
@@ -226,16 +224,3 @@ def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
logger = TensorBoardLogger(root_dir=dest, name="")
_ = logger.version
-
- log.warning.assert_not_called()
-
-
-def test_tensorboard_missing_folder_warning(tmp_path, caplog):
- """Verify that the logger throws a warning for invalid directory."""
- name = "fake_dir"
- logger = TensorBoardLogger(root_dir=tmp_path, name=name)
-
- with caplog.at_level(logging.WARNING):
- assert logger.version == 0
-
- assert "Missing logger folder:" in caplog.text
diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py
index 34f14b8871ea3..93d53eb406f71 100644
--- a/tests/tests_fabric/plugins/precision/test_amp.py
+++ b/tests/tests_fabric/plugins/precision/test_amp.py
@@ -17,11 +17,13 @@
import pytest
import torch
from lightning.fabric.plugins.precision.amp import MixedPrecision
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
def test_amp_precision_default_scaler():
precision = MixedPrecision(precision="16-mixed", device=Mock())
- assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
+ scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
+ assert isinstance(precision.scaler, scaler_cls)
def test_amp_precision_scaler_with_bf16():
@@ -36,7 +38,8 @@ def test_amp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
precision = MixedPrecision(precision="16-mixed", device="cuda")
assert precision.device == "cuda"
- assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
+ scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
+ assert isinstance(precision.scaler, scaler_cls)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16
diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py
index 5d88a7d9babb9..aa6c6cfce4504 100644
--- a/tests/tests_fabric/plugins/precision/test_amp_integration.py
+++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py
@@ -17,6 +17,7 @@
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from tests_fabric.helpers.runif import RunIf
@@ -82,7 +83,8 @@ def run(fused=False):
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)
model, optimizer = fabric.setup(model, optimizer)
- assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler)
+ scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
+ assert isinstance(fabric._precision.scaler, scaler_cls)
data = torch.randn(10, 10, device="cuda")
target = torch.randn(10, 10, device="cuda")
diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
index ec02796b4b51c..b8b9020b201a7 100644
--- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
+++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py
@@ -22,6 +22,7 @@
from lightning.fabric.connector import _Connector
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.utilities.init import _materialize_meta_tensors
+from lightning.fabric.utilities.load import _lazy_load
from tests_fabric.helpers.runif import RunIf
@@ -93,7 +94,7 @@ def __init__(self):
precision.convert_module(model)
-@RunIf(min_cuda_gpus=1)
+@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
@@ -148,7 +149,7 @@ def __init__(self):
assert model.l.weight.dtype == expected
-@RunIf(min_cuda_gpus=1, min_torch="2.1")
+@RunIf(min_cuda_gpus=1)
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
@@ -225,8 +226,49 @@ def __init__(self):
model = MyModel()
ckpt_path = tmp_path / "foo.ckpt"
torch.save(state_dict, ckpt_path)
- torch.load(str(ckpt_path), mmap=True)
+ torch.load(str(ckpt_path), mmap=True, weights_only=True)
keys = model.load_state_dict(state_dict, strict=True, assign=True) # quantizes
assert not keys.missing_keys
assert model.l.weight.device.type == "cuda"
assert model.l.weight.dtype == expected
+
+
+@RunIf(min_cuda_gpus=1, max_torch="2.4")
+@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
+def test_load_quantized_checkpoint(tmp_path):
+ """Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
+
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(16, 16, bias=False)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
+ model = Model()
+ model = fabric.setup(model)
+ model(torch.randn(2, 16, device=fabric.device))
+ state_dict = model.state_dict()
+ # The checkpoint contains quantized weights
+ assert state_dict["linear.weight"].dtype == torch.uint8
+ assert state_dict["linear.weight"].shape == (128, 1)
+ torch.save(state_dict, tmp_path / "checkpoint.pt")
+
+ fabric = Fabric(accelerator="cuda", devices=1, plugins=BitsandbytesPrecision("nf4-dq"))
+ model = Model()
+ model = fabric.setup(model)
+ state_dict = torch.load(tmp_path / "checkpoint.pt", weights_only=True)
+ model.load_state_dict(state_dict)
+ assert model.linear.weight.dtype == torch.uint8
+ assert model.linear.weight.shape == (128, 1)
+ # Shapes match during forward (weight is being dequantized during forward)
+ model(torch.randn(2, 16, device=fabric.device))
+
+ # Test with lazy load (LitGPT uses this)
+ # TODO: Replace `_lazy_load` with `torch.load(..., mmap=True)` in LitGPT
+ state_dict = _lazy_load(tmp_path / "checkpoint.pt")
+ model.load_state_dict(state_dict)
+ assert model.linear.weight.dtype == torch.uint8
+ assert model.linear.weight.shape == (128, 1)
diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py
index 74c1034518c39..e42df493dd725 100644
--- a/tests/tests_fabric/plugins/precision/test_fsdp.py
+++ b/tests/tests_fabric/plugins/precision/test_fsdp.py
@@ -26,25 +26,9 @@
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
- pytest.param(
- "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
- ),
- pytest.param(
- "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
- ),
- pytest.param(
- "bf16-mixed",
- (torch.float32, torch.bfloat16, torch.bfloat16),
- marks=RunIf(min_torch="2.0"),
- id="bf16-mixed-ge2_0",
- ),
- pytest.param(
- "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
- ),
- pytest.param(
- "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0"
- ),
- pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"),
+ ("16-mixed", (torch.float32, torch.float16, torch.float16)),
+ ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
+ ("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
@@ -74,8 +58,10 @@ def test_fsdp_precision_scaler_with_bf16():
@RunIf(min_cuda_gpus=1)
def test_fsdp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16."""
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
precision = FSDPPrecision(precision="16-mixed")
- assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
+ assert isinstance(precision.scaler, ShardedGradScaler)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16
diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py
index beea7eccb69c2..56d9875dfefed 100644
--- a/tests/tests_fabric/strategies/test_ddp.py
+++ b/tests/tests_fabric/strategies/test_ddp.py
@@ -23,7 +23,6 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from torch.nn.parallel import DistributedDataParallel
from tests_fabric.helpers.runif import RunIf
@@ -128,7 +127,7 @@ def __instancecheck__(self, instance):
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
- expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
+ expected_device = parallel_devices[1]
strategy = DDPStrategy(
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py
index 65eaacde2ff2c..a7ed09b00b09e 100644
--- a/tests/tests_fabric/strategies/test_ddp_integration.py
+++ b/tests/tests_fabric/strategies/test_ddp_integration.py
@@ -19,7 +19,8 @@
import pytest
import torch
from lightning.fabric import Fabric
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
+from lightning_utilities.core.imports import RequirementCache
+from torch._dynamo import OptimizedModule
from torch.nn.parallel.distributed import DistributedDataParallel
from tests_fabric.helpers.runif import RunIf
@@ -27,6 +28,10 @@
from tests_fabric.test_fabric import BoringModel
+@pytest.mark.skipif(
+ RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"),
+ reason="torch.distributed not compatible with numpy>=2.0",
+)
@pytest.mark.parametrize(
"accelerator",
[
@@ -70,16 +75,11 @@ def assert_params_equal(params0, params1):
assert_params_equal(params_before, wrapped_model.parameters())
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
-@mock.patch(
- "lightning.fabric.wrappers.torch.compile",
- Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
-)
+@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
+@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
- from torch._dynamo import OptimizedModule
-
fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp")
fabric.launch()
diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py
index 2476cd3961504..3be535effa078 100644
--- a/tests/tests_fabric/strategies/test_deepspeed_integration.py
+++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py
@@ -312,11 +312,11 @@ def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
single_ckpt_path = checkpoint_path / "single_model.pt"
# the tag is hardcoded in DeepSpeedStrategy
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
- state_dict = torch.load(single_ckpt_path)
+ state_dict = torch.load(single_ckpt_path, weights_only=False)
else:
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
- state_dict = torch.load(single_ckpt_path)["module"]
+ state_dict = torch.load(single_ckpt_path, weights_only=False)["module"]
model = model.cpu()
diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py
index 572bbd20d357c..e50abb1882870 100644
--- a/tests/tests_fabric/strategies/test_dp.py
+++ b/tests/tests_fabric/strategies/test_dp.py
@@ -74,6 +74,7 @@ def __instancecheck__(self, instance):
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"precision",
[
diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py
index 3f2d02e06be2a..0c46e7ac1763c 100644
--- a/tests/tests_fabric/strategies/test_fsdp.py
+++ b/tests/tests_fabric/strategies/test_fsdp.py
@@ -16,7 +16,6 @@
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
-import lightning.fabric
import pytest
import torch
import torch.nn as nn
@@ -26,24 +25,22 @@
from lightning.fabric.strategies.fsdp import (
_FSDPBackwardSyncControl,
_get_full_state_dict_context,
- _has_meta_device_parameters,
_is_sharded_checkpoint,
)
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim import Adam
-from tests_fabric.helpers.runif import RunIf
-
-def test_fsdp_custom_mixed_precision():
+def test_custom_mixed_precision():
"""Test that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
-def test_fsdp_cpu_offload():
+def test_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
@@ -55,7 +52,7 @@ def test_fsdp_cpu_offload():
assert strategy.cpu_offload == config
-def test_fsdp_sharding_strategy():
+def test_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy
@@ -74,9 +71,8 @@ def test_fsdp_sharding_strategy():
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
-@RunIf(min_torch="2.0")
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
-def test_fsdp_hybrid_shard_configuration(sharding_strategy):
+def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
@@ -89,6 +85,11 @@ def test_fsdp_hybrid_shard_configuration(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["process_group"] is process_group
+ monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
+ with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
+ FSDPStrategy(device_mesh=Mock())
+
+ monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
@@ -98,7 +99,7 @@ def test_fsdp_hybrid_shard_configuration(sharding_strategy):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
-def test_fsdp_checkpoint_io_unsupported():
+def test_checkpoint_io_unsupported():
"""Test that the FSDP strategy does not support the `CheckpointIO` plugin."""
strategy = FSDPStrategy()
with pytest.raises(NotImplementedError, match="does not use the `CheckpointIO` plugin"):
@@ -108,24 +109,8 @@ def test_fsdp_checkpoint_io_unsupported():
strategy.checkpoint_io = Mock()
-@pytest.mark.parametrize("torch_ge_2_0", [False, True])
-def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
- """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
- module = nn.Linear(2, 2)
- with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0):
- strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
- bad_optimizer = Adam(module.parameters())
-
- if torch_ge_2_0:
- strategy.setup_optimizer(bad_optimizer)
- else:
- with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
- strategy.setup_optimizer(bad_optimizer)
-
-
-@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.setup_module")
-def test_fsdp_setup_use_orig_params(_):
+def test_setup_use_orig_params(_):
module = nn.Linear(2, 2)
optimizer = Adam(module.parameters())
@@ -141,7 +126,7 @@ def test_fsdp_setup_use_orig_params(_):
assert strategy._fsdp_kwargs["use_orig_params"]
-def test_fsdp_no_backward_sync():
+def test_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
FullyShardedDataParallel."""
@@ -162,14 +147,7 @@ def test_fsdp_no_backward_sync():
module.no_sync.assert_called_once()
-def test_fsdp_activation_checkpointing_support(monkeypatch):
- """Test that we error out if activation checkpointing requires a newer PyTorch version."""
- monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
- with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
- FSDPStrategy(activation_checkpointing_policy=Mock())
-
-
-def test_fsdp_activation_checkpointing():
+def test_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
@@ -185,28 +163,13 @@ def __init__(self):
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
- if _TORCH_GREATER_EQUAL_2_1:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
- assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
- assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
+ strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
+ assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
+ assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
- strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
- assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
- assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
- else:
- strategy = FSDPStrategy(activation_checkpointing=Block1)
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
-
- strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
+ strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
+ assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
+ assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
@@ -216,7 +179,7 @@ def __init__(self):
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
-def test_fsdp_forbidden_precision_raises():
+def test_forbidden_precision_raises():
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
FSDPStrategy(precision=HalfPrecision())
@@ -225,7 +188,7 @@ def test_fsdp_forbidden_precision_raises():
strategy.precision = HalfPrecision()
-def test_fsdp_grad_clipping_norm_error():
+def test_grad_clipping_norm_error():
strategy = FSDPStrategy()
with pytest.raises(
TypeError,
@@ -234,21 +197,19 @@ def test_fsdp_grad_clipping_norm_error():
strategy.clip_gradients_norm(Mock(), Mock(), Mock())
-@RunIf(min_torch="2.0.0")
-def test_fsdp_save_checkpoint_storage_options(tmp_path):
+def test_save_checkpoint_storage_options(tmp_path):
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
strategy = FSDPStrategy()
with pytest.raises(TypeError, match=escape("FSDPStrategy.save_checkpoint(..., storage_options=...)` is not")):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
-@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
-def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
+def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -305,9 +266,7 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
assert path.is_dir()
-@RunIf(min_torch="2.0.0")
-@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
-def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
+def test_save_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only save one FSDP model per checkpoint."""
strategy = FSDPStrategy()
@@ -315,7 +274,7 @@ def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={})
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
- strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
+ strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple FSDP models
model1 = Mock(spec=FullyShardedDataParallel)
@@ -326,8 +285,7 @@ def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
-@RunIf(min_torch="2.0.0")
-def test_fsdp_load_checkpoint_no_state(tmp_path):
+def test_load_checkpoint_no_state(tmp_path):
"""Test that the FSDP strategy can't load the full state without access to a model instance from the user."""
strategy = FSDPStrategy()
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state=None")):
@@ -336,10 +294,10 @@ def test_fsdp_load_checkpoint_no_state(tmp_path):
strategy.load_checkpoint(path=tmp_path, state={})
-@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
-@mock.patch("lightning.fabric.strategies.fsdp._lazy_load", Mock())
-def test_fsdp_load_checkpoint_one_fsdp_module_required(tmp_path):
+@mock.patch("lightning.fabric.strategies.model_parallel._lazy_load", Mock())
+@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
+def test_load_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only load one FSDP model per checkpoint."""
strategy = FSDPStrategy()
@@ -359,14 +317,14 @@ def test_fsdp_load_checkpoint_one_fsdp_module_required(tmp_path):
# A raw nn.Module instead of a dictionary is ok
model = Mock(spec=nn.Module)
+ model.parameters.return_value = [torch.zeros(2, 1)]
path = tmp_path / "full.ckpt"
path.touch()
strategy.load_checkpoint(path=path, state=model)
-@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
-def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
+def test_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy = FSDPStrategy(state_dict_type="invalid")
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
@@ -374,8 +332,7 @@ def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model": model})
-@RunIf(min_torch="2.0.0")
-def test_fsdp_load_unknown_checkpoint_type(tmp_path):
+def test_load_unknown_checkpoint_type(tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
strategy = FSDPStrategy()
model = Mock(spec=FullyShardedDataParallel)
@@ -386,8 +343,7 @@ def test_fsdp_load_unknown_checkpoint_type(tmp_path):
strategy.load_checkpoint(path=path, state={"model": model})
-@RunIf(min_torch="2.0.0")
-def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path):
+def test_load_raw_checkpoint_validate_single_file(tmp_path):
"""Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
strategy = FSDPStrategy()
model = Mock(spec=nn.Module)
@@ -397,8 +353,7 @@ def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path):
strategy.load_checkpoint(path=path, state=model)
-@RunIf(min_torch="2.0.0")
-def test_fsdp_load_raw_checkpoint_optimizer_unsupported(tmp_path):
+def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
"""Validate that the FSDP strategy does not yet support loading the raw PyTorch state-dict for an optimizer."""
strategy = FSDPStrategy()
optimizer = Mock(spec=torch.optim.Optimizer)
@@ -424,35 +379,13 @@ def test_set_timeout(init_process_group_mock):
)
-def test_has_meta_device_parameters():
- """Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
- optimizers."""
- # nn.Module
- module = nn.Linear(2, 2)
- meta_module = nn.Linear(2, 2, device="meta")
- assert not _has_meta_device_parameters(module)
- assert _has_meta_device_parameters(meta_module)
- assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
- # optim.Optimizer
- optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
- meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
- assert not _has_meta_device_parameters(optimizer)
- assert _has_meta_device_parameters(meta_optimizer)
- # unsupported objects
- with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
- _has_meta_device_parameters(None)
-
-
-@RunIf(min_torch="2.0")
-@pytest.mark.parametrize("torch_ge_2_1", [True, False])
@mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type")
-def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1):
- """Test that the state dict context manager handles CPU offloading depending on the PyTorch version."""
- monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1)
+def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch):
+ """Test that the state dict context manager handles CPU offloading."""
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1):
- assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config
- assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config
+ assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
+ assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config
set_type_mock.reset_mock()
diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py
index 4a971294a326d..0697c3043d496 100644
--- a/tests/tests_fabric/strategies/test_fsdp_integration.py
+++ b/tests/tests_fabric/strategies/test_fsdp_integration.py
@@ -23,12 +23,9 @@
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
-from lightning.fabric.utilities.imports import (
- _TORCH_GREATER_EQUAL_2_0,
- _TORCH_GREATER_EQUAL_2_1,
-)
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.fabric.wrappers import _FabricOptimizer
+from torch._dynamo import OptimizedModule
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
from torch.nn import Parameter
@@ -121,10 +118,10 @@ def get_model(self):
return model
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
+@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
-def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
+def test_train_save_load(tmp_path, manual_wrapping, precision):
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
trainer_cls = _TrainerManualWrapping if manual_wrapping else _Trainer
fabric = Fabric(
@@ -176,8 +173,9 @@ def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
assert state["coconut"] == 11
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
-def test_fsdp_save_full_state_dict(tmp_path):
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
fabric = Fabric(
accelerator="cuda",
@@ -193,7 +191,7 @@ def test_fsdp_save_full_state_dict(tmp_path):
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
- checkpoint = torch.load(checkpoint_path)
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
assert checkpoint["steps"] == 1
loaded_state_dict = checkpoint["model"]
@@ -250,7 +248,7 @@ def test_fsdp_save_full_state_dict(tmp_path):
# get optimizer state after loading
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
fabric.save(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 2})
- optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
+ optimizer_state_after = torch.load(normal_checkpoint_path, weights_only=True)["optimizer"]
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=trainer.model
)
@@ -290,8 +288,9 @@ def test_fsdp_save_full_state_dict(tmp_path):
trainer.run()
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
-def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -331,7 +330,7 @@ def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
if fabric.global_rank == 0:
- checkpoint = torch.load(checkpoint_path)
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
torch.save(checkpoint["model"], raw_checkpoint_path)
fabric.barrier()
@@ -362,11 +361,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
# the linear layer got sharded and each part is on the expected device
assert next(fabric_model.parameters()).device == torch.device("cuda", fabric.local_rank)
assert next(fabric_model.parameters()).numel() == 50
- if _TORCH_GREATER_EQUAL_2_0:
- # In PyTorch >= 2.0 we set `use_orig_params=True` and don't see flattened parameters
- assert isinstance(next(fabric_model.parameters()), Parameter)
- else:
- assert isinstance(next(fabric_model.parameters()), FlatParameter)
+ assert isinstance(next(fabric_model.parameters()), Parameter)
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
# different devices
@@ -374,7 +369,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
assert fabric.device == torch.device("cuda", fabric.local_rank)
-@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.0.0")
+@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_setup_with_orig_params_and_multiple_param_groups():
"""Test that Fabric sets `use_orig_params` for the user when jointly setting up model and optimizer."""
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
@@ -406,16 +401,11 @@ def test_setup_with_orig_params_and_multiple_param_groups():
assert not isinstance(layer.weight, FlatParameter)
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
-@mock.patch(
- "lightning.fabric.wrappers.torch.compile",
- Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
-)
+@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True, skip_windows=True)
+@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
- from torch._dynamo import OptimizedModule
-
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
@@ -477,16 +467,13 @@ def _run_setup_assertions(empty_init, expected_device):
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
- if _TORCH_GREATER_EQUAL_2_1:
- # Case 2: Empty-init with PyTorch >= 2.1 supports meta device
- _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
- else:
- # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
- _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
+ # Case 2: Empty-init with meta device
+ _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
-def test_fsdp_save_filter(tmp_path):
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_save_filter(tmp_path):
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
fabric.launch()
model = nn.Linear(32, 2)
@@ -498,9 +485,9 @@ def test_fsdp_save_filter(tmp_path):
checkpoint_path = tmp_path / "full.pth"
fabric.save(checkpoint_path, state, filter=filter)
- checkpoint = torch.load(checkpoint_path)["model"]
+ checkpoint = torch.load(checkpoint_path, weights_only=True)["model"]
assert set(checkpoint) == {"bias"}
- assert isinstance(checkpoint["bias"], torch.Tensor)
+ assert type(checkpoint["bias"]) is torch.Tensor
fabric.strategy._state_dict_type = "sharded"
checkpoint_path = tmp_path / "sharded"
@@ -509,7 +496,7 @@ def test_fsdp_save_filter(tmp_path):
@RunIf(min_cuda_gpus=1)
-def test_fsdp_manual_activation_checkpointing():
+def test_manual_activation_checkpointing():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))
strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
@@ -549,9 +536,6 @@ def test_rewrap_warnings():
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
- if not _TORCH_GREATER_EQUAL_2_1:
- return
-
with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta
@@ -621,6 +605,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
@@ -668,3 +653,22 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "steps": 1}
fabric.load(checkpoint_path_full, state)
+
+
+@RunIf(min_cuda_gpus=2, standalone=True)
+def test_no_call_to_apply(monkeypatch):
+ """Regression test to ensure we're not calling `FSDP.apply()` indirectly (see #19755)."""
+ monkeypatch.setattr(torch.distributed.fsdp.FullyShardedDataParallel, "apply", Mock())
+
+ fabric = Fabric(
+ accelerator="cuda",
+ strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
+ devices=2,
+ )
+ fabric.launch()
+
+ for setup_method in ("setup", "setup_module"):
+ model = BoringModel()
+ setup = getattr(fabric, setup_method)
+ model = setup(model)
+ model._forward_module.apply.assert_not_called()
diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py
new file mode 100644
index 0000000000000..1f8b5b783b73e
--- /dev/null
+++ b/tests/tests_fabric/strategies/test_model_parallel.py
@@ -0,0 +1,357 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from datetime import timedelta
+from re import escape
+from unittest import mock
+from unittest.mock import Mock
+
+import pytest
+import torch
+import torch.nn as nn
+from lightning.fabric.plugins.environments import LightningEnvironment
+from lightning.fabric.strategies import ModelParallelStrategy
+from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
+from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl
+from torch.optim import Adam
+
+from tests_fabric.helpers.runif import RunIf
+
+
+@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
+def test_torch_greater_equal_2_4():
+ with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
+ ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+
+
+@RunIf(min_torch="2.4")
+def test_device_mesh_access():
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
+ _ = strategy.device_mesh
+
+
+@RunIf(min_torch="2.4")
+@pytest.mark.parametrize(
+ ("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
+ [
+ (1, 4, 1, 1),
+ (1, 4, 2, 3),
+ (1, 4, 4, 2),
+ (2, 4, 1, 4),
+ (2, 4, 2, 1),
+ ],
+)
+def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, invalid_tp_size):
+ """Test passing sizes that don't multiply to the world size raises an error."""
+ strategy = ModelParallelStrategy(
+ parallelize_fn=(lambda m, _: m),
+ data_parallel_size=invalid_dp_size,
+ tensor_parallel_size=invalid_tp_size,
+ )
+ strategy._setup_distributed = Mock()
+ strategy._accelerator = Mock()
+ strategy.cluster_environment = Mock(
+ world_size=Mock(return_value=(num_nodes * devices)), local_rank=Mock(return_value=1)
+ )
+ strategy.parallel_devices = [torch.device("cpu")] * devices
+ strategy.num_nodes = num_nodes
+ with pytest.raises(RuntimeError, match="multiplied should equal the world size"):
+ strategy.setup_environment()
+
+
+@RunIf(min_torch="2.4")
+def test_checkpoint_io_unsupported():
+ """Test that the ModelParallel strategy does not support the `CheckpointIO` plugin."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ with pytest.raises(NotImplementedError, match="does not use the `CheckpointIO` plugin"):
+ _ = strategy.checkpoint_io
+
+ with pytest.raises(NotImplementedError, match="does not support setting a `CheckpointIO` plugin"):
+ strategy.checkpoint_io = Mock()
+
+
+@RunIf(min_torch="2.4")
+def test_fsdp_v1_modules_unsupported():
+ """Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
+ from torch.distributed.fsdp import FullyShardedDataParallel
+
+ module = Mock(modules=Mock(return_value=[Mock(spec=FullyShardedDataParallel)]))
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda x, _: x))
+ with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
+ strategy.setup_module(module)
+
+
+@RunIf(min_torch="2.4")
+def test_parallelize_fn_call():
+ model = nn.Linear(2, 2)
+ optimizer = Adam(model.parameters())
+
+ parallel_model_mock = Mock(spec=nn.Module, parameters=Mock(return_value=[]), buffers=Mock(return_value=[]))
+ parallelize_fn = Mock(return_value=parallel_model_mock)
+ strategy = ModelParallelStrategy(parallelize_fn=parallelize_fn)
+ strategy._device_mesh = Mock()
+ strategy.parallel_devices = [torch.device("cpu")]
+ model_setup, [optimizer_setup] = strategy.setup_module_and_optimizers(model, [optimizer])
+ assert model_setup is parallel_model_mock
+ assert optimizer_setup is optimizer
+ parallelize_fn.assert_called_with(model, strategy.device_mesh)
+
+ # Raises an error if parallelize_fn does not return a module
+ parallelize_fn = Mock(return_value=None)
+ strategy = ModelParallelStrategy(parallelize_fn=parallelize_fn)
+ strategy._device_mesh = Mock()
+ strategy.parallel_devices = [torch.device("cpu")]
+ with pytest.raises(TypeError, match="The `parallelize_fn` must return a `nn.Module` instance"):
+ strategy.setup_module_and_optimizers(model, [optimizer])
+
+
+@RunIf(min_torch="2.4")
+def test_no_backward_sync():
+ """Test that the backward sync control disables gradient sync on modules that benefit from it."""
+ from torch.distributed._composable.fsdp import FSDPModule
+
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ assert isinstance(strategy._backward_sync_control, _ParallelBackwardSyncControl)
+
+ fsdp_layer = Mock(spec=FSDPModule)
+ other_layer = nn.Linear(2, 2)
+ module = Mock()
+ module.modules = Mock(return_value=[fsdp_layer, other_layer])
+
+ with strategy._backward_sync_control.no_backward_sync(module, True):
+ fsdp_layer.set_requires_gradient_sync.assert_called_with(False, recurse=False)
+ fsdp_layer.set_requires_gradient_sync.assert_called_with(True, recurse=False)
+
+ with strategy._backward_sync_control.no_backward_sync(module, False):
+ fsdp_layer.set_requires_gradient_sync.assert_called_with(True, recurse=False)
+ fsdp_layer.set_requires_gradient_sync.assert_called_with(False, recurse=False)
+
+
+@RunIf(min_torch="2.4")
+def test_save_checkpoint_storage_options(tmp_path):
+ """Test that the strategy does not accept storage options for saving checkpoints."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ with pytest.raises(
+ TypeError, match=escape("ModelParallelStrategy.save_checkpoint(..., storage_options=...)` is not")
+ ):
+ strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
+@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
+@mock.patch("torch.distributed.checkpoint.state_dict.get_model_state_dict", return_value={})
+@mock.patch("torch.distributed.checkpoint.state_dict.get_optimizer_state_dict", return_value={})
+@mock.patch("lightning.fabric.strategies.model_parallel.torch.save")
+@mock.patch("lightning.fabric.strategies.model_parallel.shutil")
+def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, _, __, ___, tmp_path):
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m), save_distributed_checkpoint=False)
+
+ # save_distributed_checkpoint=False, path exists, path is not a sharded checkpoint: error
+ path = tmp_path / "not-empty"
+ path.mkdir()
+ (path / "file").touch()
+ assert not _is_sharded_checkpoint(path)
+ with pytest.raises(IsADirectoryError, match="exists and is a directory"):
+ strategy.save_checkpoint(path=path, state=Mock())
+
+ # save_distributed_checkpoint=False, path exists, path is a sharded checkpoint: no error (overwrite)
+ path = tmp_path / "sharded-checkpoint"
+ path.mkdir()
+ (path / "meta.pt").touch()
+ assert _is_sharded_checkpoint(path)
+ model = Mock()
+ model.modules.return_value = [model]
+ strategy.save_checkpoint(path=path, state={"model": model})
+ shutil_mock.rmtree.assert_called_once_with(path)
+
+ # save_distributed_checkpoint=False, path exists, path is a file: no error (overwrite)
+ path = tmp_path / "file.pt"
+ path.touch()
+ model = Mock(spec=nn.Module)
+ torch_save_mock.reset_mock()
+ strategy.save_checkpoint(path=path, state={"model": model})
+ torch_save_mock.assert_called_once()
+
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m), save_distributed_checkpoint=True)
+ save_mock = mock.patch("torch.distributed.checkpoint.save")
+
+ # save_distributed_checkpoint=True, path exists, path is a folder: no error (overwrite)
+ path = tmp_path / "not-empty-2"
+ path.mkdir()
+ (path / "file").touch()
+ model = Mock(spec=nn.Module)
+ with save_mock:
+ strategy.save_checkpoint(path=path, state={"model": model})
+ assert (path / "file").exists()
+
+ # save_distributed_checkpoint=True, path exists, path is a file: no error (overwrite)
+ path = tmp_path / "file-2.pt"
+ path.touch()
+ model = Mock(spec=nn.Module)
+ with save_mock:
+ strategy.save_checkpoint(path=path, state={"model": model})
+ assert path.is_dir()
+
+
+@RunIf(min_torch="2.4")
+def test_save_checkpoint_one_dist_module_required(tmp_path):
+ """Test that the ModelParallelStrategy strategy can only save one distributed model per checkpoint."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+
+ # missing DTensor model
+ with pytest.raises(ValueError, match="Could not find a distributed model in the provided checkpoint state."):
+ strategy.save_checkpoint(path=tmp_path, state={})
+ with pytest.raises(ValueError, match="Could not find a distributed model in the provided checkpoint state."):
+ strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
+
+ # multiple DTensor models
+ with mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True):
+ model1 = Mock(spec=nn.Module)
+ model1.modules.return_value = [model1]
+ model2 = Mock(spec=nn.Module)
+ model2.modules.return_value = [model2]
+ with pytest.raises(ValueError, match="Found multiple distributed models in the given state."):
+ strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
+
+
+@RunIf(min_torch="2.4")
+def test_load_checkpoint_no_state(tmp_path):
+ """Test that the ModelParallelStrategy strategy can't load the full state without access to a model instance from
+ the user."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ with pytest.raises(ValueError, match=escape("Got ModelParallelStrategy.load_checkpoint(..., state=None")):
+ strategy.load_checkpoint(path=tmp_path, state=None)
+ with pytest.raises(ValueError, match=escape("Got ModelParallelStrategy.load_checkpoint(..., state={})")):
+ strategy.load_checkpoint(path=tmp_path, state={})
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
+@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
+def test_load_checkpoint_one_dist_module_required(tmp_path):
+ """Test that the ModelParallelStrategy strategy can only load one distributed model per checkpoint."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+
+ # missing DTensor model
+ with pytest.raises(ValueError, match="Could not find a distributed model in the provided checkpoint state."):
+ strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
+ with pytest.raises(ValueError, match="Could not find a distributed model in the provided checkpoint state."):
+ strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
+
+ # multiple DTensor models
+ with mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True):
+ model1 = Mock(spec=nn.Module)
+ model1.modules.return_value = [model1]
+ model2 = Mock(spec=nn.Module)
+ model2.modules.return_value = [model2]
+ with pytest.raises(ValueError, match="Found multiple distributed models in the given state."):
+ strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
+
+ # A raw nn.Module instead of a dictionary is ok
+ model = Mock(spec=nn.Module)
+ model.parameters.return_value = [torch.zeros(2, 1)]
+ path = tmp_path / "full.ckpt"
+ path.touch()
+ strategy.load_checkpoint(path=path, state=model)
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
+def test_load_unknown_checkpoint_type(_, tmp_path):
+ """Test that the strategy validates the contents at the checkpoint path."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ model = Mock()
+ path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
+ path.mkdir()
+ with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
+ strategy.load_checkpoint(path=path, state={"model": model})
+
+
+@RunIf(min_torch="2.4")
+def test_load_raw_checkpoint_validate_single_file(tmp_path):
+ """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ model = Mock(spec=nn.Module)
+ path = tmp_path / "folder"
+ path.mkdir()
+ with pytest.raises(ValueError, match="The given path must be a single file containing the full state dict"):
+ strategy.load_checkpoint(path=path, state=model)
+
+
+@RunIf(min_torch="2.4")
+def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
+ """Validate that the ModelParallelStrategy strategy does not yet support loading the raw PyTorch state-dict for an
+ optimizer."""
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
+ optimizer = Mock(spec=torch.optim.Optimizer)
+ with pytest.raises(
+ NotImplementedError, match="Loading a single optimizer object from a checkpoint is not supported"
+ ):
+ strategy.load_checkpoint(path=tmp_path, state=optimizer)
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.fabric.strategies.model_parallel._setup_device_mesh")
+@mock.patch("torch.distributed.init_process_group")
+def test_set_timeout(init_process_group_mock, _):
+ """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
+ test_timedelta = timedelta(seconds=30)
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m), timeout=test_timedelta)
+ strategy.parallel_devices = [torch.device("cpu")]
+ strategy.cluster_environment = LightningEnvironment()
+ strategy.accelerator = Mock()
+ strategy.setup_environment()
+ process_group_backend = strategy._get_process_group_backend()
+ global_rank = strategy.cluster_environment.global_rank()
+ world_size = strategy.cluster_environment.world_size()
+ init_process_group_mock.assert_called_with(
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ )
+
+
+@RunIf(min_torch="2.4")
+def test_meta_device_materialization():
+ """Test that the `setup_module()` method materializes meta-device tensors in the module."""
+
+ class NoResetParameters(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(4, 4))
+
+ class CustomModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # nn.Sequential as a parameterless module
+ self.layer1 = nn.Sequential(NoResetParameters(), NoResetParameters())
+ self.layer2 = nn.Linear(4, 4)
+ self.register_buffer("buffer", torch.rand(2))
+
+ def reset_parameters(self):
+ self.buffer.fill_(1.0)
+
+ strategy = ModelParallelStrategy(parallelize_fn=(lambda x, _: x))
+ strategy._device_mesh = Mock()
+ strategy._parallel_devices = [torch.device("cpu")]
+
+ with torch.device("meta"):
+ model = CustomModel()
+ assert model.layer1[0].weight.is_meta
+ assert model.layer2.weight.is_meta
+ assert model.buffer.is_meta
+
+ with pytest.warns(UserWarning, match=r"`reset_parameters\(\)` method for re-initialization: NoResetParameters"):
+ model = strategy.setup_module(model)
+ assert all(not p.is_meta for p in model.parameters())
+ assert all(not b.is_meta for b in model.buffers())
diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py
new file mode 100644
index 0000000000000..dfbdb16b10060
--- /dev/null
+++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py
@@ -0,0 +1,699 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from copy import deepcopy
+from pathlib import Path
+from unittest import mock
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from lightning.fabric import Fabric
+from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
+from lightning.fabric.utilities.load import _load_distributed_checkpoint
+from torch.utils.data import DataLoader, DistributedSampler
+
+from tests_fabric.helpers.datasets import RandomDataset
+from tests_fabric.helpers.runif import RunIf
+
+
+class FeedForward(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.w1 = nn.Linear(32, 64)
+ self.w2 = nn.Linear(32, 64)
+ self.w3 = nn.Linear(64, 32)
+
+ def forward(self, x):
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
+
+
+def _parallelize_feed_forward_tp(model, device_mesh):
+ from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
+
+ tp_mesh = device_mesh["tensor_parallel"]
+ tp_plan = {
+ "w1": ColwiseParallel(),
+ "w2": ColwiseParallel(),
+ "w3": RowwiseParallel(),
+ }
+ parallelize_module(model, tp_mesh, tp_plan)
+ return model
+
+
+def _parallelize_feed_forward_fsdp2(model, device_mesh):
+ from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
+ dp_mesh = device_mesh["data_parallel"]
+ assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
+
+ # Fully-shard each layer
+ fully_shard(model.w1, mesh=dp_mesh)
+ fully_shard(model.w2, mesh=dp_mesh)
+ fully_shard(model.w3, mesh=dp_mesh)
+
+ # TODO: Re-enable activation checkpointing
+ # Currently, state dict keys get prefixed with '_checkpoint_wrapper' in the keys
+ # which leads to mismatches when loading weights into a checkpoint-wrapped module.
+ # PyTorch should handle this automatically.
+
+ # model = checkpoint_wrapper(model)
+
+ return model
+
+
+def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
+ model = _parallelize_feed_forward_tp(model, device_mesh)
+ model = _parallelize_feed_forward_fsdp2(model, device_mesh)
+ return model
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
+def test_setup_device_mesh():
+ from torch.distributed.device_mesh import DeviceMesh
+
+ for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
+ strategy = ModelParallelStrategy(
+ parallelize_fn=(lambda m, _: m),
+ data_parallel_size=dp_size,
+ tensor_parallel_size=tp_size,
+ )
+ fabric = Fabric(accelerator="auto", devices=4, strategy=strategy)
+ fabric.launch()
+
+ device_mesh = fabric.strategy.device_mesh
+ assert isinstance(device_mesh, DeviceMesh)
+ assert device_mesh.device_type == fabric.device.type
+ assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
+ assert device_mesh.size(0) == dp_size
+ assert device_mesh.size(1) == tp_size
+ assert device_mesh.ndim == 2
+
+ fabric.barrier()
+
+ # Passing "auto" will select internode and intranode dimensions automatically
+ strategy = ModelParallelStrategy(
+ parallelize_fn=(lambda m, _: m),
+ data_parallel_size="auto",
+ tensor_parallel_size="auto",
+ )
+ fabric = Fabric(accelerator="auto", devices=4, num_nodes=1, strategy=strategy)
+ fabric.launch()
+ assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
+ assert fabric.strategy.device_mesh.size(0) == 1
+ assert fabric.strategy.device_mesh.size(1) == 4
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
+def test_tensor_parallel():
+ from torch.distributed._tensor import DTensor
+
+ strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp)
+ fabric = Fabric(accelerator="auto", devices=2, strategy=strategy)
+ fabric.launch()
+
+ fabric.seed_everything(0)
+
+ with fabric.init_module(empty_init=True):
+ model = FeedForward()
+
+ model = fabric.setup(model)
+ optimizer = torch.optim.AdamW(model.parameters())
+ optimizer = fabric.setup_optimizers(optimizer)
+
+ device_mesh = fabric.strategy.device_mesh
+ assert all(tensor.device_mesh == device_mesh["tensor_parallel"] for tensor in optimizer.param_groups[0]["params"])
+ assert all(isinstance(weight, DTensor) for weight in model.parameters())
+ assert model.w1.weight.device_mesh == device_mesh["tensor_parallel"]
+
+ dataset_size = 6
+ dataset = RandomDataset(32, dataset_size)
+ dataloader = DataLoader(dataset, batch_size=2)
+ dataloader = fabric.setup_dataloaders(dataloader)
+
+ # No data sharding, all GPUs get the same input inside a TP group
+ assert len(dataloader) == dataset_size // dataloader.batch_size
+ assert isinstance(dataloader.sampler, DistributedSampler)
+
+ for _, batch in enumerate(dataloader):
+ # All batches must be identical across TP group
+ batches = fabric.all_gather(batch)
+ assert all(torch.equal(batches[0], batches[i]) for i in range(1, len(batches)))
+
+ output = model(batch)
+ fabric.backward(output.sum())
+ assert isinstance(model.w1.weight.grad, DTensor)
+ assert model.w1.weight.grad.device_mesh == device_mesh["tensor_parallel"]
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
+def test_fsdp2_tensor_parallel():
+ from torch.distributed._tensor import DTensor
+
+ strategy = ModelParallelStrategy(
+ parallelize_fn=_parallelize_feed_forward_fsdp2_tp,
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ fabric = Fabric(accelerator="auto", devices=4, strategy=strategy)
+ fabric.launch()
+
+ fabric.seed_everything(0)
+
+ with fabric.init_module(empty_init=True):
+ model = FeedForward()
+
+ model = fabric.setup(model)
+ optimizer = torch.optim.AdamW(model.parameters())
+ optimizer = fabric.setup_optimizers(optimizer)
+
+ assert all(isinstance(weight, DTensor) for weight in model.parameters())
+ assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"])
+ assert model.w1.weight.device_mesh.ndim == 2
+ assert model.w1.weight.device_mesh.size(0) == 2
+ assert model.w1.weight.device_mesh.size(1) == 2
+ assert all(weight.device.type != "meta" for weight in model.parameters())
+ assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"])
+ assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"])
+
+ dataset_size = 8
+ dataset = RandomDataset(32, dataset_size)
+ dataloader = DataLoader(dataset, batch_size=2)
+ dataloader = fabric.setup_dataloaders(dataloader)
+
+ # No data sharding across TP dimension, sharding across data-parallel dimension only
+ device_mesh = fabric.strategy.device_mesh
+ dp_mesh = device_mesh["data_parallel"]
+ tp_mesh = device_mesh["tensor_parallel"]
+ assert len(dataloader) == dataset_size // dataloader.batch_size // dp_mesh.size()
+ assert isinstance(dataloader.sampler, DistributedSampler)
+
+ for _, batch in enumerate(dataloader):
+ batches = fabric.all_gather(batch)
+ # Batches across the TP dimension must be identical
+ batches_tp = batches[tp_mesh.mesh]
+ assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp)))
+ # Batches across the DP dimension must be different
+ batches_dp = batches[dp_mesh.mesh]
+ assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp)))
+
+ output = model(batch)
+ fabric.backward(output.sum())
+ assert isinstance(model.w1.weight.grad, DTensor)
+ assert model.w1.weight.grad.device_mesh == device_mesh
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+def _train(fabric, model=None, optimizer=None):
+ fabric.seed_everything(0)
+
+ if model is None:
+ with fabric.init_module(empty_init=True):
+ model = FeedForward()
+ model = fabric.setup(model)
+ if optimizer is None:
+ optimizer = torch.optim.AdamW(model.parameters())
+ optimizer = fabric.setup_optimizers(optimizer)
+
+ output = model(torch.rand(2, 32, device=fabric.device))
+ fabric.backward(output.sum())
+ optimizer.step()
+ optimizer.zero_grad()
+ return model, optimizer
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
+@pytest.mark.parametrize(
+ "precision",
+ [
+ pytest.param("32-true"),
+ pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
+ ],
+)
+def test_train_save_load(precision, tmp_path):
+ """Test 2D-parallel training, saving and loading precision settings."""
+ strategy = ModelParallelStrategy(
+ _parallelize_feed_forward_fsdp2_tp,
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision=precision)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ checkpoint_path = fabric.broadcast(str(tmp_path / "dist-checkpoint"))
+
+ params_before = [p.full_tensor().clone() for p in model.parameters()]
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+ fabric.save(checkpoint_path, state)
+ assert set(os.listdir(checkpoint_path)) == {
+ ".metadata",
+ "__0_0.distcp",
+ "__1_0.distcp",
+ "__2_0.distcp",
+ "__3_0.distcp",
+ "meta.pt",
+ }
+
+ # re-init all objects and resume
+ strategy = ModelParallelStrategy(
+ _parallelize_feed_forward_fsdp2_tp,
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision=precision)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ # check correctness with loaded state
+ state = {"model": model, "optimizer": optimizer, "steps": 0}
+ metadata = fabric.load(checkpoint_path, state)
+ for p0, p1 in zip(params_before, (p.full_tensor() for p in model.parameters())):
+ torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
+
+ # check user data in state reloaded
+ assert state["steps"] == 1
+ assert not metadata
+
+ # attempt to load a key not in the metadata checkpoint
+ state = {"model": model, "coconut": 11}
+ with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
+ fabric.load(checkpoint_path, state)
+
+ # `strict=False` ignores the missing key
+ state = {"model": model, "coconut": 11}
+ fabric.load(checkpoint_path, state, strict=False)
+ assert state["coconut"] == 11
+
+
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_save_full_state_dict(tmp_path):
+ """Test that ModelParallelStrategy saves the full state into a single file with
+ `save_distributed_checkpoint=False`."""
+ from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
+
+ strategy = ModelParallelStrategy(
+ _parallelize_feed_forward_fsdp2,
+ data_parallel_size=2,
+ tensor_parallel_size=1,
+ save_distributed_checkpoint=False,
+ )
+ fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+ fabric.save(checkpoint_path, state)
+
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
+ assert checkpoint["steps"] == 1
+ loaded_state_dict = checkpoint["model"]
+
+ # assert the correct state model was saved
+ state_dict = model.state_dict()
+ assert set(loaded_state_dict.keys()) == set(state_dict.keys())
+ for param_name in state_dict:
+ assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].full_tensor().cpu())
+ params_before = [p.full_tensor().cpu() for p in model.parameters()]
+
+ # assert the correct optimizer state was saved
+ optimizer_state_before = get_optimizer_state_dict(model, optimizer)
+ assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
+
+ # 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
+ strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
+ fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ metadata = fabric.load(checkpoint_path, {"model": model, "optimizer": optimizer})
+ assert metadata == {"steps": 1}
+
+ params_after = [p.full_tensor() for p in model.parameters()]
+ assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
+
+ optimizer_state_after = get_optimizer_state_dict(model, optimizer)
+ optimizer_state_after["param_groups"][0]["betas"] = tuple(optimizer_state_after["param_groups"][0]["betas"])
+ assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
+ torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
+ assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
+
+ # run a step to verify the optimizer state is correct
+ _train(fabric, model, optimizer)
+
+ # 2. verify the FSDP state can be loaded back into a single-device model/strategy
+ fabric = Fabric(accelerator="cpu", devices=1)
+ model, optimizer = _train(fabric)
+ metadata = fabric.load(checkpoint_path, {"model": model, "optimizer": optimizer})
+ assert metadata == {"steps": 1}
+ params_after = list(model.parameters())
+ assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
+
+ # get optimizer state after loading
+ normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
+ fabric.save(normal_checkpoint_path, {"model": model, "optimizer": optimizer, "steps": 2})
+
+ optimizer_state_after = torch.load(normal_checkpoint_path, weights_only=True)["optimizer"]
+ assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
+ assert torch.equal(
+ optimizer_state_after["state"][0]["exp_avg"],
+ optimizer_state_before["state"]["_forward_module.w1.weight"]["exp_avg"].full_tensor().cpu(),
+ )
+
+ # run a step to verify the optimizer state is correct
+ _train(fabric, model, optimizer)
+
+ # 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
+ strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
+ fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ metadata = fabric.load(normal_checkpoint_path, {"model": model, "optimizer": optimizer})
+ assert metadata == {"steps": 2}
+
+ params_after = [p.full_tensor() for p in model.parameters()]
+ assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
+
+ optimizer_state_after = get_optimizer_state_dict(model, optimizer)
+ optimizer_state_after["param_groups"][0]["betas"] = tuple(optimizer_state_after["param_groups"][0]["betas"])
+ assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
+ torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
+ assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
+
+ # run a step to verify the optimizer state is correct
+ _train(fabric, model, optimizer)
+
+
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_load_full_state_dict_into_sharded_model(tmp_path):
+ """Test that the strategy can load a full-state checkpoint into a distributed model."""
+ fabric = Fabric(accelerator="cuda", devices=1)
+ fabric.seed_everything(0)
+ model, optimizer = _train(fabric)
+
+ # Save a full-state-dict checkpoint
+ checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+ fabric.save(checkpoint_path, state)
+
+ # Gather all weights and store a copy manually
+ params_before = torch.cat([p.cpu().view(-1) for p in model.parameters()])
+
+ # Create a FSDP sharded model
+ strategy = ModelParallelStrategy(_parallelize_feed_forward_fsdp2, data_parallel_size=2, tensor_parallel_size=1)
+ fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
+ fabric.launch()
+ model, optimizer = _train(fabric)
+
+ state = {"model": model, "optimizer": optimizer, "steps": 44}
+ fabric.load(checkpoint_path, state)
+ assert state["steps"] == 1
+
+ # Gather all weights and compare
+ params_after = torch.cat([p.full_tensor().cpu().view(-1) for p in model.parameters()])
+ assert torch.equal(params_before, params_after)
+
+ # Create a raw state-dict checkpoint to test `Fabric.load_raw` too
+ raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
+ if fabric.global_rank == 0:
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
+ torch.save(checkpoint["model"], raw_checkpoint_path)
+ fabric.barrier()
+
+ _train(fabric, model, optimizer)
+ fabric.load_raw(raw_checkpoint_path, model)
+
+ # Gather all weights and compare
+ params_after = torch.cat([p.full_tensor().cpu().view(-1) for p in model.parameters()])
+ assert torch.equal(params_before, params_after)
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
+@pytest.mark.parametrize("move_to_device", [True, False])
+@mock.patch("lightning.fabric.wrappers._FabricModule")
+def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
+ """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device
+ (sharding)."""
+ from torch.distributed._tensor import DTensor
+
+ strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
+ fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
+ fabric.launch()
+
+ model = FeedForward()
+ fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
+ fabric_module_mock.assert_not_called()
+
+ # the linear layer got sharded and each part is on the expected device
+ assert fabric_model.w1.weight.device == torch.device("cuda", fabric.local_rank)
+ assert isinstance(fabric_model.w1.weight, DTensor)
+
+ # The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
+ # different devices
+ assert fabric_model.device == torch.device("cuda", fabric.local_rank)
+ assert fabric.device == torch.device("cuda", fabric.local_rank)
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
+@pytest.mark.parametrize(
+ ("precision", "expected_dtype"),
+ [
+ ("32-true", torch.float32),
+ ("16-true", torch.float16),
+ pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
+ ],
+)
+def test_module_init_context(precision, expected_dtype):
+ """Test that the module under the init-context gets moved to the right device and dtype."""
+ strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
+ fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision)
+ fabric.launch()
+
+ def _run_setup_assertions(empty_init, expected_device):
+ with fabric.init_module(empty_init=empty_init):
+ model = FeedForward()
+
+ # The model is on the CPU/meta-device until after `.setup()``
+ assert all(weight.device == expected_device for weight in model.parameters())
+ assert all(weight.dtype == expected_dtype for weight in model.parameters())
+ model = fabric.setup(model)
+ # Parameters get sharded in `.setup()` and moved to the target device
+ assert all(weight.device == torch.device("cuda", fabric.local_rank) for weight in model.parameters())
+ assert all(weight.dtype == expected_dtype for weight in model.parameters())
+
+ _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
+ _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_save_filter(tmp_path):
+ strategy = ModelParallelStrategy(
+ parallelize_fn=_parallelize_feed_forward_fsdp2,
+ save_distributed_checkpoint=False,
+ )
+ fabric = Fabric(accelerator="cuda", strategy=strategy, devices=2)
+ fabric.launch()
+ model = FeedForward()
+ model = fabric.setup_module(model)
+
+ tmp_path = Path(fabric.broadcast(str(tmp_path)))
+ state = {"model": model}
+ filter = {"model": lambda k, v: "bias" in k}
+
+ checkpoint_path = tmp_path / "full.pth"
+ fabric.save(checkpoint_path, state, filter=filter)
+ checkpoint = torch.load(checkpoint_path, weights_only=True)["model"]
+ assert set(checkpoint) == {"w1.bias", "w2.bias", "w3.bias"}
+ assert type(checkpoint["w1.bias"]) is torch.Tensor
+
+ fabric.strategy._save_distributed_checkpoint = True
+ checkpoint_path = tmp_path / "distributed"
+ with pytest.raises(NotImplementedError, match="doesn't support loading distributed filtered"):
+ fabric.save(checkpoint_path, state, filter=filter)
+
+
+def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
+ from torch.distributed._composable.fsdp.fully_shard import fully_shard
+ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
+
+ dp_mesh = device_mesh["data_parallel"]
+ tp_mesh = device_mesh["tensor_parallel"]
+
+ parallelize_module(model, tp_mesh, ColwiseParallel())
+ fully_shard(model, mesh=dp_mesh)
+ return model
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+@pytest.mark.parametrize(
+ "precision",
+ [
+ "32-true",
+ pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
+ ],
+)
+@pytest.mark.parametrize(
+ "clip_type",
+ [
+ pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
+ "val",
+ ],
+)
+def test_clip_gradients(clip_type, precision):
+ strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
+ fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
+ fabric.launch()
+
+ in_features, out_features = 32, 2
+ model = torch.nn.Linear(in_features, out_features, bias=False)
+ model.weight.data.fill_(0.01)
+
+ model = fabric.setup(model)
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
+ optimizer = fabric.setup_optimizers(optimizer)
+
+ batch = torch.full((1, in_features), 0.1, device=fabric.device)
+ loss = model(batch).sum()
+
+ # The example is constructed such that the gradients are all the same
+ fabric.backward(loss)
+
+ if clip_type == "norm":
+ norm = torch.linalg.vector_norm(model.weight.grad.full_tensor().detach().cpu(), 2, dtype=torch.float32).item()
+ new_norm = norm / 10
+ fabric.clip_gradients(model, optimizer, max_norm=new_norm * 10)
+ assert torch.allclose(
+ torch.linalg.vector_norm(model.weight.grad.full_tensor().detach().cpu(), 2, dtype=torch.float32),
+ torch.tensor(new_norm),
+ )
+ elif clip_type == "val":
+ val = model.weight.grad.full_tensor()[0, 0].item()
+ new_val = val / 2.0
+ fabric.clip_gradients(model, optimizer, clip_val=new_val)
+ assert torch.allclose(
+ model.weight.grad.full_tensor(), torch.full_like(model.weight.grad.full_tensor(), new_val)
+ )
+ else:
+ raise AssertionError(f"Unknown clip type: {clip_type}")
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
+def test_save_sharded_and_consolidate_and_load(tmp_path):
+ """Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
+ strategy = ModelParallelStrategy(
+ _parallelize_feed_forward_fsdp2_tp,
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy)
+ fabric.launch()
+
+ model = FeedForward()
+ model = fabric.setup(model)
+ optimizer = torch.optim.Adam(model.parameters())
+ optimizer = fabric.setup_optimizers(optimizer)
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+
+ # run one iteration to init the state of the optimizer
+ loss = model(torch.rand(1, 32, device=fabric.device)).sum()
+ fabric.backward(loss)
+ optimizer.step()
+
+ checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
+ fabric.save(checkpoint_path_sharded, state)
+ assert set(os.listdir(checkpoint_path_sharded)) == {
+ ".metadata",
+ "__0_0.distcp",
+ "__1_0.distcp",
+ "__2_0.distcp",
+ "__3_0.distcp",
+ "meta.pt",
+ }
+
+ # consolidate the checkpoint to a single file
+ checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
+ if fabric.global_rank == 0:
+ checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
+ torch.save(checkpoint, checkpoint_path_full)
+ fabric.barrier()
+
+ # re-init and load from full checkpoint
+ strategy = ModelParallelStrategy(
+ _parallelize_feed_forward_fsdp2_tp,
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy)
+ fabric.launch()
+
+ model = FeedForward()
+ model = fabric.setup(model)
+ optimizer = torch.optim.Adam(model.parameters())
+ optimizer = fabric.setup_optimizers(optimizer)
+ state = {"model": model, "optimizer": optimizer, "steps": 1}
+ fabric.load(checkpoint_path_full, state)
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_load_raw_module_state():
+ from torch.distributed.device_mesh import init_device_mesh
+ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
+
+ class CustomModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.parameter = nn.Parameter(torch.rand(2, 2))
+ self.layer1 = nn.Linear(4, 4)
+ self.layer2 = nn.Linear(4, 4)
+ self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
+ self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)
+
+ fabric = Fabric(accelerator="cuda", devices=2)
+ fabric.launch()
+ fabric.seed_everything(0)
+
+ with fabric.init_module():
+ model = CustomModel()
+
+ state_dict = deepcopy(model.state_dict())
+
+ with fabric.init_module():
+ model = CustomModel()
+
+ device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
+ plan = {"layer1": ColwiseParallel()}
+ parallelize_module(model, device_mesh, plan)
+ _load_raw_module_state(state_dict, model, strict=True)
+
+ assert torch.equal(model.parameter, state_dict["parameter"])
+ assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
+ assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
+ assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])
+
+ state_dict.pop("parameter")
+ with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
+ _load_raw_module_state(state_dict, model, strict=True)
+
+ _load_raw_module_state(state_dict, model, strict=False)
diff --git a/tests/tests_fabric/strategies/test_registry.py b/tests/tests_fabric/strategies/test_registry.py
index 1865328cf59bf..8efd19541a298 100644
--- a/tests/tests_fabric/strategies/test_registry.py
+++ b/tests/tests_fabric/strategies/test_registry.py
@@ -42,6 +42,7 @@ def __init__(self, param1, param2):
def test_available_strategies_in_registry():
expected = {
"ddp",
+ "ddp_find_unused_parameters_true",
"deepspeed",
"deepspeed_stage_1",
"deepspeed_stage_1_offload",
diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py
index cbbbf963b3607..a7a1dba87cb97 100644
--- a/tests/tests_fabric/strategies/test_strategy.py
+++ b/tests/tests_fabric/strategies/test_strategy.py
@@ -18,7 +18,6 @@
import torch
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
from lightning.fabric.strategies import SingleDeviceStrategy
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _Stateful
from tests_fabric.helpers.runif import RunIf
@@ -239,8 +238,7 @@ def test_module_init_context(device, precision, dtype, empty_init, monkeypatch):
with strategy.module_init_context(empty_init=empty_init):
module = torch.nn.Linear(2, 2)
- expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
- assert module.weight.device == module.bias.device == expected_device
+ assert module.weight.device == module.bias.device == device
assert module.weight.dtype == module.bias.dtype == dtype
if not empty_init:
init_mock.assert_called()
@@ -274,8 +272,7 @@ def test_tensor_init_context(device, precision, dtype):
tensor1 = torch.tensor(42)
tensor2 = torch.tensor(42.0, dtype=torch.half)
- expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
- assert tensor0.device == tensor1.device == tensor2.device == expected_device
+ assert tensor0.device == tensor1.device == tensor2.device == device
assert tensor0.dtype == dtype
assert tensor1.dtype == torch.long # `.init_tensor()` only affects floating point dtypes
assert tensor2.dtype == torch.half # this tensor was created with an explicit dtype assignment
diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py
index bcd2a6e637417..e2864b684c4a7 100644
--- a/tests/tests_fabric/strategies/test_xla_fsdp.py
+++ b/tests/tests_fabric/strategies/test_xla_fsdp.py
@@ -27,7 +27,7 @@
from tests_fabric.helpers.runif import RunIf
-@RunIf(min_torch="2.0", tpu=True)
+@RunIf(tpu=True)
def test_xla_fsdp_setup_optimizer_validation():
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
module = nn.Linear(2, 2)
@@ -39,7 +39,7 @@ def test_xla_fsdp_setup_optimizer_validation():
strategy.setup_optimizer(bad_optimizer)
-@RunIf(min_torch="2.0", tpu=True)
+@RunIf(tpu=True)
def test_xla_fsdp_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
XlaFullyShardedDataParallel."""
@@ -64,7 +64,7 @@ def test_xla_fsdp_no_backward_sync():
module.no_sync.assert_called_once()
-@RunIf(min_torch="2.0", tpu=True)
+@RunIf(tpu=True)
def test_xla_fsdp_grad_clipping_value_error():
strategy = XLAFSDPStrategy()
with pytest.raises(NotImplementedError, match="does not support to clip gradients by value"):
diff --git a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py
index 999b8473b28aa..20c2ef042272e 100644
--- a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py
+++ b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py
@@ -45,7 +45,7 @@ def _xla_fsdp_rewrap_warning(fabric: Fabric):
assert isinstance(model._forward_module[2], XlaFullyShardedDataParallel)
-@RunIf(min_torch="2.0", tpu=True, standalone=True)
+@RunIf(tpu=True, standalone=True)
def test_xla_fsdp_rewrap_warning():
"""Test that XLAFSDP warns about rewrapping the modules."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
@@ -159,7 +159,7 @@ def step(model, batch):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
-@RunIf(min_torch="2.0", tpu=True, standalone=True)
+@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize(
("use_auto_wrap_policy", "state_dict_type", "sequential_save"),
[
@@ -196,7 +196,7 @@ def _test_setup_module_move_to_device(fabric, move_to_device):
assert fabric.device.type == "xla"
-@RunIf(min_torch="2.0", tpu=True, standalone=True)
+@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
def test_setup_module_move_to_device(move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py
index 0e58acb3c7267..a57f413ff6081 100644
--- a/tests/tests_fabric/test_cli.py
+++ b/tests/tests_fabric/test_cli.py
@@ -71,8 +71,9 @@ def test_run_env_vars_strategy(_, strategy, monkeypatch, fake_script):
def test_run_get_supported_strategies():
"""Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the
CLI."""
- assert len(_get_supported_strategies()) == 7
+ assert len(_get_supported_strategies()) == 8
assert "fsdp" in _get_supported_strategies()
+ assert "ddp_find_unused_parameters_true" in _get_supported_strategies()
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork", "ddp_notebook", "deepspeed_stage_3_offload"])
diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py
index fba4c34c07d63..08d6dbb45ed91 100644
--- a/tests/tests_fabric/test_connector.py
+++ b/tests/tests_fabric/test_connector.py
@@ -14,6 +14,7 @@
import inspect
import os
import sys
+from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
@@ -53,6 +54,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
+ ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAFSDPStrategy,
@@ -866,6 +868,18 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)
+@RunIf(min_torch="2.4")
+@pytest.mark.parametrize(
+ ("precision", "raises"),
+ [("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
+)
+@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
+def test_precision_selection_model_parallel(_, precision, raises):
+ error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
+ with error_context:
+ _Connector(precision=precision, strategy=ModelParallelStrategy(lambda x, _: x))
+
+
def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
@@ -978,6 +992,16 @@ def test_fsdp_unsupported_on_cpu(_):
with pytest.raises(ValueError, match="You selected the FSDP strategy but FSDP is only available on GPU"):
_Connector(accelerator="cpu", strategy="fsdp")
+ class FSDPStrategySubclass(FSDPStrategy):
+ pass
+
+ class AcceleratorSubclass(CPUAccelerator):
+ pass
+
+ # we allow subclasses of FSDPStrategy to be used with other accelerators
+ _Connector(accelerator="cpu", strategy=FSDPStrategySubclass())
+ _Connector(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())
+
def test_connector_defaults_match_fabric_defaults():
"""Test that the default values for the init arguments of Connector match the ones in Fabric."""
diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py
index fde9479c73eaf..70d04d5431404 100644
--- a/tests/tests_fabric/test_fabric.py
+++ b/tests/tests_fabric/test_fabric.py
@@ -289,7 +289,7 @@ def test_setup_optimizers_not_supported(strategy_cls):
fabric.setup_optimizers(optimizer)
-@RunIf(min_cuda_gpus=1, min_torch="2.1")
+@RunIf(min_cuda_gpus=1)
def test_setup_optimizer_on_meta_device():
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
@@ -623,7 +623,7 @@ def test_backward():
("auto", "32-true", False),
("auto", "bf16-true", False),
("auto", "bf16-mixed", True),
- pytest.param("fsdp", "32-true", True, marks=RunIf(min_cuda_gpus=1, min_torch="2.0.0")),
+ pytest.param("fsdp", "32-true", True, marks=RunIf(min_cuda_gpus=1)),
],
)
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
@@ -855,7 +855,6 @@ def test_module_sharding_context():
def test_init_module_context(monkeypatch):
"""Test that the strategy returns the context manager for initializing the module."""
- import lightning.fabric
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
@@ -866,18 +865,8 @@ def test_init_module_context(monkeypatch):
strategy.module_init_context.assert_called_once_with(empty_init=None)
strategy.module_init_context.reset_mock()
- # Pretend we are using PyTorch < 2.0
- monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
- with pytest.warns(PossibleUserWarning, match="can't place the model parameters on the device"): # noqa: SIM117
- with fabric.init_module():
- pass
- strategy.module_init_context.assert_called_once()
-
def test_init_tensor_context(monkeypatch):
- """Test that `.init_tensor()` warns if using PyTorch < 2.0."""
- import lightning.fabric
-
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
@@ -887,13 +876,6 @@ def test_init_tensor_context(monkeypatch):
strategy.tensor_init_context.assert_called_once()
strategy.tensor_init_context.reset_mock()
- # Pretend we are using PyTorch < 2.0
- monkeypatch.setattr(lightning.fabric.fabric, "_TORCH_GREATER_EQUAL_2_0", False)
- with pytest.warns(PossibleUserWarning, match="can't place tensors on the device directly"): # noqa: SIM117
- with fabric.init_tensor():
- pass
- strategy.tensor_init_context.assert_called_once()
-
def test_callbacks_input():
"""Test the various ways in which callbacks can be registered with Fabric."""
diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py
index 0923c601d51c3..363da022d285b 100644
--- a/tests/tests_fabric/test_wrappers.py
+++ b/tests/tests_fabric/test_wrappers.py
@@ -19,7 +19,7 @@
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
+from lightning.fabric.utilities.types import Optimizable
from lightning.fabric.wrappers import (
_FabricDataLoader,
_FabricModule,
@@ -28,6 +28,7 @@
_unwrap_objects,
is_wrapped,
)
+from torch._dynamo import OptimizedModule
from torch.utils.data import BatchSampler, DistributedSampler
from torch.utils.data.dataloader import DataLoader
@@ -101,15 +102,20 @@ def __init__(self, module):
super().__init__()
self.wrapped = module
+ def forward(self, *args, **kwargs):
+ return self.wrapped(*args, **kwargs)
+
# Regular case: forward_module == original_module -> no warnings
original_module = OriginalModule()
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.method_without_module_invocation() == 100
- # Special case: original module wrapped by forward module: -> warn if method accepts args
+ # Special case: original module wrapped by forward module: -> error if method requires rerouting
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
- fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
+ fabric_module = _FabricModule(
+ forward_module=wrapped_module, strategy=Mock(precision=Precision()), original_module=original_module
+ )
assert fabric_module.method_without_module_invocation() == 100
with pytest.raises(
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
@@ -120,6 +126,51 @@ def __init__(self, module):
):
assert fabric_module.method_with_self_invocation() == 102
+ # No error if explicitly marked as forward method
+ fabric_module.mark_forward_method("method_with_self_invocation")
+ assert fabric_module.method_with_self_invocation() == 102
+
+
+def test_fabric_module_mark_forward_method():
+ class OriginalModule(torch.nn.Module):
+ attribute = 1
+
+ def forward(self, x):
+ return x
+
+ def special(self):
+ pass
+
+ original_module = OriginalModule()
+ fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)
+
+ with pytest.raises(ValueError, match="You cannot mark the forward method itself"):
+ fabric_module.mark_forward_method("forward")
+
+ with pytest.raises(AttributeError, match="`OriginalModule.not_exist` does not exist or is not a method."):
+ fabric_module.mark_forward_method("not_exist")
+
+ with pytest.raises(AttributeError, match="`OriginalModule.attribute` does not exist or is not a method."):
+ fabric_module.mark_forward_method("attribute")
+
+ def special(x):
+ return x
+
+ with pytest.raises(TypeError, match="Expected a method or a string"):
+ fabric_module.mark_forward_method(special)
+
+ lightning_module_methods = {"training_step", "validation_step", "test_step", "predict_step"}
+ assert fabric_module._forward_methods == lightning_module_methods
+
+ # Mark via name
+ fabric_module.mark_forward_method("special")
+ assert fabric_module._forward_methods == {"special"} | lightning_module_methods
+
+ # Mark by passing in the method itself
+ fabric_module = _FabricModule(original_module, Mock(), original_module=original_module)
+ fabric_module.mark_forward_method(original_module.special)
+ assert fabric_module._forward_methods == {"special"} | lightning_module_methods
+
def test_fabric_module_setattr():
"""Test that setattr sets attributes on the original module."""
@@ -217,14 +268,13 @@ def __init__(self):
assert torch.equal(fabric_module.layer.weight, weight)
assert torch.equal(fabric_module.layer.bias, bias)
- if _TORCH_GREATER_EQUAL_2_1:
- # Can use additional `assign` argument in PyTorch >= 2.1
- with torch.device("meta"):
- original_module = OriginalModule()
- fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
- assert fabric_module.layer.weight.is_meta
- fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
- assert not fabric_module.layer.weight.is_meta
+ # Can use additional `assign` argument
+ with torch.device("meta"):
+ original_module = OriginalModule()
+ fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
+ assert fabric_module.layer.weight.is_meta
+ fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
+ assert not fabric_module.layer.weight.is_meta
@pytest.mark.parametrize(
@@ -448,6 +498,7 @@ def test_fabric_optimizer_steps():
# with model as optimizer
strategy = Mock(spec=["optimizer_step", "model"])
+ strategy.model = Mock(spec=Optimizable)
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
fabric_optimizer.step()
strategy.optimizer_step.assert_called_once_with(strategy.model)
@@ -492,8 +543,6 @@ def test_is_wrapped(compile):
# _FabricModule inside an OptimizedModule
if compile:
- from torch._dynamo import OptimizedModule
-
module = torch.nn.Linear(2, 2)
wrapped = torch.compile(_FabricModule(module, Mock()))
assert isinstance(wrapped, OptimizedModule)
@@ -550,8 +599,8 @@ def test_unwrap_objects(compile):
def test_step_method_redirection():
- """Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
- module."""
+ """Test that the FabricModule redirects methods marked as 'forward methods' through forward to avoid bypassing the
+ DDP/FSDP wrappers."""
class DDP(torch.nn.Module):
def __init__(self, module):
@@ -571,11 +620,11 @@ def training_step(self, arg, kwarg=None):
assert kwarg == "train_kwarg"
return "training_step_return"
- def validation_step(self, arg, kwarg=None):
+ def marked_method(self, arg, kwarg=None):
assert self() == "forward_return"
- assert arg == "val_arg"
- assert kwarg == "val_kwarg"
- return "validation_step_return"
+ assert arg == "marked_arg"
+ assert kwarg == "marked_kwarg"
+ return "marked_method_return"
def normal_method(self):
pass
@@ -603,10 +652,18 @@ def normal_method(self):
assert original_module.forward.__name__ == "forward"
# The special methods get redirected correctly to produce the expected output
+ strategy.precision.forward_context.reset_mock()
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
- assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
- strategy.precision.forward_context.assert_called()
+ assert strategy.precision.forward_context.call_count == 2
+
+ # Other methods must be marked explicitly to be redirected
+ strategy.precision.forward_context.reset_mock()
+ with pytest.raises(RuntimeError, match="You are calling the method .* from outside the model"):
+ fabric_module.marked_method("marked_arg", kwarg="marked_kwarg")
+ fabric_module.mark_forward_method("marked_method")
+ assert fabric_module.marked_method("marked_arg", kwarg="marked_kwarg") == "marked_method_return"
+ strategy.precision.forward_context.assert_called_once()
# The forward method remains untouched/unpatched after the special methods have been called
assert original_module.forward.__name__ == "forward"
@@ -614,7 +671,7 @@ def normal_method(self):
# Special case: forward_module == original_module -> no special treatment applied
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
assert fabric_module.training_step == original_module.training_step
- assert fabric_module.validation_step == original_module.validation_step
+ assert fabric_module.marked_method == original_module.marked_method
@RunIf(dynamo=True)
@@ -624,18 +681,13 @@ def test_unwrap_compiled():
# We wrap `torch.compile` on import of lightning in `wrappers.py`
assert torch.compile.__wrapped__
- with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False):
- unwrapped, compile_kwargs = _unwrap_compiled(model)
- assert unwrapped is model
- assert compile_kwargs is None
-
compiled = torch.compile(model, fullgraph=True, dynamic=True, disable=False)
assert compiled._compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
unwrapped, compile_kwargs = _unwrap_compiled(compiled)
assert unwrapped is compiled._orig_mod
assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
- del compiled._compile_kwargs
+ compiled._compile_kwargs = None
with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
_unwrap_compiled(compiled)
diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py
index 0d8f03fcdd120..cc6c23bddbd7b 100644
--- a/tests/tests_fabric/utilities/test_distributed.py
+++ b/tests/tests_fabric/utilities/test_distributed.py
@@ -3,7 +3,9 @@
from functools import partial
from pathlib import Path
from unittest import mock
+from unittest.mock import Mock
+import lightning.fabric
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
@@ -11,13 +13,17 @@
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
+ _destroy_dist_connection,
_gather_all_tensors,
_InfiniteBarrier,
+ _init_dist_connection,
+ _is_dtensor,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
is_shared_filesystem,
)
+from lightning_utilities.core.imports import RequirementCache
from tests_fabric.helpers.runif import RunIf
@@ -119,6 +125,10 @@ def test_collective_operations(devices, process):
spawn_launch(process, devices)
+@pytest.mark.skipif(
+ RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"),
+ reason="torch.distributed not compatible with numpy>=2.0",
+)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
@@ -217,3 +227,24 @@ def test_infinite_barrier():
barrier.__exit__(None, None, None)
assert barrier.barrier.call_count == 2
dist_mock.destroy_process_group.assert_called_once()
+
+
+@mock.patch("lightning.fabric.utilities.distributed.atexit")
+@mock.patch("lightning.fabric.utilities.distributed.torch.distributed.init_process_group")
+def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
+ _init_dist_connection(LightningEnvironment(), "nccl")
+ atexit_mock.register.assert_called_once_with(_destroy_dist_connection)
+ atexit_mock.reset_mock()
+ _init_dist_connection(LightningEnvironment(), "gloo")
+ atexit_mock.register.assert_not_called()
+
+
+@RunIf(min_torch="2.4")
+def test_is_dtensor(monkeypatch):
+ from torch.distributed._tensor import DTensor
+
+ assert _is_dtensor(Mock(spec=DTensor))
+ assert not _is_dtensor(torch.zeros(2, 2))
+
+ monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
+ assert not _is_dtensor(Mock(spec=DTensor))
diff --git a/tests/tests_fabric/utilities/test_init.py b/tests/tests_fabric/utilities/test_init.py
index 25d53abd7d261..dd08dec020669 100644
--- a/tests/tests_fabric/utilities/test_init.py
+++ b/tests/tests_fabric/utilities/test_init.py
@@ -16,7 +16,11 @@
import pytest
import torch.nn
-from lightning.fabric.utilities.init import _EmptyInit, _materialize_meta_tensors
+from lightning.fabric.utilities.init import (
+ _EmptyInit,
+ _has_meta_device_parameters_or_buffers,
+ _materialize_meta_tensors,
+)
from tests_fabric.helpers.runif import RunIf
@@ -54,7 +58,6 @@ def test_empty_init_speed():
assert normal_init_time > 2 * empty_init_time
-@RunIf(min_torch="2.1")
def test_materialize_meta_tensors():
class Submodule(torch.nn.Module):
def __init__(self):
@@ -85,3 +88,30 @@ def reset_parameters(self):
assert model.buf.device.type == "cpu"
assert len(list(model.parameters())) == 4
assert all(p.device.type == "cpu" for p in model.parameters())
+
+
+def test_has_meta_device_parameters_or_buffers():
+ """Test that the `_has_meta_device_parameters_or_buffers` function can find meta-device parameters in models and
+ optimizers."""
+
+ class BufferModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buffer", torch.ones(2, device="meta"))
+
+ # nn.Module
+ module = torch.nn.Linear(2, 2)
+ meta_module = torch.nn.Linear(2, 2, device="meta")
+ buffer_meta_module = BufferModule()
+ assert not _has_meta_device_parameters_or_buffers(module)
+ assert _has_meta_device_parameters_or_buffers(meta_module)
+ assert _has_meta_device_parameters_or_buffers(torch.nn.Sequential(module, meta_module, torch.nn.ReLU()))
+ assert _has_meta_device_parameters_or_buffers(buffer_meta_module)
+ # optim.Optimizer
+ optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
+ meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
+ assert not _has_meta_device_parameters_or_buffers(optimizer)
+ assert _has_meta_device_parameters_or_buffers(meta_optimizer)
+ # unsupported objects
+ with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
+ _has_meta_device_parameters_or_buffers(None)
diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py
index 574f8bf36247b..39d257f8b685b 100644
--- a/tests/tests_fabric/utilities/test_load.py
+++ b/tests/tests_fabric/utilities/test_load.py
@@ -21,10 +21,7 @@
_NotYetLoadedTensor,
)
-from tests_fabric.helpers.runif import RunIf
-
-@RunIf(min_torch="2.0.0")
def test_lazy_load_module(tmp_path):
model0 = nn.Linear(2, 2)
torch.save(model0.state_dict(), tmp_path / "model.pt")
@@ -34,6 +31,8 @@ def test_lazy_load_module(tmp_path):
model1.load_state_dict(checkpoint)
assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
+ assert checkpoint["weight"].device == torch.device("cpu")
+ assert type(checkpoint["weight"].to("cpu")) is torch.Tensor
assert type(model0.weight.data) is torch.Tensor
assert torch.equal(model0.weight, model1.weight)
assert torch.equal(model0.bias, model1.bias)
@@ -43,7 +42,6 @@ class ATensor(torch.Tensor):
pass
-@RunIf(min_torch="2.0.0")
def test_lazy_load_tensor(tmp_path):
"""Test that lazy load can handle different classes of tensors."""
expected = {
@@ -61,7 +59,6 @@ def test_lazy_load_tensor(tmp_path):
assert torch.equal(t0, t1_materialized)
-@RunIf(min_torch="2.0.0")
def test_lazy_load_mixed_state(tmp_path):
model0 = nn.Linear(2, 2)
optim0 = torch.optim.Adam(model0.parameters())
@@ -82,13 +79,11 @@ def test_lazy_load_mixed_state(tmp_path):
optim1.load_state_dict(loaded_checkpoint["optimizer"])
-@RunIf(min_torch="2.0.0")
def test_lazy_load_raises():
with pytest.raises(FileNotFoundError, match="foo' does not exist"):
_lazy_load("foo")
-@RunIf(min_torch="2.0.0")
def test_materialize_tensors(tmp_path):
# Single tensor
tensor = torch.tensor([1, 2])
diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py
index 5b6211331474a..0f6500cb42be1 100644
--- a/tests/tests_fabric/utilities/test_logger.py
+++ b/tests/tests_fabric/utilities/test_logger.py
@@ -11,14 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from argparse import Namespace
from dataclasses import dataclass
+from pathlib import Path
import numpy as np
import torch
from lightning.fabric.utilities.logger import (
_add_prefix,
+ _convert_json_serializable,
_convert_params,
_flatten_dict,
_sanitize_callable_params,
@@ -91,7 +92,7 @@ class B:
def test_sanitize_callable_params():
- """Callback function are not serializiable.
+ """Callback functions are not serializable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
@@ -103,11 +104,21 @@ def return_something():
def wrapper_something():
return return_something
+ class ClassNoArgs:
+ def __init__(self):
+ pass
+
+ class ClassWithCall:
+ def __call__(self):
+ return "name"
+
params = Namespace(
foo="bar",
something=return_something,
wrapper_something_wo_name=(lambda: lambda: "1"),
wrapper_something=wrapper_something,
+ class_no_args=ClassNoArgs,
+ class_with_call=ClassWithCall,
)
params = _convert_params(params)
@@ -117,6 +128,8 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == ""
+ assert params["class_no_args"] == "ClassNoArgs"
+ assert params["class_with_call"] == "ClassWithCall"
def test_sanitize_params():
@@ -167,3 +180,29 @@ def test_add_prefix():
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2
+
+
+def test_convert_json_serializable():
+ data = {
+ # JSON-serializable
+ "none": None,
+ "int": 1,
+ "float": 1.1,
+ "bool": True,
+ "dict": {"a": 1},
+ "list": [2, 3, 4],
+ # not JSON-serializable
+ "path": Path("path"),
+ "tensor": torch.tensor(1),
+ }
+ expected = {
+ "none": None,
+ "int": 1,
+ "float": 1.1,
+ "bool": True,
+ "dict": {"a": 1},
+ "list": [2, 3, 4],
+ "path": "path",
+ "tensor": "tensor(1)",
+ }
+ assert _convert_json_serializable(data) == expected
diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py
index 3aa78d507c346..83c7ed44120b9 100644
--- a/tests/tests_fabric/utilities/test_optimizer.py
+++ b/tests/tests_fabric/utilities/test_optimizer.py
@@ -1,36 +1,86 @@
-import collections
import dataclasses
+import pytest
import torch
from lightning.fabric.utilities.optimizer import _optimizer_to_device
from torch import Tensor
+from tests_fabric.helpers.runif import RunIf
-def test_optimizer_to_device():
- @dataclasses.dataclass(frozen=True)
+
+@pytest.mark.parametrize(
+ "optimizer_class",
+ [
+ torch.optim.Adam,
+ torch.optim.AdamW,
+ torch.optim.SGD,
+ torch.optim.RMSprop,
+ torch.optim.Adagrad,
+ torch.optim.Adadelta,
+ torch.optim.Adamax,
+ ],
+)
+@pytest.mark.parametrize(
+ "src_device",
+ [
+ torch.device("cpu"),
+ pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
+ ],
+)
+@pytest.mark.parametrize(
+ "dst_device",
+ [
+ torch.device("cpu"),
+ pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)),
+ ],
+)
+def test_optimizer_to_device(optimizer_class, src_device, dst_device):
+ # Optimizer with no state initialized
+ model = torch.nn.Linear(2, 2, device=src_device)
+ optimizer = optimizer_class(model.parameters(), lr=0.1)
+ _optimizer_to_device(optimizer, dst_device)
+ _assert_opt_parameters_on_device(optimizer, dst_device)
+
+ # Optimizer with state initialized
+ model = torch.nn.Linear(2, 2, device=src_device)
+ optimizer = optimizer_class(model.parameters(), lr=0.1)
+ model(torch.randn(2, 2, device=src_device)).sum().backward()
+ optimizer.step()
+ _optimizer_to_device(optimizer, dst_device)
+ _assert_opt_parameters_on_device(optimizer, dst_device)
+
+
+def _assert_opt_parameters_on_device(opt, device):
+ for _, v in opt.state.items():
+ for key, item in v.items():
+ if not isinstance(item, Tensor):
+ continue
+ if key == "step":
+ # The "step" tensor needs to remain on CPU
+ assert item.device.type == "cpu"
+ else:
+ assert item.device.type == device.type
+
+
+@RunIf(min_cuda_gpus=1)
+@pytest.mark.parametrize("frozen", [True, False])
+def test_optimizer_to_device_with_dataclass_in_state(frozen):
+ src_device = torch.device("cpu")
+ dst_device = torch.device("cuda")
+ model = torch.nn.Linear(32, 2, device=src_device)
+
+ @dataclasses.dataclass(frozen=frozen)
class FooState:
- bar: int
+ integer: int
+ tensor: Tensor
class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.state["dummy"] = torch.tensor(0)
- self.state["frozen"] = FooState(0)
-
- layer = torch.nn.Linear(32, 2)
- opt = TestOptimizer(layer.parameters(), lr=0.1)
- _optimizer_to_device(opt, "cpu")
- if torch.cuda.is_available():
- _optimizer_to_device(opt, "cuda")
- assert_opt_parameters_on_device(opt, "cuda")
-
-
-def assert_opt_parameters_on_device(opt, device: str):
- for param in opt.state.values():
- # Not sure there are any global tensors in the state dict
- if isinstance(param, Tensor):
- assert param.data.device.type == device
- elif isinstance(param, collections.abc.Mapping):
- for subparam in param.values():
- if isinstance(subparam, Tensor):
- assert param.data.device.type == device
+ self.state[model.weight] = {"dummy": torch.tensor(0)}
+ self.state[model.bias] = FooState(0, torch.tensor(0))
+
+ optimizer = TestOptimizer(model.parameters(), lr=0.1)
+ _optimizer_to_device(optimizer, dst_device)
+ assert optimizer.state[model.weight]["dummy"].device.type == dst_device.type
+ assert optimizer.state[model.bias].tensor.device.type == ("cpu" if frozen else dst_device.type)
diff --git a/tests/tests_fabric/utilities/test_registry.py b/tests/tests_fabric/utilities/test_registry.py
index 75e6e12f5abff..a06e5c8e82615 100644
--- a/tests/tests_fabric/utilities/test_registry.py
+++ b/tests/tests_fabric/utilities/test_registry.py
@@ -1,8 +1,8 @@
import contextlib
from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock
-from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks
@@ -47,18 +47,14 @@ def factory_multiple_callbacks_list():
@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
- query_mock = Mock()
+ query_mock = MagicMock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
- import_path = "importlib.metadata.entry_points"
- elif _PYTHON_GREATER_EQUAL_3_8_0:
- query_mock().get.return_value = [entry_point]
- import_path = "importlib.metadata.entry_points"
else:
- query_mock.return_value = [entry_point]
- import_path = "pkg_resources.iter_entry_points"
- with mock.patch(import_path, query_mock):
+ query_mock().get.return_value = [entry_point]
+
+ with mock.patch("lightning.fabric.utilities.registry.entry_points", query_mock):
yield
diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py
index 351f6a47b74cd..be2ecba3294b1 100644
--- a/tests/tests_fabric/utilities/test_seed.py
+++ b/tests/tests_fabric/utilities/test_seed.py
@@ -1,17 +1,24 @@
import os
+import random
from unittest import mock
from unittest.mock import Mock
-import lightning.fabric.utilities
+import numpy
import pytest
import torch
-from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
+from lightning.fabric.utilities.seed import (
+ _collect_rng_states,
+ _set_rng_states,
+ pl_worker_init_function,
+ reset_seed,
+ seed_everything,
+)
@mock.patch.dict(os.environ, clear=True)
def test_default_seed():
"""Test that the default seed is 0 when no seed provided and no environment variable set."""
- assert lightning.fabric.utilities.seed.seed_everything() == 0
+ assert seed_everything() == 0
assert os.environ["PL_GLOBAL_SEED"] == "0"
@@ -19,11 +26,11 @@ def test_default_seed():
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
with pytest.warns(UserWarning, match="No seed found"):
- lightning.fabric.utilities.seed.seed_everything()
+ seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")
with pytest.warns(None) as record:
- lightning.fabric.utilities.seed.seed_everything()
+ seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")
@@ -33,14 +40,14 @@ def test_seed_stays_same_with_multiple_seed_everything_calls():
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""Ensure that the PL_GLOBAL_SEED environment is read."""
- assert lightning.fabric.utilities.seed.seed_everything() == 2020
+ assert seed_everything() == 2020
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
- seed = lightning.fabric.utilities.seed.seed_everything()
+ seed = seed_everything()
assert seed == 0
@@ -49,7 +56,7 @@ def test_invalid_seed():
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
- actual = lightning.fabric.utilities.seed.seed_everything(seed)
+ actual = seed_everything(seed)
assert actual == 0
@@ -57,7 +64,7 @@ def test_reset_seed_no_op():
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
assert "PL_GLOBAL_SEED" not in os.environ
seed_before = torch.initial_seed()
- lightning.fabric.utilities.seed.reset_seed()
+ reset_seed()
assert torch.initial_seed() == seed_before
assert "PL_GLOBAL_SEED" not in os.environ
@@ -68,18 +75,26 @@ def test_reset_seed_everything(workers):
assert "PL_GLOBAL_SEED" not in os.environ
assert "PL_SEED_WORKERS" not in os.environ
- lightning.fabric.utilities.seed.seed_everything(123, workers)
+ seed_everything(123, workers)
before = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
- lightning.fabric.utilities.seed.reset_seed()
+ reset_seed()
after = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
assert torch.allclose(before, after)
+def test_reset_seed_non_verbose(caplog):
+ seed_everything(123)
+ assert len(caplog.records) == 1
+ caplog.clear()
+ reset_seed() # should call `seed_everything(..., verbose=False)`
+ assert len(caplog.records) == 0
+
+
def test_backward_compatibility_rng_states_dict():
"""Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
states = _collect_rng_states()
@@ -95,3 +110,26 @@ def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
states = _collect_rng_states()
assert states["torch.cuda"] == []
+
+
+@pytest.mark.parametrize(("num_workers", "num_ranks"), [(64, 64)])
+@pytest.mark.parametrize("base_seed", [100, 1024, 2**32 - 1])
+def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
+ """Test that Lightning's `worker_init_fn` sets unique seeds per worker/rank derived from the base seed."""
+ torch_rands = set()
+ stdlib_rands = set()
+ numpy_rands = set()
+
+ for worker_id in range(num_workers):
+ for rank in range(num_ranks):
+ seed_everything(base_seed)
+ pl_worker_init_function(worker_id, rank)
+ torch_rands.add(tuple(torch.randint(0, 1_000_000, (100,)).tolist()))
+ stdlib_rands.add(tuple(random.randint(0, 1_000_000) for _ in range(100)))
+ numpy_rands.add(tuple(numpy.random.randint(0, 1_000_000, (100,)).tolist()))
+
+ # Assert there are no duplicates (no collisions)
+ assert len(torch_rands) == num_ranks * num_workers
+ assert len(stdlib_rands) == num_ranks * num_workers
+ assert len(numpy_rands) == num_ranks * num_workers
+ assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks
diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py
index f2c3de30a325b..eefadb285af02 100644
--- a/tests/tests_fabric/utilities/test_throughput.py
+++ b/tests/tests_fabric/utilities/test_throughput.py
@@ -13,11 +13,9 @@
measure_flops,
)
-from tests_fabric.helpers.runif import RunIf
from tests_fabric.test_fabric import BoringModel
-@RunIf(min_torch="2.1")
def test_measure_flops():
with torch.device("meta"):
model = BoringModel()
diff --git a/tests/tests_pytorch/__init__.py b/tests/tests_pytorch/__init__.py
index df603cce7b830..a43ffae6a83b4 100644
--- a/tests/tests_pytorch/__init__.py
+++ b/tests/tests_pytorch/__init__.py
@@ -13,24 +13,19 @@
# limitations under the License.
import os
import warnings
+from pathlib import Path
import pytest
-_TEST_ROOT = os.path.dirname(__file__)
-_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
-_TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp")
-_PATH_DATASETS = os.path.join(_PROJECT_ROOT, "Datasets")
-_PATH_LEGACY = os.path.join(_PROJECT_ROOT, "legacy")
+_TEST_ROOT = Path(__file__).parent.parent
+_PROJECT_ROOT = _TEST_ROOT.parent
+_PATH_DATASETS = _PROJECT_ROOT / "Datasets"
+_PATH_LEGACY = _TEST_ROOT / "legacy"
# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
-if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""):
+if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""):
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
-
-if not os.path.isdir(_TEMP_PATH):
- os.mkdir(_TEMP_PATH)
-
-
# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel)
warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py
index 16fd5491d2ddd..48b346006786f 100644
--- a/tests/tests_pytorch/accelerators/test_xla.py
+++ b/tests/tests_pytorch/accelerators/test_xla.py
@@ -56,7 +56,7 @@ def test_resume_training_on_cpu(tmp_path):
"""Checks if training can be resumed from a saved checkpoint on CPU."""
# Train a model on TPU
model = BoringModel()
- trainer = Trainer(max_epochs=1, accelerator="tpu", devices="auto")
+ trainer = Trainer(max_epochs=1, accelerator="tpu", devices="auto", default_root_dir=tmp_path)
trainer.fit(model)
if trainer.world_size != trainer.num_devices:
@@ -67,7 +67,7 @@ def test_resume_training_on_cpu(tmp_path):
model_path = trainer.checkpoint_callback.best_model_path
# Verify saved Tensors are on CPU
- ckpt = torch.load(model_path)
+ ckpt = torch.load(model_path, weights_only=True)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")
diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
index a74bc300f4194..b8d3d6d36c075 100644
--- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
@@ -143,7 +143,7 @@ def on_train_start(self) -> None:
with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
- ) as mock_progress_stop:
+ ) as mock_progress_stop, pytest.raises(SystemExit):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
@@ -308,20 +308,6 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmp_path):
assert val_bar.total == 4
-@RunIf(rich=True)
-@mock.patch("lightning.pytorch.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
-def test_rich_progress_bar_colab_light_theme_update(*_):
- theme = RichProgressBar().theme
- assert theme.description == "black"
- assert theme.batch_progress == "black"
- assert theme.metrics == "black"
-
- theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme
- assert theme.description == "blue"
- assert theme.batch_progress == "black"
- assert theme.metrics == "red"
-
-
@RunIf(rich=True)
def test_rich_progress_bar_metric_display_task_id(tmp_path):
class CustomModel(BoringModel):
@@ -447,9 +433,10 @@ def test_rich_progress_bar_padding():
@RunIf(rich=True)
-def test_rich_progress_bar_can_be_pickled():
+def test_rich_progress_bar_can_be_pickled(tmp_path):
bar = RichProgressBar()
trainer = Trainer(
+ default_root_dir=tmp_path,
callbacks=[bar],
max_epochs=1,
limit_train_batches=1,
diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
index cbe07164fb427..d5187d5a1e325 100644
--- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
@@ -18,7 +18,7 @@
from collections import defaultdict
from typing import Union
from unittest import mock
-from unittest.mock import ANY, PropertyMock, call
+from unittest.mock import ANY, Mock, PropertyMock, call
import pytest
import torch
@@ -550,9 +550,10 @@ def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmp_path):
tqdm_write.assert_not_called()
-def test_tqdm_progress_bar_can_be_pickled():
+def test_tqdm_progress_bar_can_be_pickled(tmp_path):
bar = TQDMProgressBar()
trainer = Trainer(
+ default_root_dir=tmp_path,
callbacks=[bar],
max_epochs=1,
limit_train_batches=1,
@@ -782,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
pbar.enable()
trainer.test(model)
assert pbar.is_disabled
+
+
+@pytest.mark.parametrize("leave", [True, False])
+def test_tqdm_leave(leave, tmp_path):
+ pbar = TQDMProgressBar(leave=leave)
+ pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm)
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ callbacks=[pbar],
+ max_epochs=3,
+ limit_train_batches=1,
+ limit_val_batches=1,
+ benchmark=True,
+ )
+ trainer.fit(model)
+ assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py
index eecb8b975533a..aacee958faa45 100644
--- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py
+++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py
@@ -162,7 +162,7 @@ def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch, tmp
monkeypatch.setattr(imports, "_PSUTIL_AVAILABLE", False)
monitor = DeviceStatsMonitor()
- trainer = Trainer(logger=CSVLogger(tmp_path))
+ trainer = Trainer(accelerator="cpu", logger=CSVLogger(tmp_path))
assert trainer.strategy.root_device == torch.device("cpu")
with pytest.raises(ModuleNotFoundError, match="psutil` is not installed"):
monitor.setup(trainer, Mock(), "fit")
diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py
index 651e5978e7aca..b7e52ee549bcc 100644
--- a/tests/tests_pytorch/callbacks/test_early_stopping.py
+++ b/tests/tests_pytorch/callbacks/test_early_stopping.py
@@ -15,6 +15,7 @@
import math
import os
import pickle
+from contextlib import nullcontext
from typing import List, Optional
from unittest import mock
from unittest.mock import Mock
@@ -22,6 +23,7 @@
import cloudpickle
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -56,7 +58,7 @@ def on_train_epoch_end(self, trainer, pl_module):
self.saved_states.append(self.state_dict().copy())
-@RunIf(sklearn=True)
+@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_resume_early_stopping_from_checkpoint(tmp_path):
"""Prevent regressions to bugs:
@@ -82,7 +84,7 @@ def test_resume_early_stopping_from_checkpoint(tmp_path):
checkpoint_filepath = checkpoint_callback.kth_best_model_path
# ensure state is persisted properly
- checkpoint = torch.load(checkpoint_filepath)
+ checkpoint = torch.load(checkpoint_filepath, weights_only=True)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
assert len(early_stop_callback.saved_states) == 4
@@ -191,11 +193,13 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")
early_stopping_pickled = pickle.dumps(early_stopping)
- early_stopping_loaded = pickle.loads(early_stopping_pickled)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
early_stopping_pickled = cloudpickle.dumps(early_stopping)
- early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py
index 078e3d4462b44..0c09ae5d5042a 100644
--- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py
+++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py
@@ -15,6 +15,7 @@
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
@@ -113,7 +114,7 @@ def configure_optimizers(self):
trainer.fit(model)
assert model.backbone.has_been_used
- trainer = Trainer(max_epochs=3)
+ trainer = Trainer(default_root_dir=tmp_path, max_epochs=3)
trainer.fit(model, ckpt_path=chk.last_model_path)
@@ -245,7 +246,7 @@ def configure_optimizers(self):
model = FreezeModel()
cb = OnEpochLayerFinetuning()
- trainer = Trainer(max_epochs=10, callbacks=[cb])
+ trainer = Trainer(default_root_dir=tmp_path, max_epochs=10, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
trainer.fit(model, ckpt_path=chk.last_model_path)
@@ -359,6 +360,8 @@ def test_callbacks_restore(tmp_path):
"foreach": None,
"differentiable": False,
}
+ if _TORCH_GREATER_EQUAL_2_3:
+ expected["fused"] = None
assert callback._internal_optimizer_metadata[0][0] == expected
@@ -374,6 +377,8 @@ def test_callbacks_restore(tmp_path):
"foreach": None,
"differentiable": False,
}
+ if _TORCH_GREATER_EQUAL_2_3:
+ expected["fused"] = None
assert callback._internal_optimizer_metadata[0][1] == expected
diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py
index 483c8c73e99e2..40d694bb35ebc 100644
--- a/tests/tests_pytorch/callbacks/test_lambda_function.py
+++ b/tests/tests_pytorch/callbacks/test_lambda_function.py
@@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
+import pytest
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, LambdaCallback
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -23,10 +24,13 @@
def test_lambda_call(tmp_path):
seed_everything(42)
+ class CustomException(Exception):
+ pass
+
class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
- raise KeyboardInterrupt
+ raise CustomException("Custom exception to trigger `on_exception` hooks")
checker = set()
@@ -59,7 +63,8 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
- trainer.fit(model, ckpt_path=ckpt_path)
+ with pytest.raises(CustomException):
+ trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)
diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py
index ebe21e272aac8..4aedb4f23fa14 100644
--- a/tests/tests_pytorch/callbacks/test_lr_monitor.py
+++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py
@@ -44,6 +44,9 @@ def test_lr_monitor_single_lr(tmp_path):
assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
+ assert all(
+ v is None for v in lr_monitor.last_weight_decay_values.values()
+ ), "Weight decay should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
assert list(lr_monitor.lrs) == ["lr-SGD"]
diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py
index 0f255367f1a10..b42907dc9a38d 100644
--- a/tests/tests_pytorch/callbacks/test_model_summary.py
+++ b/tests/tests_pytorch/callbacks/test_model_summary.py
@@ -49,6 +49,7 @@ def summarize(
total_parameters: int,
trainable_parameters: int,
model_size: float,
+ total_training_modes,
**summarize_kwargs: Any,
) -> None:
assert summary_data[1][0] == "Name"
@@ -64,6 +65,8 @@ def summarize(
assert summary_data[4][0] == "Mode"
assert summary_data[4][1][0] == "train"
+ assert total_training_modes == {"train": 1, "eval": 0}
+
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1)
diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py
index 343716d332279..02604f5a195fe 100644
--- a/tests/tests_pytorch/callbacks/test_prediction_writer.py
+++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py
@@ -35,7 +35,7 @@ def test_prediction_writer_invalid_write_interval():
DummyPredictionWriter("something")
-def test_prediction_writer_hook_call_intervals():
+def test_prediction_writer_hook_call_intervals(tmp_path):
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval."""
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()
@@ -44,7 +44,7 @@ def test_prediction_writer_hook_call_intervals():
model = BoringModel()
cb = DummyPredictionWriter("batch_and_epoch")
- trainer = Trainer(limit_predict_batches=4, callbacks=cb)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=dataloader)
assert len(results) == 4
assert cb.write_on_batch_end.call_count == 4
@@ -54,7 +54,7 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("batch_and_epoch")
- trainer = Trainer(limit_predict_batches=4, callbacks=cb)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 1
@@ -63,7 +63,7 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("batch")
- trainer = Trainer(limit_predict_batches=4, callbacks=cb)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 0
@@ -72,21 +72,21 @@ def test_prediction_writer_hook_call_intervals():
DummyPredictionWriter.write_on_epoch_end.reset_mock()
cb = DummyPredictionWriter("epoch")
- trainer = Trainer(limit_predict_batches=4, callbacks=cb)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 0
assert cb.write_on_epoch_end.call_count == 1
@pytest.mark.parametrize("num_workers", [0, 2])
-def test_prediction_writer_batch_indices(num_workers):
+def test_prediction_writer_batch_indices(num_workers, tmp_path):
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
model = BoringModel()
writer = DummyPredictionWriter("batch_and_epoch")
- trainer = Trainer(limit_predict_batches=4, callbacks=writer)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
trainer.predict(model, dataloaders=dataloader)
writer.write_on_batch_end.assert_has_calls([
@@ -101,7 +101,7 @@ def test_prediction_writer_batch_indices(num_workers):
])
-def test_batch_level_batch_indices():
+def test_batch_level_batch_indices(tmp_path):
"""Test that batch_indices are returned when `return_predictions=False`."""
DummyPredictionWriter.write_on_batch_end = Mock()
@@ -112,7 +112,7 @@ def on_predict_epoch_end(self, *args, **kwargs):
writer = DummyPredictionWriter("batch")
model = CustomBoringModel()
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4)
- trainer = Trainer(limit_predict_batches=4, callbacks=writer)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
writer.write_on_batch_end.assert_has_calls([
diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py
index e0c6a1d9236dc..f3ec7e2ccc029 100644
--- a/tests/tests_pytorch/callbacks/test_pruning.py
+++ b/tests/tests_pytorch/callbacks/test_pruning.py
@@ -190,7 +190,7 @@ def test_pruning_callback_ddp_cpu(tmp_path):
@pytest.mark.parametrize("resample_parameters", [False, True])
-def test_pruning_lth_callable(tmp_path, resample_parameters: bool):
+def test_pruning_lth_callable(tmp_path, resample_parameters):
model = TestModel()
class ModelPruningTestCallback(ModelPruning):
@@ -206,7 +206,7 @@ def apply_lottery_ticket_hypothesis(self):
curr, curr_name = self._parameters_to_prune[i]
assert name == curr_name
actual, expected = getattr(curr, name).data, getattr(copy, name).data
- allclose = torch.allclose(actual, expected)
+ allclose = torch.allclose(actual.cpu(), expected)
assert not allclose if self._resample_parameters else allclose
pruning = ModelPruningTestCallback(
@@ -273,7 +273,7 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
filepath = str(tmp_path / "foo.ckpt")
trainer.save_checkpoint(filepath)
- model.load_state_dict(torch.load(filepath), strict=False)
+ model.load_state_dict(torch.load(filepath, weights_only=True), strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning
@@ -310,7 +310,13 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
ckpt_callback = ModelCheckpoint(
monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end
)
- trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, enable_progress_bar=False)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ logger=False,
+ callbacks=[pruning_callback, ckpt_callback],
+ max_epochs=3,
+ enable_progress_bar=False,
+ )
with caplog.at_level(INFO):
trainer.fit(model)
diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py
index f8ede0eb0239e..73709fd80a833 100644
--- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py
+++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py
@@ -56,7 +56,13 @@ def example_input_array(self) -> Any:
summary = summarize(model)
summary_data = summary._get_summary_data()
- model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)
+ model_summary.summarize(
+ summary_data=summary_data,
+ total_parameters=1,
+ trainable_parameters=1,
+ model_size=1,
+ total_training_modes=summary.total_training_modes,
+ )
# ensure that summary was logged + the breakdown of model parameters
assert mock_console.call_count == 2
diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py
index f4d0c946cefa0..5634feaf221cd 100644
--- a/tests/tests_pytorch/callbacks/test_spike.py
+++ b/tests/tests_pytorch/callbacks/test_spike.py
@@ -213,6 +213,8 @@ def test_trainer_spike_detection_integration(tmp_path, global_rank_spike, num_de
cb.should_raise = spike_value is None or finite_only or spike_value == float("inf")
trainer = Trainer(
+ default_root_dir=tmp_path,
+ logger=False,
callbacks=[cb],
accelerator="cpu",
devices=num_devices,
diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py
index 9467e45e2fa80..a74efba75813b 100644
--- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py
+++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py
@@ -8,10 +8,7 @@
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.demos.boring_classes import BoringModel
-from tests_pytorch.helpers.runif import RunIf
-
-@RunIf(min_torch="2.1")
def test_measure_flops():
with torch.device("meta"):
model = BoringModel()
diff --git a/tests/tests_pytorch/callbacks/test_timer.py b/tests/tests_pytorch/callbacks/test_timer.py
index 3a62acca2a026..e6359a2e9a5e1 100644
--- a/tests/tests_pytorch/callbacks/test_timer.py
+++ b/tests/tests_pytorch/callbacks/test_timer.py
@@ -26,24 +26,24 @@
from tests_pytorch.helpers.runif import RunIf
-def test_trainer_flag(caplog):
+def test_trainer_flag(caplog, tmp_path):
class TestModel(BoringModel):
def on_fit_start(self):
raise SystemExit()
- trainer = Trainer(max_time={"seconds": 1337})
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1337})
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1337
- trainer = Trainer(max_time={"seconds": 1337}, callbacks=[Timer()])
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1337}, callbacks=[Timer()])
with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text
# Make sure max_time still honored even if max_epochs == -1
- trainer = Trainer(max_time={"seconds": 1}, max_epochs=-1)
+ trainer = Trainer(default_root_dir=tmp_path, logger=False, max_time={"seconds": 1}, max_epochs=-1)
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py
index 44a284d92bbb2..ff8f3c95e43c5 100644
--- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py
+++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py
@@ -24,7 +24,7 @@
def test_disabled_checkpointing():
# no callback
- trainer = Trainer(max_epochs=3, enable_checkpointing=False)
+ trainer = Trainer(logger=False, max_epochs=3, enable_checkpointing=False)
assert not trainer.checkpoint_callbacks
trainer.fit(BoringModel())
assert not trainer.checkpoint_callbacks
diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
index 9463935e8f644..be754d3911ade 100644
--- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
+++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
@@ -78,7 +78,7 @@ def load_model():
from lightning.pytorch.utilities.migration import pl_legacy_patch
with pl_legacy_patch():
- _ = torch.load(path_ckpt)
+ _ = torch.load(path_ckpt, weights_only=False)
with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = ThreadExceptionHandler(target=load_model)
diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
index c911885117e29..97d8d3c4d0e4a 100644
--- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
+++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
@@ -17,7 +17,9 @@
import re
import time
from argparse import Namespace
+from contextlib import nullcontext
from datetime import timedelta
+from inspect import signature
from pathlib import Path
from typing import Union
from unittest import mock
@@ -28,7 +30,9 @@
import pytest
import torch
import yaml
+from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -348,11 +352,13 @@ def test_pickling(tmp_path):
ckpt = ModelCheckpoint(dirpath=tmp_path)
ckpt_pickled = pickle.dumps(ckpt)
- ckpt_loaded = pickle.loads(ckpt_pickled)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
ckpt_pickled = cloudpickle.dumps(ckpt)
- ckpt_loaded = cloudpickle.loads(ckpt_pickled)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
@@ -918,8 +924,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmp_path):
assert os.path.isfile(path_last_epoch)
assert os.path.isfile(path_last)
- ckpt_last_epoch = torch.load(path_last_epoch)
- ckpt_last = torch.load(path_last)
+ ckpt_last_epoch = torch.load(path_last_epoch, weights_only=True)
+ ckpt_last = torch.load(path_last, weights_only=True)
assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
@@ -1166,7 +1172,7 @@ def training_step(self, *args):
)
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
- ckpts = [torch.load(ckpt) for ckpt in tmp_path.iterdir()]
+ ckpts = [torch.load(ckpt, weights_only=True) for ckpt in tmp_path.iterdir()]
ckpts = [
ckpt["callbacks"][
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
@@ -1450,7 +1456,7 @@ def test_save_last_saves_correct_last_model_path(tmp_path):
expected = "foo=1-last.ckpt"
assert os.listdir(tmp_path) == [expected]
full_path = tmp_path / expected
- ckpt = torch.load(full_path)
+ ckpt = torch.load(full_path, weights_only=True)
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == str(full_path)
@@ -1482,7 +1488,7 @@ def test_none_monitor_saves_correct_best_model_path(tmp_path):
expected = "epoch=0-step=0.ckpt"
assert os.listdir(tmp_path) == [expected]
full_path = str(tmp_path / expected)
- ckpt = torch.load(full_path)
+ ckpt = torch.load(full_path, weights_only=True)
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path
@@ -1601,3 +1607,24 @@ def test_expand_home():
# it is possible to have a folder with the name `~`
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")
+
+
+@pytest.mark.parametrize(
+ ("val", "expected"),
+ [
+ ("yes", True),
+ ("True", True),
+ ("true", True),
+ ("no", False),
+ ("false", False),
+ ("False", False),
+ ("link", "link"),
+ ],
+)
+def test_save_last_cli(val, expected):
+ """Test that the CLI can parse the `save_last` argument correctly (composed type)."""
+ annot = signature(ModelCheckpoint).parameters["save_last"].annotation
+ parser = ArgumentParser()
+ parser.add_argument("--a", type=annot)
+ args = parser.parse_args(["--a", val])
+ assert args.a == expected
diff --git a/tests/tests_pytorch/checkpointing/test_torch_saving.py b/tests/tests_pytorch/checkpointing/test_torch_saving.py
index 1e0db893d36b3..4422a4063c719 100644
--- a/tests/tests_pytorch/checkpointing/test_torch_saving.py
+++ b/tests/tests_pytorch/checkpointing/test_torch_saving.py
@@ -30,7 +30,7 @@ def test_model_torch_save(tmp_path):
# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)
- trainer = torch.load(temp_path)
+ trainer = torch.load(temp_path, weights_only=False)
@RunIf(skip_windows=True)
diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py
index 9365dfb8d9bd3..78e81c7c5fa26 100644
--- a/tests/tests_pytorch/conftest.py
+++ b/tests/tests_pytorch/conftest.py
@@ -27,7 +27,7 @@
import torch.distributed
from lightning.fabric.plugins.environments.lightning import find_free_network_port
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
-from lightning.fabric.utilities.distributed import _distributed_is_initialized
+from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
@@ -88,6 +88,7 @@ def restore_env_variables():
"KMP_DUPLICATE_LIB_OK", # leaked by PyTorch
"CRC32C_SW_MODE", # leaked by tensorboardX
"TRITON_CACHE_DIR", # leaked by torch.compile
+ "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR", # leaked by torch.compile
"OMP_NUM_THREADS", # set by our launchers
# leaked by XLA
"ALLOW_MULTIPLE_LIBTPU_LOAD",
@@ -122,8 +123,7 @@ def restore_signal_handlers():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
- if _distributed_is_initialized():
- torch.distributed.destroy_process_group()
+ _destroy_dist_connection()
@pytest.fixture(autouse=True)
@@ -153,7 +153,11 @@ def thread_police_duuu_daaa_duuu_daaa():
assert not thread.is_alive()
elif isinstance(thread, _ChildProcessObserver):
thread.join(timeout=10)
- elif thread.name == "QueueFeederThread": # tensorboardX
+ elif (
+ thread.name == "QueueFeederThread" # tensorboardX
+ or thread.name == "QueueManagerThread" # torch.compile
+ or "(_read_thread)" in thread.name # torch.compile
+ ):
thread.join(timeout=20)
elif isinstance(thread, TMonitor):
thread.exit()
@@ -308,6 +312,17 @@ def single_process_pg():
os.environ.update(orig_environ)
+@pytest.fixture(autouse=True)
+def leave_no_artifacts_behind():
+ tests_root = Path(__file__).parent.parent
+ files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
+ yield
+ files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts}
+ difference = files_after - files_before
+ difference = {str(f.relative_to(tests_root)) for f in difference}
+ assert not difference, f"Test left artifacts behind: {difference}"
+
+
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
initial_size = len(items)
conditions = []
diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py
index 926cd3d9e3bf5..65fccb691a33d 100644
--- a/tests/tests_pytorch/core/test_datamodules.py
+++ b/tests/tests_pytorch/core/test_datamodules.py
@@ -208,7 +208,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# fit model
trainer.fit(model, datamodule=dm)
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
- checkpoint = torch.load(checkpoint_path)
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
assert dm.__class__.__qualname__ in checkpoint
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict"}
@@ -218,7 +218,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert dm.my_state_dict == {"my": "state_dict"}
-@RunIf(sklearn=True)
+@RunIf(sklearn=True, skip_windows=True) # Flaky test on Windows for unknown reasons
def test_full_loop(tmp_path):
seed_everything(7)
@@ -452,11 +452,12 @@ class BoringDataModule2(LightningDataModule):
@RunIf(skip_windows=True) # TODO: all durations are 0 on Windows
-def test_datamodule_hooks_are_profiled():
+def test_datamodule_hooks_are_profiled(tmp_path):
"""Test that `LightningDataModule` hooks are profiled."""
def get_trainer():
return Trainer(
+ default_root_dir=tmp_path,
max_steps=1,
limit_val_batches=0,
profiler="simple",
diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py
index d5aec835ad581..5ee91e82689f4 100644
--- a/tests/tests_pytorch/core/test_lightning_module.py
+++ b/tests/tests_pytorch/core/test_lightning_module.py
@@ -18,7 +18,6 @@
import pytest
import torch
from lightning.fabric import Fabric
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.core.module import _TrainerFabricShim
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -444,9 +443,6 @@ def test_trainer_reference_recursively():
ensemble.trainer = trainer
# references match
assert ensemble.trainer is inner.trainer
- if not _TORCH_GREATER_EQUAL_2_0:
- # and the trainer was weakly referenced
- assert inner.trainer is weakref.proxy(trainer)
def test_fabric_reference_recursively():
diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py
index 33a66308271ca..8ab6eca907ce6 100644
--- a/tests/tests_pytorch/core/test_lightning_optimizer.py
+++ b/tests/tests_pytorch/core/test_lightning_optimizer.py
@@ -14,26 +14,23 @@
from copy import deepcopy
from unittest.mock import DEFAULT, Mock, patch
-import pytest
import torch
-from lightning.pytorch import Trainer
+from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops.optimization.automatic import Closure
from lightning.pytorch.tuner.tuning import Tuner
from torch.optim import SGD, Adam, Optimizer
+from tests_pytorch.helpers.runif import RunIf
-@pytest.mark.parametrize("auto", [True, False])
-def test_lightning_optimizer(tmp_path, auto):
+
+def test_lightning_optimizer(tmp_path):
"""Test that optimizer are correctly wrapped by our LightningOptimizer."""
class TestModel(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
- if not auto:
- # note: this is not recommended, only done for coverage
- optimizer = LightningOptimizer(optimizer)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
@@ -232,10 +229,13 @@ def configure_optimizers(self):
assert sgd["zero_grad"].call_count == limit_train_batches
+@RunIf(mps=False) # mps does not support LBFGS
def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmp_path):
"""Test zero_grad is called the same number of times as LBFGS requires for reevaluation of the loss in
automatic_optimization."""
+ seed_everything(0)
+
class TestModel(BoringModel):
def configure_optimizers(self):
return torch.optim.LBFGS(self.parameters())
diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py
index 3d49190c9b4bb..ef340d1e17ea9 100644
--- a/tests/tests_pytorch/core/test_metric_result_integration.py
+++ b/tests/tests_pytorch/core/test_metric_result_integration.py
@@ -19,6 +19,7 @@
import lightning.pytorch as pl
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import OnExceptionCheckpoint
@@ -253,11 +254,12 @@ def lightning_log(fx, *args, **kwargs):
}
# make sure can be pickled
- pickle.loads(pickle.dumps(result))
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmp_path / "result")
torch.save(result, filepath)
- torch.load(filepath)
+ torch.load(filepath, weights_only=False)
# assert metric state reset to default values
result.reset()
@@ -395,7 +397,7 @@ def on_train_epoch_end(self) -> None:
@pytest.mark.parametrize(
"kwargs",
[
- {},
+ pytest.param({}, marks=RunIf(mps=False)),
pytest.param({"strategy": "ddp", "accelerator": "gpu", "devices": 1}, marks=RunIf(min_cuda_gpus=1)),
pytest.param(
{"strategy": "ddp", "accelerator": "gpu", "devices": 2}, marks=RunIf(min_cuda_gpus=2, standalone=True)
diff --git a/tests/tests_pytorch/core/test_saving.py b/tests/tests_pytorch/core/test_saving.py
index e20d5cb803b85..c7e48239754c5 100644
--- a/tests/tests_pytorch/core/test_saving.py
+++ b/tests/tests_pytorch/core/test_saving.py
@@ -13,6 +13,8 @@
def create_boring_checkpoint(tmp_path, model, accelerator="cuda"):
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="checkpoint")
trainer = pl.Trainer(
+ default_root_dir=tmp_path,
+ logger=False,
devices=1,
accelerator=accelerator,
max_epochs=1,
@@ -37,7 +39,7 @@ def test_load_from_checkpoint_map_location_automatic(accelerator, tmp_path, monk
create_boring_checkpoint(tmp_path, BoringModel(), accelerator=accelerator)
# The checkpoint contains tensors with storage tag on the accelerator
- checkpoint = torch.load(f"{tmp_path}/checkpoint.ckpt")
+ checkpoint = torch.load(f"{tmp_path}/checkpoint.ckpt", weights_only=True)
assert checkpoint["state_dict"]["layer.weight"].device.type.startswith(accelerator)
# Pretend that the accelerator is not available
@@ -111,7 +113,7 @@ def test_load_from_checkpoint_warn_on_empty_state_dict(tmp_path):
"""Test that checkpoints can be loaded with an empty state dict and that the appropriate warning is raised."""
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
# Now edit so the state_dict is empty
- checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
+ checkpoint = torch.load(tmp_path / "checkpoint.ckpt", weights_only=True)
checkpoint["state_dict"] = {}
torch.save(checkpoint, tmp_path / "checkpoint.ckpt")
diff --git a/tests/tests_pytorch/demos/lstm.py b/tests/tests_pytorch/demos/lstm.py
new file mode 100644
index 0000000000000..ba424ef317586
--- /dev/null
+++ b/tests/tests_pytorch/demos/lstm.py
@@ -0,0 +1,11 @@
+from lightning.pytorch.demos.lstm import SequenceSampler
+
+
+def test_sequence_sampler():
+ dataset = list(range(103))
+ sampler = SequenceSampler(dataset, batch_size=4)
+ assert len(sampler) == 25
+ batches = list(sampler)
+ assert batches[0] == [0, 25, 50, 75]
+ assert batches[1] == [1, 26, 51, 76]
+ assert batches[24] == [24, 49, 74, 99]
diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
index d12254c6794fe..e6da72c777dbb 100644
--- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
+++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py
@@ -35,14 +35,10 @@ def test_ddp_is_distributed():
_ = strategy.is_distributed
-def test_fsdp_activation_checkpointing(monkeypatch):
+def test_fsdp_activation_checkpointing():
with pytest.raises(ValueError, match="cannot set both `activation_checkpointing"):
FSDPStrategy(activation_checkpointing=torch.nn.Linear, activation_checkpointing_policy=lambda *_: True)
- monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", True)
- with pytest.deprecated_call(match=r"use `FSDPStrategy\(activation_checkpointing_policy"):
- FSDPStrategy(activation_checkpointing=torch.nn.Linear)
-
def test_double_precision_wrapper():
with pytest.deprecated_call(match=r"The `LightningDoublePrecisionModule` is deprecated and no longer needed"):
diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py
index 8860160d6f30e..9b1d4ec7353cb 100644
--- a/tests/tests_pytorch/helpers/datasets.py
+++ b/tests/tests_pytorch/helpers/datasets.py
@@ -39,14 +39,6 @@ class MNIST(Dataset):
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
-
- Examples:
- >>> dataset = MNIST(".", download=True)
- >>> len(dataset)
- 60000
- >>> torch.bincount(dataset.targets)
- tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
-
"""
RESOURCES = (
@@ -114,7 +106,7 @@ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
assert os.path.isfile(path_data), f"missing file: {path_data}"
for _ in range(trials):
try:
- res = torch.load(path_data)
+ res = torch.load(path_data, weights_only=True)
# todo: specify the possible exception
except Exception as ex:
exception = ex
@@ -141,15 +133,6 @@ class TrialMNIST(MNIST):
digits: list selected MNIST digits/classes
kwargs: Same as MNIST
- Examples:
- >>> dataset = TrialMNIST(".", download=True)
- >>> len(dataset)
- 300
- >>> sorted(set([d.item() for d in dataset.targets]))
- [0, 1, 2]
- >>> torch.bincount(dataset.targets)
- tensor([100, 100, 100])
-
"""
def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py
index 5731b857303e2..98d77a6d9a8ad 100644
--- a/tests/tests_pytorch/helpers/test_datasets.py
+++ b/tests/tests_pytorch/helpers/test_datasets.py
@@ -12,14 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
+from contextlib import nullcontext
import cloudpickle
import pytest
+import torch
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
+def test_mnist(tmp_path):
+ dataset = MNIST(tmp_path, download=True)
+ assert len(dataset) == 60000
+ assert torch.bincount(dataset.targets).tolist() == [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
+
+
+def test_trial_mnist(tmp_path):
+ dataset = TrialMNIST(tmp_path, download=True)
+ assert len(dataset) == 300
+ assert set(dataset.targets.tolist()) == {0, 1, 2}
+ assert torch.bincount(dataset.targets).tolist() == [100, 100, 100]
+
+
@pytest.mark.parametrize(
("dataset_cls", "args"),
[(MNIST, {"root": _PATH_DATASETS}), (TrialMNIST, {"root": _PATH_DATASETS}), (AverageDataset, {})],
@@ -28,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)
mnist_pickled = pickle.dumps(mnist)
- pickle.loads(mnist_pickled)
- # assert vars(mnist) == vars(mnist_loaded)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ pickle.loads(mnist_pickled)
mnist_pickled = cloudpickle.dumps(mnist)
- cloudpickle.loads(mnist_pickled)
- # assert vars(mnist) == vars(mnist_loaded)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ cloudpickle.loads(mnist_pickled)
diff --git a/tests/tests_pytorch/helpers/utils.py b/tests/tests_pytorch/helpers/utils.py
index 418e4328d2c64..f3a9663452981 100644
--- a/tests/tests_pytorch/helpers/utils.py
+++ b/tests/tests_pytorch/helpers/utils.py
@@ -18,24 +18,20 @@
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import TensorBoardLogger
-from tests_pytorch import _TEMP_PATH
-
def get_default_logger(save_dir, version=None):
# set up logger object without actually saving logs
return TensorBoardLogger(save_dir, name="lightning_logs", version=version)
-def get_data_path(expt_logger, path_dir=None):
+def get_data_path(expt_logger, path_dir):
# some calls contain only experiment not complete logger
# each logger has to have these attributes
name, version = expt_logger.name, expt_logger.version
# the other experiments...
- if not path_dir:
- path_dir = expt_logger.save_dir if hasattr(expt_logger, "save_dir") and expt_logger.save_dir else _TEMP_PATH
- path_expt = os.path.join(path_dir, name, "version_%s" % version)
+ path_expt = os.path.join(path_dir, name, f"version_{version}")
# try if the new sub-folder exists, typical case for test-tube
if not os.path.isdir(path_expt):
diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py
index 0a0923b6902a6..7cc5cc94fe8cc 100644
--- a/tests/tests_pytorch/loggers/conftest.py
+++ b/tests/tests_pytorch/loggers/conftest.py
@@ -38,7 +38,8 @@ def mlflow_mock(monkeypatch):
mlflow.tracking = mlflow_tracking
mlflow.entities = mlflow_entities
- (monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True),)
+ monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True)
+ monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", True)
return mlflow
diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py
index 62ef2e6b6db52..c5b07562afb0a 100644
--- a/tests/tests_pytorch/loggers/test_all.py
+++ b/tests/tests_pytorch/loggers/test_all.py
@@ -14,11 +14,13 @@
import inspect
import os
import pickle
+from contextlib import nullcontext
from unittest import mock
from unittest.mock import ANY, Mock
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import (
@@ -70,8 +72,9 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
@mock.patch.dict(os.environ, {})
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
-def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path):
+def test_loggers_fit_test_all(logger_class, mlflow_mock, wandb_mock, comet_mock, neptune_mock, tmp_path, monkeypatch):
"""Verify that basic functionality of all loggers."""
+ monkeypatch.chdir(tmp_path)
class CustomModel(BoringModel):
def training_step(self, batch, batch_idx):
@@ -116,12 +119,12 @@ def log_metrics(self, metrics, step):
model = CustomModel()
trainer = Trainer(
+ default_root_dir=tmp_path,
max_epochs=1,
logger=logger,
limit_train_batches=1,
limit_val_batches=1,
log_every_n_steps=1,
- default_root_dir=tmp_path,
)
trainer.fit(model)
trainer.test()
@@ -160,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")
-def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
+def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
"""Verify that pickling trainer with logger works."""
_patch_comet_atexit(monkeypatch)
@@ -181,7 +184,12 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)
- trainer2 = pickle.loads(pkl_bytes)
+ with (
+ pytest.warns(FutureWarning, match="`weights_only=False`")
+ if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
+ else nullcontext()
+ ):
+ trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})
# make sure we restored properly
diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py
index 791089c47cbbe..e467c63543ede 100644
--- a/tests/tests_pytorch/loggers/test_comet.py
+++ b/tests/tests_pytorch/loggers/test_comet.py
@@ -66,6 +66,20 @@ def test_comet_logger_online(comet_mock):
api.assert_called_once_with("rest")
+@mock.patch.dict(os.environ, {})
+def test_comet_experiment_resets_if_not_alive(comet_mock):
+ """Test that the CometLogger creates a new experiment if the old one is not alive anymore."""
+ logger = CometLogger()
+ assert logger._experiment is None
+ alive_experiment = Mock(alive=True)
+ logger._experiment = alive_experiment
+ assert logger.experiment is alive_experiment
+
+ unalive_experiment = Mock(alive=False)
+ logger._experiment = unalive_experiment
+ assert logger.experiment is not unalive_experiment
+
+
@mock.patch.dict(os.environ, {})
def test_comet_logger_no_api_key_given(comet_mock):
"""Test that CometLogger fails to initialize if both api key and save_dir are missing."""
diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py
index a03b8a7d62f78..27b85bb4ad745 100644
--- a/tests/tests_pytorch/loggers/test_csv.py
+++ b/tests/tests_pytorch/loggers/test_csv.py
@@ -168,7 +168,7 @@ def test_metrics_reset_after_save(tmp_path):
# Mock the existance check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
-def test_append_metrics_file(tmp_path):
+def test_append_metrics_file(_, tmp_path):
"""Test that the logger appends to the file instead of rewriting it on every save."""
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)
diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py
index 4d74e046c590f..de0028000cd9f 100644
--- a/tests/tests_pytorch/loggers/test_logger.py
+++ b/tests/tests_pytorch/loggers/test_logger.py
@@ -13,6 +13,7 @@
# limitations under the License.
import pickle
from argparse import Namespace
+from contextlib import nullcontext
from copy import deepcopy
from typing import Any, Dict, Optional
from unittest.mock import patch
@@ -20,6 +21,7 @@
import numpy as np
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
@@ -122,7 +124,8 @@ def test_multiple_loggers_pickle(tmp_path):
trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
- trainer2 = pickle.loads(pkl_bytes)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)
diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py
index 372491a34c249..14af36680904c 100644
--- a/tests/tests_pytorch/loggers/test_mlflow.py
+++ b/tests/tests_pytorch/loggers/test_mlflow.py
@@ -18,7 +18,11 @@
import pytest
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
-from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger, _get_resolve_tags
+from lightning.pytorch.loggers.mlflow import (
+ _MLFLOW_AVAILABLE,
+ MLFlowLogger,
+ _get_resolve_tags,
+)
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
@@ -260,6 +264,56 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
)
+@pytest.mark.parametrize("synchronous", [False, True])
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path, synchronous):
+ """Test that the logger calls methods on the mlflow experiment with the specified synchronous flag."""
+
+ time = mlflow_mock.entities.time
+ metric = mlflow_mock.entities.Metric
+ param = mlflow_mock.entities.Param
+ time.return_value = 1
+
+ mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
+ mlflow_client.get_experiment_by_name.return_value = None
+ logger = MLFlowLogger(
+ "test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=synchronous
+ )
+
+ params = {"test": "test_param"}
+ logger.log_hyperparams(params)
+
+ mlflow_client.log_batch.assert_called_once_with(
+ run_id=logger.run_id, params=[param(key="test", value="test_param")], synchronous=synchronous
+ )
+ param.assert_called_with(key="test", value="test_param")
+
+ metrics = {"some_metric": 10}
+ logger.log_metrics(metrics)
+
+ mlflow_client.log_batch.assert_called_with(
+ run_id=logger.run_id,
+ metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)],
+ synchronous=synchronous,
+ )
+ metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0)
+
+ mlflow_client.create_experiment.assert_called_once_with(name="test", artifact_location="my_artifact_location")
+
+
+@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
+@mock.patch.dict("lightning.pytorch.loggers.mlflow.__dict__", {"_MLFLOW_SYNCHRONOUS_AVAILABLE": False})
+def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
+ """Test that the logger does not support synchronous flag."""
+ time = mlflow_mock.entities.time
+ time.return_value = 1
+
+ mlflow_client = mlflow_mock.tracking.MlflowClient.return_value
+ mlflow_client.get_experiment_by_name.return_value = None
+ with pytest.raises(ModuleNotFoundError):
+ MLFlowLogger("test", save_dir=str(tmp_path), artifact_location="my_artifact_location", synchronous=True)
+
+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py
index 13941c18db8e8..0a39337ac5c16 100644
--- a/tests/tests_pytorch/loggers/test_neptune.py
+++ b/tests/tests_pytorch/loggers/test_neptune.py
@@ -149,15 +149,17 @@ def test_neptune_additional_methods(neptune_mock):
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
-def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path):
+def test_neptune_leave_open_experiment_after_fit(neptune_mock, tmp_path, monkeypatch):
"""Verify that neptune experiment was NOT closed after training."""
+ monkeypatch.chdir(tmp_path)
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
_fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path)
assert run_instance_mock.stop.call_count == 0
-def test_neptune_log_metrics_on_trained_model(neptune_mock, tmp_path):
+def test_neptune_log_metrics_on_trained_model(neptune_mock, tmp_path, monkeypatch):
"""Verify that trained models do log data."""
+ monkeypatch.chdir(tmp_path)
class LoggingModel(BoringModel):
def on_validation_epoch_end(self):
@@ -305,9 +307,10 @@ def test_get_full_model_names_from_exp_structure():
assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys
-def test_inactive_run(neptune_mock, tmp_path):
+def test_inactive_run(neptune_mock, tmp_path, monkeypatch):
from neptune.exceptions import InactiveRunException
+ monkeypatch.chdir(tmp_path)
logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project")
run_instance_mock.__setitem__.side_effect = InactiveRunException
diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py
index 61e689831ac1a..82ffff25cac7c 100644
--- a/tests/tests_pytorch/loggers/test_tensorboard.py
+++ b/tests/tests_pytorch/loggers/test_tensorboard.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import os
from argparse import Namespace
from unittest import mock
@@ -109,7 +108,6 @@ def test_tensorboard_no_name(tmp_path, name):
assert os.listdir(tmp_path / "version_0")
-@mock.patch.dict(os.environ, {}, clear=True)
def test_tensorboard_log_sub_dir(tmp_path):
class TestLogger(TensorBoardLogger):
# for reproducibility
@@ -141,14 +139,15 @@ def name(self):
trainer = Trainer(**trainer_args, logger=logger)
assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")
- # test env var (`$`) handling
- test_env_dir = "some_directory"
- os.environ["TEST_ENV_DIR"] = test_env_dir
- save_dir = "$TEST_ENV_DIR/tmp"
- explicit_save_dir = f"{test_env_dir}/tmp"
- logger = TestLogger(save_dir, sub_dir="sub_dir")
- trainer = Trainer(**trainer_args, logger=logger)
- assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")
+ with mock.patch.dict(os.environ, {}):
+ # test env var (`$`) handling
+ test_env_dir = "some_directory"
+ os.environ["TEST_ENV_DIR"] = test_env_dir
+ save_dir = "$TEST_ENV_DIR/tmp"
+ explicit_save_dir = f"{test_env_dir}/tmp"
+ logger = TestLogger(save_dir, sub_dir="sub_dir")
+ trainer = Trainer(**trainer_args, logger=logger)
+ assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")
@pytest.mark.parametrize("step_idx", [10, None])
@@ -312,8 +311,7 @@ def test_tensorboard_save_hparams_to_yaml_once(tmp_path):
assert not os.path.isfile(os.path.join(tmp_path, hparams_file))
-@mock.patch("lightning.pytorch.loggers.tensorboard.log")
-def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
+def test_tensorboard_with_symlink(tmp_path, monkeypatch):
"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
relative paths."""
monkeypatch.chdir(tmp_path) # need to use relative paths
@@ -325,16 +323,3 @@ def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
logger = TensorBoardLogger(save_dir=dest, name="")
_ = logger.version
-
- log.warning.assert_not_called()
-
-
-def test_tensorboard_missing_folder_warning(tmp_path, caplog):
- """Verify that the logger throws a warning for invalid directory."""
- name = "fake_dir"
- logger = TensorBoardLogger(save_dir=tmp_path, name=name)
-
- with caplog.at_level(logging.WARNING):
- assert logger.version == 0
-
- assert "Missing logger folder:" in caplog.text
diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py
index 16c69d2c6d773..4e3fbb287a1f9 100644
--- a/tests/tests_pytorch/loggers/test_wandb.py
+++ b/tests/tests_pytorch/loggers/test_wandb.py
@@ -13,10 +13,13 @@
# limitations under the License.
import os
import pickle
+from contextlib import nullcontext
+from pathlib import Path
from unittest import mock
import pytest
import yaml
+from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
@@ -25,6 +28,8 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_utilities.test.warning import no_warning_call
+from tests_pytorch.test_cli import _xfail_python_ge_3_11_9
+
def test_wandb_project_name(wandb_mock):
with mock.patch.dict(os.environ, {}):
@@ -111,9 +116,10 @@ def test_wandb_logger_init(wandb_mock):
wandb_mock.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})
# log hyper parameters
- hparams = {"test": None, "nested": {"a": 1}, "b": [2, 3, 4]}
+ hparams = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": Path("path")}
+ expected = {"none": None, "dict": {"a": 1}, "b": [2, 3, 4], "path": "path"}
logger.log_hyperparams(hparams)
- wandb_mock.init().config.update.assert_called_once_with(hparams, allow_val_change=True)
+ wandb_mock.init().config.update.assert_called_once_with(expected, allow_val_change=True)
# watch a model
logger.watch("model", "log", 10, False)
@@ -156,7 +162,8 @@ def name(self):
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
- trainer2 = pickle.loads(pkl_bytes)
+ with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
+ trainer2 = pickle.loads(pkl_bytes)
assert os.environ["WANDB_MODE"] == "dryrun"
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
@@ -548,6 +555,7 @@ def test_wandb_logger_download_artifact(wandb_mock, tmp_path):
wandb_mock.Api().artifact.assert_called_once_with("test_artifact", type="model")
+@_xfail_python_ge_3_11_9
@pytest.mark.parametrize(("log_model", "expected"), [("True", True), ("False", False), ("all", "all")])
def test_wandb_logger_cli_integration(log_model, expected, wandb_mock, monkeypatch, tmp_path):
"""Test that the WandbLogger can be used with the LightningCLI."""
diff --git a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py
similarity index 76%
rename from tests/tests_pytorch/loops/optimization/test_optimizer_loop.py
rename to tests/tests_pytorch/loops/optimization/test_automatic_loop.py
index 2111212de8901..0ea6290586f55 100644
--- a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py
+++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from contextlib import nullcontext
from typing import Dict, Generic, Iterator, Mapping, TypeVar
import pytest
@@ -82,3 +82,27 @@ def training_step(self, batch, batch_idx):
with pytest.raises(MisconfigurationException, match=match):
trainer.fit(model)
+
+
+@pytest.mark.parametrize("world_size", [1, 2])
+def test_skip_training_step_not_allowed(world_size, tmp_path):
+ """Test that skipping the training_step in distributed training is not allowed."""
+
+ class TestModel(BoringModel):
+ def training_step(self, batch, batch_idx):
+ return None
+
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ max_steps=1,
+ barebones=True,
+ )
+ trainer.strategy.world_size = world_size # mock world size without launching processes
+ error_context = (
+ pytest.raises(RuntimeError, match="Skipping the `training_step` .* is not supported")
+ if world_size > 1
+ else nullcontext()
+ )
+ with error_context:
+ trainer.fit(model)
diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py
index b605c44a15b50..ff317cd2e18ba 100644
--- a/tests/tests_pytorch/loops/test_loops.py
+++ b/tests/tests_pytorch/loops/test_loops.py
@@ -226,7 +226,7 @@ def val_dataloader(self):
trainer.fit(model)
ckpt_path = str(tmp_path / "on_exception.ckpt")
- checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"]
+ checkpoint = torch.load(ckpt_path, weights_only=True)["loops"]["fit_loop"]
trainer.fit_loop.load_state_dict(checkpoint)
@@ -285,7 +285,7 @@ def training_step(self, batch, batch_idx):
ckpt_path = str(tmp_path / "on_exception.ckpt")
assert os.path.exists(ckpt_path)
- checkpoint = torch.load(ckpt_path)
+ checkpoint = torch.load(ckpt_path, weights_only=True)
optim_progress = trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress
@@ -446,7 +446,7 @@ def train_dataloader(self):
ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
- checkpoint = torch.load(ckpt_path)
+ checkpoint = torch.load(ckpt_path, weights_only=True)
n_sch_steps_total = n_epochs
n_sch_steps_current = 1
@@ -547,7 +547,7 @@ def test_fit_loop_reset(tmp_path):
trainer.fit(model)
# reset state loaded from a checkpoint from mid-epoch
- mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=2.ckpt"))
+ mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=2.ckpt"), weights_only=True)
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.automatic_optimization
@@ -578,7 +578,7 @@ def test_fit_loop_reset(tmp_path):
assert optimizer_loop.restarting
# reset state loaded from a checkpoint from the end of an epoch
- end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"))
+ end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True)
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
fit_loop.restarting = False
@@ -696,7 +696,7 @@ def val_dataloader(self):
trainer.fit(model)
assert os.path.exists(ckpt_path)
- checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"]
+ checkpoint = torch.load(ckpt_path, weights_only=True)["loops"]["fit_loop"]
per_val_train_batches = int(n_batches * val_check_interval)
assert checkpoint["epoch_loop.batch_progress"] == {
@@ -963,7 +963,7 @@ def train_dataloader(self):
# Save a checkpoint
trainer.save_checkpoint(tmp_path / "checkpoint.ckpt")
- checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
+ checkpoint = torch.load(tmp_path / "checkpoint.ckpt", weights_only=True)
if has_state:
assert checkpoint["loops"]["fit_loop"]["state_dict"]["combined_loader"]
else:
diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py
index 16ed3842e3a96..a110a20bfaf84 100644
--- a/tests/tests_pytorch/loops/test_training_epoch_loop.py
+++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py
@@ -30,6 +30,7 @@ def test_no_val_on_train_epoch_loop_restart(tmp_path):
"limit_train_batches": 1,
"limit_val_batches": 1,
"num_sanity_val_steps": 0,
+ "logger": False,
"enable_checkpointing": False,
}
trainer = Trainer(**trainer_kwargs)
diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py
index aa56e8ca02ba4..5a175e181dd9e 100644
--- a/tests/tests_pytorch/models/test_hooks.py
+++ b/tests/tests_pytorch/models/test_hooks.py
@@ -18,7 +18,6 @@
import pytest
import torch
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from lightning.pytorch.utilities.model_helpers import is_overridden
@@ -179,6 +178,8 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert batch.samples.device == self.device
assert isinstance(batch_idx, int)
+ # the actual training step is not needed for the assertions
+ return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)
@@ -479,7 +480,7 @@ def training_step(self, batch, batch_idx):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
{"name": "Callback.on_sanity_check_start", "args": (trainer, model)},
{"name": "val_dataloader"},
{"name": "train", "args": (False,)},
@@ -497,7 +498,7 @@ def training_step(self, batch, batch_idx):
{"name": "Callback.on_train_epoch_start", "args": (trainer, model)},
{"name": "on_train_epoch_start"},
*model._train_batch(trainer, model, train_batches, device=device, **kwargs),
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
{"name": "on_validation_model_zero_grad"},
{"name": "train", "args": (False,)},
{"name": "on_validation_model_eval"},
@@ -577,7 +578,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
{"name": "train_dataloader"},
{"name": "Callback.on_train_start", "args": (trainer, model)},
{"name": "on_train_start"},
@@ -655,7 +656,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
{"name": "train_dataloader"},
{"name": "Callback.on_train_start", "args": (trainer, model)},
{"name": "on_train_start"},
@@ -718,7 +719,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
{"name": "setup", "kwargs": {"stage": verb}},
{"name": "configure_model"},
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
*(hooks if batches else []),
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
{"name": "teardown", "kwargs": {"stage": verb}},
@@ -741,7 +742,7 @@ def test_trainer_model_hook_system_predict(tmp_path):
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
{"name": "setup", "kwargs": {"stage": "predict"}},
{"name": "configure_model"},
- {"name": "zero_grad", **({} if _TORCH_GREATER_EQUAL_2_0 else {"kwargs": {"set_to_none": True}})},
+ {"name": "zero_grad"},
{"name": "predict_dataloader"},
{"name": "train", "args": (False,)},
{"name": "on_predict_model_eval"},
diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py
index 0d7fced3b8197..871b1cba673eb 100644
--- a/tests/tests_pytorch/models/test_hparams.py
+++ b/tests/tests_pytorch/models/test_hparams.py
@@ -108,7 +108,7 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
- raw_checkpoint = torch.load(raw_checkpoint_path)
+ raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False)
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
@@ -242,7 +242,7 @@ def __init__(self, test_arg, test_arg2):
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
- raw_checkpoint = torch.load(raw_checkpoint_path)
+ raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
@@ -393,7 +393,7 @@ def test_collect_init_arguments(tmp_path, cls):
raw_checkpoint_path = _raw_checkpoint_path(trainer)
- raw_checkpoint = torch.load(raw_checkpoint_path)
+ raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179
@@ -507,7 +507,7 @@ def test_load_past_checkpoint(tmp_path, past_key):
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
- raw_checkpoint = torch.load(raw_checkpoint_path)
+ raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
raw_checkpoint["hparams_type"] = "Namespace"
raw_checkpoint[past_key]["batch_size"] = -17
@@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path):
trainer.fit(model)
-def test_hparams_save_yaml(tmp_path):
+def test_save_hparams_to_yaml(tmp_path):
class Options(str, Enum):
option1name = "option1val"
option2name = "option2val"
@@ -590,6 +590,14 @@ def _compare_params(loaded_params, default_params: dict):
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
+def test_save_hparams_to_yaml_warning(tmp_path):
+ """Test that we warn about unserializable parameters that need to be dropped."""
+ path_yaml = tmp_path / "hparams.yaml"
+ hparams = {"torch_type": torch.float32}
+ with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"):
+ save_hparams_to_yaml(path_yaml, hparams)
+
+
class NoArgsSubClassBoringModel(CustomBoringModel):
def __init__(self):
super().__init__()
@@ -764,7 +772,7 @@ def __init__(self, arg1, arg2, arg3):
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
- raw_checkpoint = torch.load(raw_checkpoint_path)
+ raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["arg1"] == 14
diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py
index 9bb20579b7162..ee670cd66e871 100644
--- a/tests/tests_pytorch/models/test_onnx.py
+++ b/tests/tests_pytorch/models/test_onnx.py
@@ -13,6 +13,8 @@
# limitations under the License.
import operator
import os
+from io import BytesIO
+from pathlib import Path
from unittest.mock import patch
import numpy as np
@@ -32,15 +34,22 @@
def test_model_saves_with_input_sample(tmp_path):
"""Test that ONNX model saves with input sample and size is greater than 3 MB."""
model = BoringModel()
- trainer = Trainer(fast_dev_run=True)
- trainer.fit(model)
-
- file_path = os.path.join(tmp_path, "model.onnx")
input_sample = torch.randn((1, 32))
+
+ file_path = os.path.join(tmp_path, "os.path.onnx")
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2
+ file_path = Path(tmp_path) / "pathlib.onnx"
+ model.to_onnx(file_path, input_sample)
+ assert os.path.isfile(file_path)
+ assert os.path.getsize(file_path) > 4e2
+
+ file_path = BytesIO()
+ model.to_onnx(file_path=file_path, input_sample=input_sample)
+ assert len(file_path.getvalue()) > 4e2
+
@pytest.mark.parametrize(
"accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]
diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py
index 211dc49d42eda..fe7e3fbbab357 100644
--- a/tests/tests_pytorch/models/test_restore.py
+++ b/tests/tests_pytorch/models/test_restore.py
@@ -131,7 +131,7 @@ def configure_optimizers(self):
trainer.fit(model, datamodule=dm)
resume_ckpt = str(tmp_path / "last.ckpt")
- state_dict = torch.load(resume_ckpt)
+ state_dict = torch.load(resume_ckpt, weights_only=True)
trainer_args.update({"max_epochs": 3, "enable_checkpointing": False, "callbacks": []})
@@ -208,7 +208,7 @@ def test_correct_step_and_epoch(tmp_path):
ckpt_path = str(tmp_path / "model.ckpt")
trainer.save_checkpoint(ckpt_path)
- ckpt = torch.load(ckpt_path)
+ ckpt = torch.load(ckpt_path, weights_only=True)
assert ckpt["epoch"] == first_max_epochs
assert ckpt["global_step"] == first_max_epochs * train_batches
@@ -258,7 +258,7 @@ def on_train_epoch_end(self, *_):
def test_try_resume_from_non_existing_checkpoint(tmp_path):
"""Test that trying to resume from non-existing `ckpt_path` fails with an error."""
model = BoringModel()
- trainer = Trainer()
+ trainer = Trainer(logger=False)
with pytest.raises(FileNotFoundError, match="Checkpoint file not found"):
trainer.fit(model, ckpt_path=str(tmp_path / "non_existing.ckpt"))
@@ -461,7 +461,7 @@ def test_load_model_from_checkpoint(tmp_path, model_template):
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
# Since `BoringModel` has `_save_hparams = True` by default, check that ckpt has hparams
- ckpt = torch.load(last_checkpoint)
+ ckpt = torch.load(last_checkpoint, weights_only=True)
assert model_template.CHECKPOINT_HYPER_PARAMS_KEY in ckpt, "hyper_parameters missing from checkpoints"
# Ensure that model can be correctly restored from checkpoint
diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py
index a5783f143f277..993085729e545 100644
--- a/tests/tests_pytorch/models/test_torchscript.py
+++ b/tests/tests_pytorch/models/test_torchscript.py
@@ -19,6 +19,7 @@
import torch
from fsspec.implementations.local import LocalFileSystem
from lightning.fabric.utilities.cloud_io import get_filesystem
+from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -26,6 +27,7 @@
from tests_pytorch.helpers.runif import RunIf
+@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_input_output(modelclass):
"""Test that scripted LightningModule forward works."""
@@ -45,6 +47,7 @@ def test_torchscript_input_output(modelclass):
assert torch.allclose(script_output, model_output)
+@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_example_input_output_trace(modelclass):
"""Test that traced LightningModule forward works with example_input_array."""
diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py
index 10257531d5e11..bc9f77907919a 100644
--- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py
+++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py
@@ -15,6 +15,7 @@
import torch
from lightning.fabric import seed_everything
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision import MixedPrecision
@@ -28,7 +29,8 @@ def __init__(self, fused=False):
self.fused = fused
def configure_optimizers(self):
- assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
+ scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
+ assert isinstance(self.trainer.precision_plugin.scaler, scaler_cls)
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)
diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py
index f985c09888847..1ee89752fcbae 100644
--- a/tests/tests_pytorch/plugins/precision/test_double.py
+++ b/tests/tests_pytorch/plugins/precision/test_double.py
@@ -134,6 +134,7 @@ def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)
+@RunIf(mps=False) # mps does not support float64
@pytest.mark.parametrize(
"boring_model",
[
diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py
index e4d652cb15864..8b595c2c74a32 100644
--- a/tests/tests_pytorch/plugins/precision/test_fsdp.py
+++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py
@@ -26,25 +26,9 @@
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
- pytest.param(
- "16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
- ),
- pytest.param(
- "16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
- ),
- pytest.param(
- "bf16-mixed",
- (torch.float32, torch.bfloat16, torch.bfloat16),
- marks=RunIf(min_torch="2.0"),
- id="bf16-mixed-ge2_0",
- ),
- pytest.param(
- "bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
- ),
- pytest.param(
- "32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0"
- ),
- pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"),
+ ("16-mixed", (torch.float32, torch.float16, torch.float16)),
+ ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
+ ("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
@@ -74,8 +58,10 @@ def test_fsdp_precision_scaler_with_bf16():
@RunIf(min_cuda_gpus=1)
def test_fsdp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16."""
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
precision = FSDPPrecision(precision="16-mixed")
- assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
+ assert isinstance(precision.scaler, ShardedGradScaler)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16
diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
index 55334cd7fc42b..185a767d9e8c9 100644
--- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
+++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py
@@ -31,7 +31,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
torch.save(checkpoint, path)
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
- return torch.load(path)
+ return torch.load(path, weights_only=True)
def remove_checkpoint(self, path: _PATH) -> None:
os.remove(path)
diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py
index 4d5c9762cc6ff..5b0c13e605ee2 100644
--- a/tests/tests_pytorch/profilers/test_profiler.py
+++ b/tests/tests_pytorch/profilers/test_profiler.py
@@ -14,6 +14,7 @@
import logging
import os
import platform
+import sys
import time
from copy import deepcopy
from unittest.mock import patch
@@ -21,6 +22,7 @@
import numpy as np
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
@@ -34,6 +36,13 @@
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
+# TODO: Nested profile calls are not supported and raise an error in Python 3.12+
+# https://github.com/Lightning-AI/pytorch-lightning/issues/19983
+skip_advanced_profiler_py312 = pytest.mark.skipif(
+ sys.version_info >= (3, 12), reason="Nested profiler calls not supported."
+)
+
+
def _get_python_cprofile_total_duration(profile):
return sum(x.inlinetime for x in profile.getstats())
@@ -299,6 +308,19 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
assert len(data) > 0
+def test_advanced_profiler_dump_states(tmp_path):
+ advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
+ """Ensure the profiler dump stats during summary."""
+ # record at least one event
+ with advanced_profiler.profile(action_name := "test"):
+ pass
+ # dump_stats to file
+ advanced_profiler.describe()
+ path = advanced_profiler.dirpath / f"{action_name}.prof"
+ data = path.read_bytes()
+ assert len(data) > 0
+
+
def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
@@ -332,6 +354,7 @@ def test_pytorch_profiler_describe(pytorch_profiler):
assert len(data) > 0
+@skip_advanced_profiler_py312
def test_advanced_profiler_cprofile_deepcopy(tmp_path):
"""Checks for pickle issue reported in #6522."""
model = BoringModel()
@@ -430,7 +453,8 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmp_path):
def test_pytorch_profiler_nested(tmp_path):
"""Ensure that the profiler handles nested context."""
- pytorch_profiler = PyTorchProfiler(use_cuda=False, dirpath=tmp_path, filename="profiler", schedule=None)
+ kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": False}
+ pytorch_profiler = PyTorchProfiler(dirpath=tmp_path, filename="profiler", schedule=None, **kwargs)
with pytorch_profiler.profile("a"):
a = torch.ones(42)
@@ -475,13 +499,14 @@ def look_for_trace(trace_dir):
def test_register_record_function(tmp_path):
use_cuda = torch.cuda.is_available()
+ kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": torch.cuda.is_available()}
pytorch_profiler = PyTorchProfiler(
export_to_chrome=False,
- use_cuda=use_cuda,
dirpath=tmp_path,
filename="profiler",
schedule=None,
on_trace_ready=None,
+ **kwargs,
)
class TestModel(BoringModel):
@@ -507,7 +532,14 @@ def __init__(self):
assert "[pl][module]torch.nn.modules.linear.Linear: layer.2" in event_names
-@pytest.mark.parametrize("cls", [SimpleProfiler, AdvancedProfiler, PyTorchProfiler])
+@pytest.mark.parametrize(
+ "cls",
+ [
+ SimpleProfiler,
+ PyTorchProfiler,
+ pytest.param(AdvancedProfiler, marks=skip_advanced_profiler_py312),
+ ],
+)
def test_profiler_teardown(tmp_path, cls):
"""This test checks if profiler teardown method is called when trainer is exiting."""
diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py
index fce882852823b..7c883c419ea82 100644
--- a/tests/tests_pytorch/serve/test_servable_module_validator.py
+++ b/tests/tests_pytorch/serve/test_servable_module_validator.py
@@ -37,7 +37,7 @@ def test_servable_module_validator():
@pytest.mark.flaky(reruns=3)
-def test_servable_module_validator_with_trainer(tmp_path):
+def test_servable_module_validator_with_trainer(tmp_path, mps_count_0):
callback = ServableModuleValidator()
trainer = Trainer(
default_root_dir=tmp_path,
diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
index a3af440292306..b0462c0105a9f 100644
--- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
+++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
@@ -194,14 +194,16 @@ def on_fit_start(self) -> None:
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
-def test_memory_sharing_disabled():
+def test_memory_sharing_disabled(tmp_path):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
model = SimpleModel()
assert not model.layer.weight.is_shared()
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
- trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
+ trainer = Trainer(
+ default_root_dir=tmp_path, logger=False, accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0
+ )
trainer.fit(model)
@@ -214,7 +216,7 @@ def test_check_for_missing_main_guard():
launcher.launch(function=Mock())
-def test_fit_twice_raises():
+def test_fit_twice_raises(mps_count_0):
model = BoringModel()
trainer = Trainer(
limit_train_batches=1,
diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py
index f352ead871102..699424b3c53b9 100644
--- a/tests/tests_pytorch/strategies/test_common.py
+++ b/tests/tests_pytorch/strategies/test_common.py
@@ -15,7 +15,6 @@
import pytest
import torch
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision
from lightning.pytorch.strategies import SingleDeviceStrategy
@@ -82,8 +81,7 @@ def test_module_init_context(device, precision, dtype, empty_init, monkeypatch):
with strategy.tensor_init_context(empty_init=empty_init):
module = torch.nn.Linear(2, 2)
- expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
- assert module.weight.device == module.bias.device == expected_device
+ assert module.weight.device == module.bias.device == device
assert module.weight.dtype == module.bias.dtype == dtype
if not empty_init:
init_mock.assert_called()
diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py
index dadd49c359e06..b23d306b9d907 100644
--- a/tests/tests_pytorch/strategies/test_ddp.py
+++ b/tests/tests_pytorch/strategies/test_ddp.py
@@ -18,7 +18,6 @@
import pytest
import torch
from lightning.fabric.plugins.environments import LightningEnvironment
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision
@@ -102,7 +101,7 @@ def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs, mps_count_
def test_tensor_init_context(precision_plugin, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
- expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
+ expected_device = parallel_devices[1]
strategy = DDPStrategy(
parallel_devices=parallel_devices, precision_plugin=precision_plugin, cluster_environment=LightningEnvironment()
diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py
index 0b841cde8de67..836072d36be83 100644
--- a/tests/tests_pytorch/strategies/test_ddp_integration.py
+++ b/tests/tests_pytorch/strategies/test_ddp_integration.py
@@ -20,7 +20,6 @@
import torch
from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment
from lightning.fabric.utilities.distributed import _distributed_is_initialized
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, EarlyStopping
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
@@ -112,9 +111,7 @@ class CustomCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert isinstance(trainer.strategy.model, DistributedDataParallel)
expected = ["something"]
- assert (
- trainer.strategy.model.parameters_to_ignore == set(expected) if _TORCH_GREATER_EQUAL_2_0 else expected
- )
+ assert trainer.strategy.model.parameters_to_ignore == set(expected)
assert trainer.strategy.model.module._ddp_params_and_buffers_to_ignore == expected
model = CustomModel()
@@ -285,12 +282,14 @@ def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)
+# ZeroRedundancyOptimizer internally calls `torch.load` with `weights_only` not set, triggering the FutureWarning
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True)
@pytest.mark.parametrize("strategy", [pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"])
-def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmp_path, strategy):
+def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(strategy, tmp_path):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
- trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
+ trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
trainer.fit(model)
diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py
index 93c9ee9f25183..be9428ff7533c 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed.py
@@ -28,7 +28,7 @@
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.plugins import DeepSpeedPrecision
-from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy
+from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
from torch import Tensor, nn
@@ -38,11 +38,6 @@
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
-if _DEEPSPEED_AVAILABLE:
- import deepspeed
- from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
- from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
-
class ModelParallelBoringModel(BoringModel):
def __init__(self):
@@ -245,6 +240,7 @@ def configure_model(self) -> None:
def test_deepspeed_run_configure_optimizers(tmp_path):
"""Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using
configure_optimizers for optimizers and schedulers."""
+ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
class TestCB(Callback):
def on_train_start(self, trainer, pl_module) -> None:
@@ -284,6 +280,7 @@ def configure_optimizers(self):
def test_deepspeed_config(tmp_path, deepspeed_zero_config):
"""Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
and saves the model weights to load correctly."""
+ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
class TestCB(Callback):
def on_train_start(self, trainer, pl_module) -> None:
@@ -397,6 +394,8 @@ def test_deepspeed_custom_activation_checkpointing_params():
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmp_path):
"""Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure
correctly."""
+ import deepspeed
+
ds = DeepSpeedStrategy(
partition_activations=True,
cpu_checkpointing=True,
@@ -453,6 +452,8 @@ def setup(self, trainer, pl_module, stage=None) -> None:
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu(tmp_path):
"""Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly."""
+ import deepspeed
+
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
@@ -978,6 +979,8 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path,
def _assert_save_model_is_equal(model, tmp_path, trainer):
+ from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+
checkpoint_path = os.path.join(tmp_path, "model.pt")
checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
@@ -986,7 +989,7 @@ def _assert_save_model_is_equal(model, tmp_path, trainer):
if trainer.is_global_zero:
single_ckpt_path = os.path.join(tmp_path, "single_model.pt")
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
- state_dict = torch.load(single_ckpt_path)
+ state_dict = torch.load(single_ckpt_path, weights_only=False)
model = model.cpu()
# Assert model parameters are identical after loading
diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py
index 751a42b96b42f..aec01b83e956a 100644
--- a/tests/tests_pytorch/strategies/test_fsdp.py
+++ b/tests/tests_pytorch/strategies/test_fsdp.py
@@ -14,11 +14,7 @@
import torch.nn as nn
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
-from lightning.fabric.utilities.imports import (
- _TORCH_GREATER_EQUAL_2_0,
- _TORCH_GREATER_EQUAL_2_1,
- _TORCH_GREATER_EQUAL_2_2,
-)
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -28,18 +24,12 @@
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint
-from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
-from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy
from tests_pytorch.helpers.runif import RunIf
-if _TORCH_GREATER_EQUAL_2_0:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-else:
- ModuleWrapPolicy = object
-
class TestFSDPModel(BoringModel):
def __init__(self):
@@ -88,10 +78,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
if self.trainer.precision == "16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
@@ -120,10 +110,8 @@ def __init__(self, wrap_min_params: int = 2):
self.should_be_wrapped = [wrap_min_params < (32 * 32 + 32), None, wrap_min_params < (32 * 2 + 2)]
def configure_optimizers(self):
- parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
-
# SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214
- return torch.optim.AdamW(parameters, lr=0.1)
+ return torch.optim.AdamW(self.parameters(), lr=0.1)
class TestFSDPModelAutoWrapped(TestBoringModel):
@@ -151,10 +139,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
if self.trainer.precision == "16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
- param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
+ param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
@@ -216,24 +204,22 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
def test_invalid_on_cpu(tmp_path, cuda_count_0):
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
- with pytest.raises(
- MisconfigurationException,
- match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
- ):
+ with pytest.raises(ValueError, match="The strategy `fsdp` requires a GPU accelerator"):
trainer = Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy="fsdp")
assert isinstance(trainer.strategy, FSDPStrategy)
trainer.strategy.setup_environment()
-def test_fsdp_custom_mixed_precision():
+def test_custom_mixed_precision():
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
-def test_fsdp_strategy_sync_batchnorm(tmp_path):
+def test_strategy_sync_batchnorm(tmp_path):
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
model = TestFSDPModel()
trainer = Trainer(
@@ -248,8 +234,9 @@ def test_fsdp_strategy_sync_batchnorm(tmp_path):
_run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt"))
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=1, skip_windows=True)
-def test_fsdp_modules_without_parameters(tmp_path):
+def test_modules_without_parameters(tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
class MetricsModel(BoringModel):
@@ -278,10 +265,11 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
-def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmp_path):
+def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
strategy = FSDPStrategy(state_dict_type=state_dict_type)
@@ -291,28 +279,18 @@ def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmp_path):
_run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt"))
-if _TORCH_GREATER_EQUAL_2_0:
-
- def custom_auto_wrap_policy(
- module,
- recurse,
- nonwrapped_numel: int,
- ) -> bool:
- return nonwrapped_numel >= 2
-
-else:
-
- def custom_auto_wrap_policy(
- module,
- recurse,
- unwrapped_params: int,
- ) -> bool:
- return unwrapped_params >= 2
+def custom_auto_wrap_policy(
+ module,
+ recurse,
+ nonwrapped_numel: int,
+) -> bool:
+ return nonwrapped_numel >= 2
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
-def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
+def test_strategy_full_state_dict(tmp_path, wrap_min_params):
"""Test to ensure that the full state dict is extracted when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
@@ -345,6 +323,7 @@ def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("model", "strategy", "strategy_cfg"),
@@ -354,29 +333,20 @@ def test_fsdp_strategy_full_state_dict(tmp_path, wrap_min_params):
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{"auto_wrap_policy": custom_auto_wrap_policy},
- marks=RunIf(max_torch="2.0.0"),
- id="autowrap_1x",
- ),
- pytest.param(
- TestFSDPModelAutoWrapped(),
- FSDPStrategy,
- {"auto_wrap_policy": custom_auto_wrap_policy},
- marks=RunIf(min_torch="2.0.0"),
id="autowrap_2x",
),
pytest.param(
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{
- "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
+ "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
"use_orig_params": True,
},
- marks=RunIf(min_torch="2.1.0"),
id="autowrap_use_orig_params",
),
],
)
-def test_fsdp_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
+def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
ck = ModelCheckpoint(save_last=True)
@@ -404,7 +374,7 @@ def test_fsdp_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
@pytest.mark.parametrize("use_orig_params", [None, False, True])
def test_invalid_parameters_in_optimizer(use_orig_params):
fsdp_kwargs = {}
- if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
+ if use_orig_params is not None:
fsdp_kwargs = {"use_orig_params": use_orig_params}
trainer = Trainer(
@@ -414,19 +384,12 @@ def test_invalid_parameters_in_optimizer(use_orig_params):
fast_dev_run=1,
)
- error_context = (
- nullcontext()
- if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
- else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
- )
-
class EmptyParametersModel(BoringModel):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-2)
model = EmptyParametersModel()
- with error_context:
- trainer.fit(model)
+ trainer.fit(model)
class NoFlatParametersModel(BoringModel):
def configure_optimizers(self):
@@ -435,7 +398,7 @@ def configure_optimizers(self):
error_context = (
nullcontext()
- if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
+ if use_orig_params is not False
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)
@@ -444,7 +407,7 @@ def configure_optimizers(self):
trainer.fit(model)
-def test_fsdp_forbidden_precision_raises():
+def test_forbidden_precision_raises():
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
FSDPStrategy(precision_plugin=HalfPrecision())
@@ -453,7 +416,7 @@ def test_fsdp_forbidden_precision_raises():
strategy.precision_plugin = HalfPrecision()
-def test_fsdp_activation_checkpointing():
+def test_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
@@ -469,28 +432,13 @@ def __init__(self):
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
- if _TORCH_GREATER_EQUAL_2_1:
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
- assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
- assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
-
- strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
- assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
- assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
- else:
- strategy = FSDPStrategy(activation_checkpointing=Block1)
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
+ strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
+ assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
+ assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
- strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
-
- strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
- assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
+ strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
+ assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
+ assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
model = Model()
strategy._parallel_devices = [torch.device("cuda", 0)]
@@ -503,7 +451,7 @@ def __init__(self):
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
-def test_fsdp_strategy_cpu_offload():
+def test_strategy_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
@@ -515,7 +463,7 @@ def test_fsdp_strategy_cpu_offload():
assert strategy.cpu_offload == config
-def test_fsdp_sharding_strategy():
+def test_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy
@@ -534,9 +482,8 @@ def test_fsdp_sharding_strategy():
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
-@RunIf(min_torch="2.0")
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
-def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
+def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
@@ -549,6 +496,11 @@ def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy.kwargs["process_group"] is process_group
+ monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
+ with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
+ FSDPStrategy(device_mesh=Mock())
+
+ monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
@@ -558,17 +510,12 @@ def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
-def test_fsdp_use_orig_params():
- """Test that Lightning enables `use_orig_params` in PyTorch >= 2.0."""
- with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False):
- strategy = FSDPStrategy()
- assert "use_orig_params" not in strategy.kwargs
-
- with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", True):
- strategy = FSDPStrategy()
- assert strategy.kwargs["use_orig_params"]
- strategy = FSDPStrategy(use_orig_params=False)
- assert not strategy.kwargs["use_orig_params"]
+def test_use_orig_params():
+ """Test that Lightning enables `use_orig_params` automatically."""
+ strategy = FSDPStrategy()
+ assert strategy.kwargs["use_orig_params"]
+ strategy = FSDPStrategy(use_orig_params=False)
+ assert not strategy.kwargs["use_orig_params"]
@mock.patch("torch.distributed.init_process_group")
@@ -587,9 +534,8 @@ def test_set_timeout(init_process_group_mock):
)
-@RunIf(min_torch="2.0")
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
-def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
+def test_strategy_load_optimizer_states_multiple(_, tmp_path):
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], state_dict_type="full")
trainer = Trainer()
trainer.state.fn = TrainerFn.FITTING
@@ -611,9 +557,10 @@ def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
strategy.load_checkpoint(tmp_path / "one-state.ckpt")
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
-def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
+def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
@@ -644,12 +591,9 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
if trainer.global_rank != 0:
assert len(model_state_dict) == 0
- if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1 or not _TORCH_GREATER_EQUAL_2_0:
+ if trainer.global_rank != 0:
assert len(optimizer_state_dict) == 0
- if not _TORCH_GREATER_EQUAL_2_0:
- return
-
# restore model to ddp
model = TestBoringModel()
trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
@@ -672,9 +616,10 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):
trainer.strategy.barrier()
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
-def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
+def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
@@ -718,10 +663,10 @@ def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0
- if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1 or not _TORCH_GREATER_EQUAL_2_0:
+ if trainer.global_rank != 0:
assert len(restored_optimizer_state_dict) == 0
- if trainer.global_rank == 0 and _TORCH_GREATER_EQUAL_2_0:
+ if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
@@ -738,14 +683,17 @@ def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
("32-true", torch.float32),
],
)
-def test_configure_model(precision, expected_dtype):
+def test_configure_model(precision, expected_dtype, tmp_path):
"""Test that the module under configure_model gets moved to the right device and dtype."""
trainer = Trainer(
+ default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
max_epochs=1,
+ enable_checkpointing=False,
+ logger=False,
)
class MyModel(BoringModel):
@@ -770,33 +718,6 @@ def on_fit_start(self):
trainer.fit(model)
-@mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False)
-@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
-@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
-def test_load_save_optimizer_torch_lt_2_0(_, __, tmp_path):
- strategy = FSDPStrategy(state_dict_type="full")
- with pytest.warns(UserWarning, match="does not support saving the optimizer state"):
- strategy.optimizer_state(Mock())
-
- file = tmp_path / "test.ckpt"
- file.touch()
- trainer = Trainer()
- trainer.state.fn = TrainerFn.FITTING
- strategy._lightning_module = Mock(trainer=trainer)
- with pytest.warns(UserWarning, match="does not support loading the optimizer state"):
- strategy.load_checkpoint(file)
-
-
-@mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False)
-def test_sharded_state_dict_type_support():
- """Test that the sharded state dict type is supported."""
- with pytest.raises(
- NotImplementedError,
- match=escape("`FSDPStrategy(state_dict_type='sharded')` is not supported in PyTorch < 2.0"),
- ):
- FSDPStrategy(state_dict_type="sharded")
-
-
def test_save_checkpoint_storage_options(tmp_path):
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
strategy = FSDPStrategy()
@@ -804,13 +725,12 @@ def test_save_checkpoint_storage_options(tmp_path):
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock())
-@RunIf(min_torch="2.0.0")
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
@mock.patch("lightning.pytorch.strategies.fsdp.shutil")
-def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
+def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -826,16 +746,12 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
path.mkdir()
(path / "meta.pt").touch()
assert _is_sharded_checkpoint(path)
- model = Mock(spec=FullyShardedDataParallel)
- model.modules.return_value = [model]
strategy.save_checkpoint(Mock(), filepath=path)
shutil_mock.rmtree.assert_called_once_with(path)
# state_dict_type='full', path exists, path is a file: no error (overwrite)
path = tmp_path / "file.pt"
path.touch()
- model = Mock(spec=FullyShardedDataParallel)
- model.modules.return_value = [model]
torch_save_mock.reset_mock()
strategy.save_checkpoint(Mock(), filepath=path)
torch_save_mock.assert_called_once()
@@ -852,8 +768,6 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
- model = Mock(spec=FullyShardedDataParallel)
- model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert (path / "file").exists()
@@ -861,21 +775,19 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
- model = Mock(spec=FullyShardedDataParallel)
- model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert path.is_dir()
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
-def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
+def test_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy = FSDPStrategy(state_dict_type="invalid")
with pytest.raises(ValueError, match="Unknown state_dict_type"):
strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path)
-def test_fsdp_load_unknown_checkpoint_type(tmp_path):
+def test_load_unknown_checkpoint_type(tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
strategy = FSDPStrategy()
strategy.model = Mock()
@@ -903,7 +815,8 @@ def on_train_start(self):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
-@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_load_sharded_state_dict(tmp_path):
"""Test FSDP saving and loading with the sharded state dict format."""
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, state_dict_type="sharded")
@@ -926,7 +839,7 @@ def test_save_load_sharded_state_dict(tmp_path):
checkpoint_path = Path(trainer.strategy.broadcast(trainer.checkpoint_callback.best_model_path))
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
- metadata = torch.load(checkpoint_path / "meta.pt")
+ metadata = torch.load(checkpoint_path / "meta.pt", weights_only=True)
assert "pytorch-lightning_version" in metadata
assert len(metadata["callbacks"]) == 1 # model checkpoint callback
assert "state_dict" not in metadata
@@ -943,7 +856,7 @@ def test_save_load_sharded_state_dict(tmp_path):
@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
@mock.patch("lightning.pytorch.strategies.fsdp._lazy_load")
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
-def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
+def test_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
"""Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
model = BoringModel()
checkpoint = {"state_dict": model.state_dict()}
@@ -959,10 +872,7 @@ def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_
file.touch()
strategy.load_checkpoint(checkpoint_path=file)
- if _TORCH_GREATER_EQUAL_2_0:
- lazy_load_mock.assert_called_once()
- else:
- torch_load_mock.assert_called_once()
+ lazy_load_mock.assert_called_once()
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@@ -974,7 +884,7 @@ def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
-def test_module_init_context(precision, expected_dtype):
+def test_module_init_context(precision, expected_dtype, tmp_path):
"""Test that the module under the init-context gets moved to the right device and dtype."""
class Model(BoringModel):
@@ -990,12 +900,15 @@ def on_train_start(self):
def _run_setup_assertions(empty_init, expected_device):
trainer = Trainer(
+ default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy={torch.nn.Linear}),
precision=precision,
max_steps=1,
barebones=True,
+ enable_checkpointing=False,
+ logger=False,
)
with trainer.init_module(empty_init=empty_init):
model = Model()
@@ -1008,19 +921,24 @@ def _run_setup_assertions(empty_init, expected_device):
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
- if _TORCH_GREATER_EQUAL_2_1:
- # Case 2: Empty-init with PyTorch >= 2.1 supports meta device
- _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
- else:
- # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
- _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
+ # Case 2: Empty-init with meta device
+ _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
+@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
- model = BoringModel()
+ class CustomModel(BoringModel):
+ def configure_optimizers(self):
+ # Use Adam instead of SGD for this test because it has state
+ # In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'`
+ # when loading the optimizer state-dict back.
+ # TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy
+ return torch.optim.Adam(self.parameters(), lr=0.1)
+
+ model = CustomModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
@@ -1041,7 +959,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
torch.save(checkpoint, checkpoint_path_full)
trainer.strategy.barrier()
- model = BoringModel()
+ model = CustomModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py
new file mode 100644
index 0000000000000..731da66d4a61f
--- /dev/null
+++ b/tests/tests_pytorch/strategies/test_model_parallel.py
@@ -0,0 +1,249 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import timedelta
+from re import escape
+from unittest import mock
+from unittest.mock import Mock
+
+import pytest
+import torch
+import torch.nn as nn
+from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint
+from lightning.pytorch import LightningModule
+from lightning.pytorch.plugins.environments import LightningEnvironment
+from lightning.pytorch.strategies import ModelParallelStrategy
+
+from tests_pytorch.helpers.runif import RunIf
+
+
+@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
+def test_torch_greater_equal_2_4():
+ with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
+ ModelParallelStrategy()
+
+
+@RunIf(min_torch="2.4")
+def test_device_mesh_access():
+ strategy = ModelParallelStrategy()
+ with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
+ _ = strategy.device_mesh
+
+
+@RunIf(min_torch="2.4")
+@pytest.mark.parametrize(
+ ("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
+ [
+ (1, 4, 1, 1),
+ (1, 4, 2, 3),
+ (1, 4, 4, 2),
+ (2, 4, 1, 4),
+ (2, 4, 2, 1),
+ ],
+)
+def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, invalid_tp_size):
+ """Test passing sizes that don't multiply to the world size raises an error."""
+ strategy = ModelParallelStrategy(
+ data_parallel_size=invalid_dp_size,
+ tensor_parallel_size=invalid_tp_size,
+ )
+ strategy._setup_distributed = Mock()
+ strategy._accelerator = Mock()
+ strategy.cluster_environment = Mock(
+ world_size=Mock(return_value=(num_nodes * devices)), local_rank=Mock(return_value=1)
+ )
+ strategy.parallel_devices = [torch.device("cpu")] * devices
+ strategy.num_nodes = num_nodes
+ with pytest.raises(RuntimeError, match="multiplied should equal the world size"):
+ strategy.setup_environment()
+
+
+@RunIf(min_torch="2.4")
+def test_fsdp_v1_modules_unsupported():
+ """Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
+ from torch.distributed.fsdp import FullyShardedDataParallel
+
+ class Model(LightningModule):
+ def configure_model(self):
+ pass
+
+ model = Model()
+ model.modules = Mock(return_value=[Mock(spec=FullyShardedDataParallel)])
+ strategy = ModelParallelStrategy()
+ strategy.model = model
+ strategy._lightning_module = model
+ strategy._accelerator = Mock()
+
+ with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
+ strategy.setup(Mock())
+
+
+@RunIf(min_torch="2.4")
+def test_configure_model_required():
+ class Model1(LightningModule):
+ pass
+
+ class Model2(LightningModule):
+ def configure_model(self):
+ pass
+
+ model = Model1()
+ strategy = ModelParallelStrategy()
+ strategy.model = model
+ strategy._lightning_module = model
+ strategy._accelerator = Mock()
+ strategy._parallel_devices = [torch.device("cpu")]
+
+ with pytest.raises(TypeError, match="you are required to override the `configure_model"):
+ strategy.setup(Mock())
+
+ model = Model2()
+ strategy.model = model
+ strategy._lightning_module = model
+ strategy.setup(Mock())
+
+
+@RunIf(min_torch="2.4")
+def test_save_checkpoint_storage_options(tmp_path):
+ """Test that the strategy does not accept storage options for saving checkpoints."""
+ strategy = ModelParallelStrategy()
+ with pytest.raises(
+ TypeError, match=escape("ModelParallelStrategy.save_checkpoint(..., storage_options=...)` is not")
+ ):
+ strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path, storage_options=Mock())
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.pytorch.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
+@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
+@mock.patch("lightning.pytorch.strategies.model_parallel.shutil")
+def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path):
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=False)
+
+ # save_distributed_checkpoint=False, path exists, path is not a sharded checkpoint: error
+ path = tmp_path / "not-empty"
+ path.mkdir()
+ (path / "file").touch()
+ assert not _is_sharded_checkpoint(path)
+ with pytest.raises(IsADirectoryError, match="exists and is a directory"):
+ strategy.save_checkpoint(Mock(), filepath=path)
+
+ # save_distributed_checkpoint=False, path exists, path is a sharded checkpoint: no error (overwrite)
+ path = tmp_path / "sharded-checkpoint"
+ path.mkdir()
+ (path / "meta.pt").touch()
+ assert _is_sharded_checkpoint(path)
+ strategy.save_checkpoint(Mock(), filepath=path)
+ shutil_mock.rmtree.assert_called_once_with(path)
+
+ # save_distributed_checkpoint=False, path exists, path is a file: no error (overwrite)
+ path = tmp_path / "file.pt"
+ path.touch()
+ torch_save_mock.reset_mock()
+ strategy.save_checkpoint(Mock(), filepath=path)
+ torch_save_mock.assert_called_once()
+
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=True)
+
+ save_mock = mock.patch("torch.distributed.checkpoint.save")
+
+ # save_distributed_checkpoint=True, path exists, path is a folder: no error (overwrite)
+ path = tmp_path / "not-empty-2"
+ path.mkdir()
+ (path / "file").touch()
+ with save_mock:
+ strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
+ assert (path / "file").exists()
+
+ # save_distributed_checkpoint=True, path exists, path is a file: no error (overwrite)
+ path = tmp_path / "file-2.pt"
+ path.touch()
+ with save_mock:
+ strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
+ assert path.is_dir()
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
+def test_load_unknown_checkpoint_type(_, tmp_path):
+ """Test that the strategy validates the contents at the checkpoint path."""
+ strategy = ModelParallelStrategy()
+ strategy.model = Mock()
+ strategy._lightning_module = Mock(strict_loading=True)
+ path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
+ path.mkdir()
+ with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
+ strategy.load_checkpoint(checkpoint_path=path)
+
+
+@RunIf(min_torch="2.4")
+@mock.patch("lightning.pytorch.strategies.model_parallel._setup_device_mesh")
+@mock.patch("torch.distributed.init_process_group")
+def test_set_timeout(init_process_group_mock, _):
+ """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
+ test_timedelta = timedelta(seconds=30)
+ strategy = ModelParallelStrategy(timeout=test_timedelta)
+ strategy._lightning_module = Mock()
+ strategy.parallel_devices = [torch.device("cpu")]
+ strategy.cluster_environment = LightningEnvironment()
+ strategy.accelerator = Mock()
+ strategy.setup_environment()
+ process_group_backend = strategy._get_process_group_backend()
+ global_rank = strategy.cluster_environment.global_rank()
+ world_size = strategy.cluster_environment.world_size()
+ init_process_group_mock.assert_called_with(
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ )
+
+
+@RunIf(min_torch="2.4")
+def test_meta_device_materialization():
+ """Test that the `setup()` method materializes meta-device tensors in the LightningModule."""
+
+ class NoResetParameters(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(4, 4))
+
+ class CustomModel(LightningModule):
+ def __init__(self):
+ super().__init__()
+ # nn.Sequential as a parameterless module
+ self.layer1 = nn.Sequential(NoResetParameters(), NoResetParameters())
+ self.layer2 = nn.Linear(4, 4)
+ self.register_buffer("buffer", torch.rand(2))
+
+ def reset_parameters(self):
+ self.buffer.fill_(1.0)
+
+ def configure_model(self) -> None:
+ pass
+
+ with torch.device("meta"):
+ model = CustomModel()
+ assert model.layer1[0].weight.is_meta
+ assert model.layer2.weight.is_meta
+ assert model.buffer.is_meta
+
+ strategy = ModelParallelStrategy()
+ strategy._accelerator = Mock()
+ strategy._device_mesh = Mock()
+ strategy._parallel_devices = [torch.device("cpu")]
+ strategy._lightning_module = model
+ strategy.model = model
+
+ with pytest.warns(UserWarning, match=r"`reset_parameters\(\)` method for re-initialization: NoResetParameters"):
+ strategy.setup(Mock())
+ assert all(not p.is_meta for p in model.parameters())
+ assert all(not b.is_meta for b in model.buffers())
diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py
new file mode 100644
index 0000000000000..57d273917573a
--- /dev/null
+++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py
@@ -0,0 +1,511 @@
+# Copyright The Lightning AI team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from pathlib import Path
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from lightning.pytorch import LightningModule, Trainer, seed_everything
+from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
+from lightning.pytorch.strategies import ModelParallelStrategy
+from torch.utils.data import DataLoader, DistributedSampler
+from torchmetrics.classification import Accuracy
+
+from tests_pytorch.helpers.runif import RunIf
+
+
+class FeedForward(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.w1 = nn.Linear(32, 64)
+ self.w2 = nn.Linear(32, 64)
+ self.w3 = nn.Linear(64, 32)
+
+ def forward(self, x):
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
+
+
+def _parallelize_feed_forward_tp(model, device_mesh):
+ from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
+
+ tp_mesh = device_mesh["tensor_parallel"]
+ tp_plan = {
+ "w1": ColwiseParallel(),
+ "w2": ColwiseParallel(),
+ "w3": RowwiseParallel(),
+ }
+ parallelize_module(model, tp_mesh, tp_plan)
+ return model
+
+
+def _parallelize_feed_forward_fsdp2(model, device_mesh):
+ from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
+ dp_mesh = device_mesh["data_parallel"]
+ assert dp_mesh.ndim == 1 # Hybrid-sharding not supported
+
+ # Fully-shard each layer
+ fully_shard(model.w1, mesh=dp_mesh)
+ fully_shard(model.w2, mesh=dp_mesh)
+ fully_shard(model.w3, mesh=dp_mesh)
+
+ # TODO: Re-enable activation checkpointing
+ # Currently, state dict keys get prefixed with '_checkpoint_wrapper' in the keys
+ # which leads to mismatches when loading weights into a checkpoint-wrapped module.
+ # PyTorch should handle this automatically.
+
+ # model = checkpoint_wrapper(model)
+
+ return model
+
+
+def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
+ model = _parallelize_feed_forward_tp(model, device_mesh)
+ model = _parallelize_feed_forward_fsdp2(model, device_mesh)
+ return model
+
+
+class TemplateModel(LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.model = FeedForward()
+
+ def training_step(self, batch):
+ output = self.model(batch)
+ return output.sum()
+
+ def train_dataloader(self):
+ dataset_size = 8
+ dataset = RandomDataset(32, dataset_size)
+ return DataLoader(dataset, batch_size=2)
+
+ def configure_optimizers(self):
+ return torch.optim.AdamW(self.model.parameters())
+
+
+class FSDP2Model(TemplateModel):
+ def configure_model(self):
+ _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh)
+
+
+class TensorParallelModel(TemplateModel):
+ def configure_model(self):
+ _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh)
+
+
+class FSDP2TensorParallelModel(TemplateModel):
+ def configure_model(self):
+ _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh)
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
+def test_setup_device_mesh():
+ from torch.distributed.device_mesh import DeviceMesh
+
+ for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
+ strategy = ModelParallelStrategy(
+ data_parallel_size=dp_size,
+ tensor_parallel_size=tp_size,
+ )
+ trainer = Trainer(
+ accelerator="auto",
+ devices=4,
+ strategy=strategy,
+ logger=False,
+ enable_checkpointing=False,
+ max_steps=1,
+ )
+
+ class Model(BoringModel):
+ def configure_model(self):
+ device_mesh = self.device_mesh
+ assert isinstance(device_mesh, DeviceMesh)
+ assert device_mesh.device_type == model.device.type
+ assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
+ assert device_mesh.size(0) == dp_size
+ assert device_mesh.size(1) == tp_size
+ assert device_mesh.ndim == 2
+
+ model = Model()
+ trainer.fit(model)
+
+ # Passing "auto" will select internode and intranode dimensions automatically
+ strategy = ModelParallelStrategy(
+ data_parallel_size="auto",
+ tensor_parallel_size="auto",
+ )
+ trainer = Trainer(
+ accelerator="auto",
+ devices=4,
+ num_nodes=1,
+ strategy=strategy,
+ logger=False,
+ enable_checkpointing=False,
+ max_steps=1,
+ )
+
+ class Model(BoringModel):
+ def configure_model(self):
+ device_mesh = self.device_mesh
+ assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
+ assert device_mesh.size(0) == 1
+ assert device_mesh.size(1) == 4
+
+ model = Model()
+ trainer.fit(model)
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
+def test_tensor_parallel():
+ from torch.distributed._tensor import DTensor
+
+ class Model(TensorParallelModel):
+ def on_train_start(self):
+ device_mesh = self.device_mesh
+ optimizer = self.optimizers()
+ assert all(
+ tensor.device_mesh == device_mesh["tensor_parallel"] for tensor in optimizer.param_groups[0]["params"]
+ )
+ assert all(isinstance(weight, DTensor) for weight in self.model.parameters())
+ assert self.model.w1.weight.device_mesh == device_mesh["tensor_parallel"]
+
+ # No data sharding, all GPUs get the same input inside a TP group
+ dataloader = self.trainer.train_dataloader
+ assert len(dataloader) == 8 // dataloader.batch_size
+ assert isinstance(dataloader.sampler, DistributedSampler)
+
+ def training_step(self, batch):
+ # All batches must be identical across TP group
+ batches = self.all_gather(batch)
+ assert all(torch.equal(batches[0], batches[i]) for i in range(1, len(batches)))
+ return super().training_step(batch)
+
+ trainer = Trainer(
+ accelerator="auto",
+ devices=2,
+ strategy=ModelParallelStrategy(),
+ max_steps=2,
+ enable_checkpointing=False,
+ logger=False,
+ )
+
+ seed_everything(0)
+ with trainer.init_module(empty_init=True):
+ model = Model()
+
+ trainer.fit(model)
+
+
+@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
+def test_fsdp2_tensor_parallel():
+ from torch.distributed._tensor import DTensor
+
+ class Model(FSDP2TensorParallelModel):
+ def on_train_start(self):
+ optimizer = self.optimizers()
+ assert all(isinstance(weight, DTensor) for weight in self.model.parameters())
+ assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"])
+ assert self.model.w1.weight.device_mesh.ndim == 2
+ assert self.model.w1.weight.device_mesh.size(0) == 2
+ assert self.model.w1.weight.device_mesh.size(1) == 2
+ assert all(weight.device.type != "meta" for weight in self.model.parameters())
+ assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"])
+ assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"])
+
+ # No data sharding across TP dimension, sharding across data-parallel dimension only
+ device_mesh = self.device_mesh
+ dp_mesh = device_mesh["data_parallel"]
+ dataloader = self.trainer.train_dataloader
+ assert len(dataloader) == 8 // dataloader.batch_size // dp_mesh.size()
+ assert isinstance(dataloader.sampler, DistributedSampler)
+
+ def training_step(self, batch):
+ batches = self.all_gather(batch)
+ dp_mesh = self.device_mesh["data_parallel"]
+ tp_mesh = self.device_mesh["tensor_parallel"]
+
+ # Batches across the TP dimension must be identical
+ batches_tp = batches[tp_mesh.mesh]
+ assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp)))
+ # Batches across the DP dimension must be different
+ batches_dp = batches[dp_mesh.mesh]
+ assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp)))
+
+ return super().training_step(batch)
+
+ strategy = ModelParallelStrategy(
+ data_parallel_size=2,
+ tensor_parallel_size=2,
+ )
+ trainer = Trainer(
+ accelerator="auto",
+ devices=4,
+ strategy=strategy,
+ max_steps=2,
+ enable_checkpointing=False,
+ logger=False,
+ )
+
+ seed_everything(0)
+ with trainer.init_module(empty_init=True):
+ model = Model()
+
+ trainer.fit(model)
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_modules_without_parameters(tmp_path):
+ """Test that TorchMetrics get moved to the device despite not having any parameters."""
+
+ class MetricsModel(TensorParallelModel):
+ def __init__(self):
+ super().__init__()
+ self.metric = Accuracy("multiclass", num_classes=10)
+ assert self.metric.device == self.metric.tp.device == torch.device("cpu")
+
+ def setup(self, stage) -> None:
+ assert self.metric.device == self.metric.tp.device == torch.device("cpu")
+
+ def training_step(self, batch):
+ assert self.metric.device.type == self.metric.tp.device.type == "cuda"
+ self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device))
+ return super().training_step(batch)
+
+ model = MetricsModel()
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="cuda",
+ devices=2,
+ strategy=ModelParallelStrategy(),
+ max_steps=1,
+ enable_checkpointing=False,
+ logger=False,
+ )
+ trainer.fit(model)
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+@pytest.mark.parametrize(
+ ("precision", "expected_dtype"),
+ [
+ ("32-true", torch.float32),
+ ("16-true", torch.float16),
+ pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
+ ],
+)
+def test_module_init_context(precision, expected_dtype, tmp_path):
+ """Test that the module under the init-context gets moved to the right device and dtype."""
+
+ class Model(FSDP2Model):
+ def on_train_start(self):
+ assert self.model.w1.weight.device == torch.device("cuda", self.local_rank)
+ assert self.model.w1.weight.dtype == expected_dtype
+ optimizer = self.optimizers(use_pl_optimizer=False)
+ assert optimizer.param_groups[0]["params"][0].device.type == "cuda"
+
+ def _run_setup_assertions(empty_init, expected_device):
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="cuda",
+ devices=2,
+ strategy=ModelParallelStrategy(),
+ precision=precision,
+ max_steps=1,
+ barebones=True,
+ enable_checkpointing=False,
+ logger=False,
+ )
+ with trainer.init_module(empty_init=empty_init):
+ model = Model()
+
+ # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()`
+ assert model.model.w1.weight.device == expected_device
+ assert model.model.w1.weight.dtype == expected_dtype
+ trainer.fit(model)
+
+ # Case 1: No empty init
+ _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
+
+ # Case 2: Empty-init with meta device
+ _run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
+@pytest.mark.parametrize("save_distributed_checkpoint", [True, False])
+def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
+ """Test that the strategy returns the correct state dict of the LightningModule."""
+ model = FSDP2Model()
+ correct_state_dict = model.state_dict() # State dict before wrapping
+
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=save_distributed_checkpoint)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="cuda",
+ devices=2,
+ strategy=strategy,
+ max_epochs=1,
+ barebones=True,
+ )
+ trainer.fit(model)
+
+ state_dict = trainer.strategy.lightning_module_state_dict()
+
+ if save_distributed_checkpoint:
+ # All ranks return a state dict
+ assert len(state_dict) > 0
+ # State dict should contain same keys as non-distributed state dict
+ assert list(state_dict.keys()) == list(correct_state_dict.keys())
+ else:
+ if trainer.global_rank != 0:
+ # The full state-dict is only returned on rank 0
+ assert len(state_dict) == 0
+ return
+ # State dict should contain same keys as non-distributed state dict
+ assert list(state_dict.keys()) == list(correct_state_dict.keys())
+
+
+@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
+def test_load_full_state_checkpoint_into_regular_model(tmp_path):
+ """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model."""
+
+ # Save a regular full-state checkpoint from a distributed model
+ model = FSDP2Model()
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=False)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="gpu",
+ devices=2,
+ strategy=strategy,
+ max_epochs=1,
+ barebones=True,
+ )
+ trainer.fit(model)
+ model_path = tmp_path / "last.ckpt"
+ model_path = trainer.strategy.broadcast(model_path)
+ trainer.save_checkpoint(model_path)
+ model_state_dict = trainer.strategy.lightning_module_state_dict()
+ optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
+
+ if trainer.global_rank != 0:
+ assert len(model_state_dict) == 0
+ assert len(optimizer_state_dict) == 0
+
+ # Create a regular model and load the checkpoint into it
+ model = TemplateModel()
+ trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
+ trainer.fit(model, ckpt_path=model_path)
+ restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
+ restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
+
+ if trainer.global_rank == 0:
+ assert len(model_state_dict) == len(restored_model_state_dict)
+ assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
+ torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
+ torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
+ trainer.strategy.barrier()
+
+
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
+def test_load_standard_checkpoint_into_distributed_model(tmp_path):
+ """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model."""
+
+ # Save a regular DDP checkpoint
+ model = TemplateModel()
+ trainer = Trainer(default_root_dir=tmp_path, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
+ trainer.fit(model)
+ model_path = tmp_path / "last.ckpt"
+ model_path = trainer.strategy.broadcast(model_path)
+ trainer.save_checkpoint(model_path)
+ model_state_dict = trainer.strategy.lightning_module_state_dict()
+ optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
+
+ # Create a distributed model and load the checkpoint into it
+ model = FSDP2Model()
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=False)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ accelerator="gpu",
+ devices=2,
+ strategy=strategy,
+ max_epochs=1,
+ barebones=True,
+ )
+ trainer.fit(model, ckpt_path=model_path)
+ restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
+ restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
+
+ if trainer.global_rank != 0:
+ assert len(restored_model_state_dict) == 0
+ assert len(restored_optimizer_state_dict) == 0
+ if trainer.global_rank == 0:
+ assert len(model_state_dict) == len(restored_model_state_dict)
+ assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
+ torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
+ torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
+ trainer.strategy.barrier()
+
+
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
+def test_save_load_sharded_state_dict(tmp_path):
+ """Test saving and loading with the distributed state dict format."""
+
+ class CheckpointModel(FSDP2Model):
+ def __init__(self, params_to_compare=None):
+ super().__init__()
+ self.params_to_compare = params_to_compare
+
+ def on_train_start(self):
+ if self.params_to_compare is None:
+ return
+ for p0, p1 in zip(self.params_to_compare, self.trainer.model.parameters()):
+ assert torch.equal(p0, p1.full_tensor())
+
+ seed_everything(0)
+
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=True)
+ trainer_kwargs = {
+ "default_root_dir": tmp_path,
+ "accelerator": "cuda",
+ "devices": 2,
+ "max_epochs": 1,
+ "enable_progress_bar": False,
+ "enable_model_summary": False,
+ "logger": False,
+ }
+
+ # Initial training
+ model = CheckpointModel()
+ trainer = Trainer(**trainer_kwargs, strategy=strategy)
+ trainer.fit(model)
+ params_before = [p.full_tensor() for p in trainer.model.parameters()]
+
+ checkpoint_path = Path(trainer.strategy.broadcast(trainer.checkpoint_callback.best_model_path))
+ assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
+
+ metadata = torch.load(checkpoint_path / "meta.pt", weights_only=True)
+ assert "pytorch-lightning_version" in metadata
+ assert len(metadata["callbacks"]) == 1 # model checkpoint callback
+ assert "state_dict" not in metadata
+ assert "optimizer_states" not in metadata
+
+ # Load checkpoint and continue training
+ trainer_kwargs.update(max_epochs=2)
+ model = CheckpointModel(params_to_compare=params_before)
+ strategy = ModelParallelStrategy(save_distributed_checkpoint=True)
+ trainer = Trainer(**trainer_kwargs, strategy=strategy)
+ trainer.fit(model, ckpt_path=checkpoint_path)
diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py
index abc1c83ec5143..90e15638bfd06 100644
--- a/tests/tests_pytorch/strategies/test_registry.py
+++ b/tests/tests_pytorch/strategies/test_registry.py
@@ -40,7 +40,7 @@ def test_strategy_registry_with_deepspeed_strategies(strategy_name, init_params)
@RunIf(deepspeed=True)
@pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"])
-def test_deepspeed_strategy_registry_with_trainer(tmp_path, strategy):
+def test_deepspeed_strategy_registry_with_trainer(tmp_path, strategy, mps_count_0):
trainer = Trainer(default_root_dir=tmp_path, strategy=strategy, precision="16-mixed")
assert isinstance(trainer.strategy, DeepSpeedStrategy)
diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py
index 605f4ec3845e7..56b58d4d157a1 100644
--- a/tests/tests_pytorch/test_cli.py
+++ b/tests/tests_pytorch/test_cli.py
@@ -48,6 +48,7 @@
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning_utilities import compare_version
from lightning_utilities.test.warning import no_warning_call
+from packaging.version import Version
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
from torch.optim import SGD
@@ -64,6 +65,14 @@ def lazy_instance(*args, **kwargs):
return None
+_xfail_python_ge_3_11_9 = pytest.mark.xfail(
+ # https://github.com/omni-us/jsonargparse/issues/484
+ Version(f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") >= Version("3.11.9"),
+ strict=False,
+ reason="jsonargparse + Python 3.11.9 compatibility issue",
+)
+
+
@contextmanager
def mock_subclasses(baseclass, *subclasses):
"""Mocks baseclass so that it only has the given child subclasses."""
@@ -347,6 +356,7 @@ def test_save_to_log_dir_false_error():
)
+@_xfail_python_ge_3_11_9
def test_lightning_cli_logger_save_config(cleandir):
class LoggerSaveConfigCallback(SaveConfigCallback):
def __init__(self, *args, **kwargs) -> None:
@@ -456,7 +466,7 @@ def test_lightning_cli_help():
), pytest.raises(SystemExit):
any_model_any_data_cli()
- assert "--data.init_args.data_dir" in out.getvalue()
+ assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue())
def test_lightning_cli_print_config():
@@ -736,6 +746,7 @@ def add_arguments_to_parser(self, parser):
assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50
+@_xfail_python_ge_3_11_9
@pytest.mark.parametrize("use_generic_base_class", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class):
class MyLightningCLI(LightningCLI):
@@ -782,7 +793,7 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
-@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
+@_xfail_python_ge_3_11_9
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
@@ -870,7 +881,7 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
- ckpt = torch.load(checkpoint_path)
+ ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
@@ -889,17 +900,15 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
- "class_path": f"{__name__}.TestModelSaveHparams",
- "init_args": {
- "optimizer": "torch.optim.Adam",
- "scheduler": "torch.optim.lr_scheduler.ConstantLR",
- "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
- },
+ "_class_path": f"{__name__}.TestModelSaveHparams",
+ "optimizer": "torch.optim.Adam",
+ "scheduler": "torch.optim.lr_scheduler.ConstantLR",
+ "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
assert checkpoint_path.is_file()
- ckpt = torch.load(checkpoint_path)
+ ckpt = torch.load(checkpoint_path, weights_only=True)
assert ckpt["hyper_parameters"] == expected
model = LightningModule.load_from_checkpoint(checkpoint_path)
@@ -911,6 +920,38 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)
+class TestModelSaveHparamsUntyped(BoringModel):
+ def __init__(self, learning_rate, step_size=None, **kwargs):
+ super().__init__()
+ self.save_hyperparameters()
+ self.learning_rate = learning_rate
+ self.step_size = step_size
+ self.kwargs = kwargs
+
+
+def test_lightning_cli_save_hyperparameters_untyped_module(cleandir):
+ config = {
+ "model": {
+ "class_path": f"{__name__}.TestModelSaveHparamsUntyped",
+ "init_args": {"learning_rate": 1e-2},
+ "dict_kwargs": {"x": 1},
+ }
+ }
+ with mock.patch("sys.argv", ["any.py", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]):
+ cli = LightningCLI(BoringModel, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
+ cli.trainer.fit(cli.model)
+ assert isinstance(cli.model, TestModelSaveHparamsUntyped)
+ assert cli.model.hparams["learning_rate"] == 1e-2
+ assert cli.model.hparams["step_size"] is None
+ assert cli.model.hparams["x"] == 1
+
+ checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
+ model = TestModelSaveHparamsUntyped.load_from_checkpoint(checkpoint_path)
+ assert model.learning_rate == 1e-2
+ assert model.step_size is None
+ assert model.kwargs == {"x": 1}
+
+
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
@@ -1031,6 +1072,7 @@ def __init__(self, foo, bar=5):
self.bar = bar
+@_xfail_python_ge_3_11_9
def test_lightning_cli_model_short_arguments():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
@@ -1055,6 +1097,7 @@ def __init__(self, foo, bar=5):
self.bar = bar
+@_xfail_python_ge_3_11_9
def test_lightning_cli_datamodule_short_arguments():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
@@ -1100,6 +1143,7 @@ def test_lightning_cli_datamodule_short_arguments():
assert cli.parser.groups["data"].group_class is BoringDataModule
+@_xfail_python_ge_3_11_9
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_callbacks_append(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
@@ -1143,6 +1187,7 @@ def test_callbacks_append(use_class_path_callbacks):
assert all(t in callback_types for t in expected)
+@_xfail_python_ge_3_11_9
def test_optimizers_and_lr_schedulers_reload(cleandir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
@@ -1174,6 +1219,7 @@ def test_optimizers_and_lr_schedulers_reload(cleandir):
LightningCLI(BoringModel, run=False)
+@_xfail_python_ge_3_11_9
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(cleandir):
class TestLightningCLI(LightningCLI):
def __init__(self, *args):
@@ -1427,6 +1473,7 @@ def test_cli_help_message():
assert "Implements Adam" in shorthand_help.getvalue()
+@_xfail_python_ge_3_11_9
def test_cli_reducelronplateau():
with mock.patch(
"sys.argv", ["any.py", "--optimizer=Adam", "--lr_scheduler=ReduceLROnPlateau", "--lr_scheduler.monitor=foo"]
@@ -1437,6 +1484,7 @@ def test_cli_reducelronplateau():
assert config["lr_scheduler"]["scheduler"].monitor == "foo"
+@_xfail_python_ge_3_11_9
def test_cli_configureoptimizers_can_be_overridden():
class MyCLI(LightningCLI):
def __init__(self):
@@ -1481,6 +1529,7 @@ def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReL
assert cli.model.activation is not model.activation
+@_xfail_python_ge_3_11_9
def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
@@ -1496,6 +1545,7 @@ def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
assert strategy_default is not cli.config_init.trainer.strategy
+@_xfail_python_ge_3_11_9
def test_cli_logger_shorthand():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
@@ -1526,6 +1576,7 @@ def _test_logger_init_args(logger_name, init, unresolved=None):
assert data["dict_kwargs"] == unresolved
+@_xfail_python_ge_3_11_9
def test_comet_logger_init_args():
_test_logger_init_args(
"CometLogger",
@@ -1541,6 +1592,7 @@ def test_comet_logger_init_args():
strict=False,
reason="TypeError on Windows when parsing",
)
+@_xfail_python_ge_3_11_9
def test_neptune_logger_init_args():
_test_logger_init_args(
"NeptuneLogger",
@@ -1549,6 +1601,7 @@ def test_neptune_logger_init_args():
)
+@_xfail_python_ge_3_11_9
def test_tensorboard_logger_init_args():
_test_logger_init_args(
"TensorBoardLogger",
@@ -1560,6 +1613,7 @@ def test_tensorboard_logger_init_args():
)
+@_xfail_python_ge_3_11_9
def test_wandb_logger_init_args():
_test_logger_init_args(
"WandbLogger",
@@ -1644,6 +1698,7 @@ def __init__(self, a_func: Callable = torch.nn.Softmax):
assert "a_func: torch.nn.Softmax" in out.getvalue()
+@_xfail_python_ge_3_11_9
def test_pytorch_profiler_init_args():
from lightning.pytorch.profilers import Profiler, PyTorchProfiler
diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
index ee19c3951cb25..65c5777e28fed 100644
--- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py
@@ -14,6 +14,7 @@
import inspect
import os
import sys
+from contextlib import nullcontext
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock
@@ -48,6 +49,7 @@
DDPStrategy,
DeepSpeedStrategy,
FSDPStrategy,
+ ModelParallelStrategy,
SingleDeviceStrategy,
SingleDeviceXLAStrategy,
XLAStrategy,
@@ -564,12 +566,19 @@ def test_strategy_choice_ddp_cpu_slurm(cuda_count_0, strategy):
def test_check_fsdp_strategy_and_fallback():
- with pytest.raises(
- MisconfigurationException,
- match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
- ):
+ with pytest.raises(ValueError, match="The strategy `fsdp` requires a GPU accelerator"):
Trainer(accelerator="cpu", strategy="fsdp")
+ class FSDPStrategySubclass(FSDPStrategy):
+ pass
+
+ class AcceleratorSubclass(CPUAccelerator):
+ pass
+
+ # we allow subclasses of FSDPStrategy to be used with other accelerators
+ Trainer(accelerator="cpu", strategy=FSDPStrategySubclass())
+ Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())
+
@mock.patch.dict(os.environ, {}, clear=True)
def test_unsupported_tpu_choice(xla_available, tpu_available):
@@ -1056,3 +1065,14 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
+
+
+@RunIf(min_torch="2.4")
+@pytest.mark.parametrize(
+ ("precision", "raises"),
+ [("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
+)
+def test_precision_selection_model_parallel(precision, raises, mps_count_0):
+ error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
+ with error_context:
+ _AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())
diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
index 440400357101f..eb09413cafcce 100644
--- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
@@ -14,11 +14,11 @@
import contextlib
import logging
from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock
import pytest
import torch
-from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import (
EarlyStopping,
@@ -144,7 +144,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
)
trainer.fit(model)
- ckpt = torch.load(str(tmp_path / "all_states.ckpt"))
+ ckpt = torch.load(str(tmp_path / "all_states.ckpt"), weights_only=True)
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
@@ -293,20 +293,16 @@ def factory_multiple_callbacks_list():
@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
- query_mock = Mock()
+ query_mock = MagicMock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
- import_path = "importlib.metadata.entry_points"
- elif _PYTHON_GREATER_EQUAL_3_8_0:
- query_mock().get.return_value = [entry_point]
- import_path = "importlib.metadata.entry_points"
else:
- query_mock.return_value = [entry_point]
- import_path = "pkg_resources.iter_entry_points"
- with mock.patch(import_path, query_mock):
+ query_mock().get.return_value = [entry_point]
+
+ with mock.patch("lightning.fabric.utilities.registry.entry_points", query_mock):
yield
diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py
index 62c94d4cc277e..b47bf2d5b03fb 100644
--- a/tests/tests_pytorch/trainer/flags/test_env_vars.py
+++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py
@@ -25,7 +25,7 @@ def test_passing_no_env_variables():
assert trainer.logger is not None
assert trainer.max_steps == -1
assert trainer.max_epochs is None
- trainer = Trainer(logger=False, max_steps=1)
+ trainer = Trainer(max_steps=1, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert trainer.logger is None
assert trainer.max_steps == 1
@@ -49,7 +49,7 @@ def test_passing_env_variables_defaults():
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"})
-def test_passing_env_variables_devices(cuda_count_2):
+def test_passing_env_variables_devices(cuda_count_2, mps_count_0):
"""Testing overwriting trainer arguments."""
trainer = Trainer()
assert trainer.num_devices == 2
diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py
index c262f0ca33806..bae7b66dbbd55 100644
--- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py
+++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py
@@ -16,7 +16,6 @@
import pytest
import torch
-from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops import _Loop
@@ -81,7 +80,5 @@ def run(self): ...
f.run()
no_grad_mock.assert_called_once_with()
f.inference_mode = True
- with mock.patch("torch.inference_mode") as inference_mode_mock:
+ with mock.patch("torch.inference_mode"):
f.run()
- if not _TORCH_EQUAL_2_0:
- inference_mode_mock.assert_called_once_with()
diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py
index 73e8017dcfcfc..25aaeb8cff77e 100644
--- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py
+++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py
@@ -36,7 +36,7 @@ def test_min_max_steps_epochs(tmp_path, min_epochs, max_epochs, min_steps, max_s
assert trainer.global_step == trainer.max_steps
-def test_max_epochs_not_set_warning():
+def test_max_epochs_not_set_warning(tmp_path):
"""Test that a warning is only emitted when `max_epochs` was not set by the user."""
class CustomModel(BoringModel):
@@ -46,7 +46,7 @@ def training_step(self, *args, **kwargs):
match = "`max_epochs` was not set. Setting it to 1000 epochs."
model = CustomModel()
- trainer = Trainer(max_epochs=None, limit_train_batches=1)
+ trainer = Trainer(logger=False, enable_checkpointing=False, max_epochs=None, limit_train_batches=1)
with pytest.warns(PossibleUserWarning, match=match):
trainer.fit(model)
diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
index ac94cc0482c47..b776263e9953d 100644
--- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
+++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py
@@ -37,7 +37,13 @@ def on_validation_epoch_start(self) -> None:
self.val_epoch_calls += 1
model = TestModel()
- trainer = Trainer(max_epochs=max_epochs, val_check_interval=1 / denominator, logger=False)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ enable_checkpointing=False,
+ logger=False,
+ max_epochs=max_epochs,
+ val_check_interval=1 / denominator,
+ )
trainer.fit(model)
assert model.train_epoch_calls == max_epochs
@@ -107,6 +113,8 @@ def test_validation_check_interval_exceed_data_length_wrong():
trainer = Trainer(
limit_train_batches=10,
val_check_interval=100,
+ logger=False,
+ enable_checkpointing=False,
)
model = BoringModel()
diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
index f005a41e0d75c..b90d767a23caf 100644
--- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
+++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
@@ -24,7 +24,6 @@
import numpy as np
import pytest
import torch
-from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import Trainer, callbacks
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
@@ -557,8 +556,7 @@ def test_step(self, batch, batch_idx):
]
def get_metrics_at_idx(idx):
- mock_call = mock_log_metrics.mock_calls[idx]
- return mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
+ return mock_log_metrics.mock_calls[idx].kwargs["metrics"]
assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[2]
assert get_metrics_at_idx(3)["valid_loss_0_step"] == model.val_losses[3]
@@ -736,8 +734,7 @@ def test_dataloader(self):
cb_metrics = set(trainer.callback_metrics)
assert cb_metrics == {"foo/dataloader_idx_0", "foo/dataloader_idx_1", "foobar"}
- mock_call = mock_log_metrics.mock_calls[0]
- logged_metrics = mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
+ logged_metrics = mock_log_metrics.mock_calls[0].kwargs["metrics"]
cb_metrics.add("epoch")
assert set(logged_metrics) == cb_metrics
diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py
index 32923d444d4bb..e96857a6c192d 100644
--- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py
+++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py
@@ -33,6 +33,7 @@
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchmetrics import AveragePrecision as AvgPre
+from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.models.test_hooks import get_members
@@ -639,3 +640,23 @@ def test_result_collection_no_batch_size_extraction():
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val
+
+
+@RunIf(min_cuda_gpus=1)
+def test_result_collection_changes_device():
+ """Test that the keys in the ResultCollection are moved to the device together with the collection."""
+ results = _ResultCollection(training=True)
+ fx, name = "training_step", "step_log_val"
+ log_val = torch.tensor(7.0, device="cuda:0")
+
+ # same device as the original tensor
+ results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
+ assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device
+
+ # moved to cpu
+ results.cpu()
+ assert results[f"{fx}.{name}"].cumulated_batch_size.device == torch.device("cpu")
+
+ # same device as the new tensor
+ results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
+ assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device
diff --git a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py
index e464f6dbac58d..b91dbff8c6d09 100644
--- a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py
+++ b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py
@@ -11,7 +11,7 @@
def test_backward_count_simple(torch_backward, num_steps):
"""Test that backward is called exactly once per step."""
model = BoringModel()
- trainer = Trainer(max_steps=num_steps)
+ trainer = Trainer(max_steps=num_steps, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == num_steps
@@ -25,19 +25,21 @@ def test_backward_count_simple(torch_backward, num_steps):
def test_backward_count_with_grad_accumulation(torch_backward):
"""Test that backward is called the correct number of times when accumulating gradients."""
model = BoringModel()
- trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2)
+ trainer = Trainer(
+ max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False
+ )
trainer.fit(model)
assert torch_backward.call_count == 6
torch_backward.reset_mock()
- trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
+ trainer = Trainer(max_steps=6, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 12
@patch("torch.Tensor.backward")
-def test_backward_count_with_closure(torch_backward):
+def test_backward_count_with_closure(torch_backward, tmp_path):
"""Using a closure (e.g. with LBFGS) should lead to no extra backward calls."""
class TestModel(BoringModel):
@@ -45,12 +47,12 @@ def configure_optimizers(self):
return torch.optim.LBFGS(self.parameters(), lr=0.1)
model = TestModel()
- trainer = Trainer(max_steps=5)
+ trainer = Trainer(max_steps=5, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 5
torch_backward.reset_mock()
- trainer = Trainer(max_steps=5, accumulate_grad_batches=2)
+ trainer = Trainer(max_steps=5, accumulate_grad_batches=2, logger=False, enable_checkpointing=False)
trainer.fit(model)
assert torch_backward.call_count == 10
diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py
index 6a9123f2980a6..f0ab8fe401633 100644
--- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py
+++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py
@@ -22,7 +22,6 @@
import torch.distributed as torch_distrib
import torch.nn.functional as F
from lightning.fabric.utilities.exceptions import MisconfigurationException
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
from lightning.pytorch.strategies import Strategy
@@ -31,11 +30,7 @@
def assert_emtpy_grad(grad):
- if _TORCH_GREATER_EQUAL_2_0:
- assert grad is None
- else:
- if grad is not None: # backward has been called
- assert torch.all(grad == 0)
+ assert grad is None
class ManualOptModel(BoringModel):
@@ -915,7 +910,7 @@ def configure_optimizers(self):
return [optimizer], [scheduler]
model = Model()
- trainer = Trainer(accelerator="cpu", max_epochs=0)
+ trainer = Trainer(accelerator="cpu", max_epochs=0, logger=False, enable_checkpointing=False)
if automatic_optimization:
with pytest.raises(MisconfigurationException, match="doesn't follow PyTorch's LRScheduler"):
trainer.fit(model)
diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py
index fbba239303352..319eafeb0d0bb 100644
--- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py
+++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py
@@ -36,7 +36,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
model = TestModel()
model.automatic_optimization = True
- trainer = pl.Trainer()
+ trainer = pl.Trainer(logger=False, enable_checkpointing=False)
with pytest.raises(RuntimeError, match="Remove the `optimizer_idx` argument from `training_step`"):
trainer.fit(model)
@@ -47,7 +47,7 @@ def configure_optimizers(self):
model = TestModel()
model.automatic_optimization = True
- trainer = pl.Trainer()
+ trainer = pl.Trainer(logger=False, enable_checkpointing=False)
with pytest.raises(RuntimeError, match="multiple optimizers is only supported with manual optimization"):
trainer.fit(model)
diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
index f178f8395cd45..76c0c695b3c02 100644
--- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
+++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py
@@ -86,9 +86,9 @@ def test_num_stepping_batches_infinite_training():
@pytest.mark.parametrize("max_steps", [2, 100])
-def test_num_stepping_batches_with_max_steps(max_steps):
+def test_num_stepping_batches_with_max_steps(max_steps, tmp_path):
"""Test stepping batches with `max_steps`."""
- trainer = Trainer(max_steps=max_steps)
+ trainer = Trainer(max_steps=max_steps, default_root_dir=tmp_path, logger=False, enable_checkpointing=False)
model = BoringModel()
trainer.fit(model)
assert trainer.estimated_stepping_batches == max_steps
@@ -152,6 +152,6 @@ def on_train_start(self):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_num_stepping_batches_with_tpu_multi():
"""Test stepping batches with the TPU strategy across multiple devices."""
- trainer = Trainer(accelerator="tpu", devices="auto", max_epochs=1)
+ trainer = Trainer(accelerator="tpu", devices="auto", max_epochs=1, logger=False, enable_checkpointing=False)
model = MultiprocessModel()
trainer.fit(model)
diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py
index cea398eb6f501..a2d29baa9fa6f 100644
--- a/tests/tests_pytorch/trainer/test_dataloaders.py
+++ b/tests/tests_pytorch/trainer/test_dataloaders.py
@@ -641,6 +641,8 @@ def __init__(self):
def training_step(self, batch, batch_idx):
self.batches_seen.append(batch)
+ # the actual training step is not needed for the assertions below
+ return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
def on_train_epoch_end(self):
world_size = 2
@@ -679,7 +681,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmp_path):
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
trainer = Trainer(
- default_root_dir=tmp_path, max_epochs=1, log_every_n_steps=2, limit_train_batches=1, logger=CSVLogger(".")
+ default_root_dir=tmp_path,
+ max_epochs=1,
+ log_every_n_steps=2,
+ limit_train_batches=1,
+ logger=CSVLogger(tmp_path),
)
trainer.fit(model)
@@ -727,7 +733,7 @@ def __len__(self):
@pytest.mark.parametrize("yield_at_all", [False, True])
-def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all):
+def test_iterable_dataset_stop_iteration_at_epoch_beginning(yield_at_all, tmp_path):
"""Test that the training loop skips execution if the iterator is empty from the start."""
class TestDataset(IterableDataset):
@@ -748,7 +754,8 @@ def gen(self):
model = TestModel()
train_dataloader = DataLoader(TestDataset(model.gen), batch_size=2)
trainer = Trainer(
- default_root_dir=os.getcwd(),
+ default_root_dir=tmp_path,
+ logger=False,
max_epochs=2,
enable_model_summary=False,
)
@@ -805,8 +812,10 @@ def __init__(self):
super().__init__()
self.seen_samples = []
- def training_step(self, batch):
+ def training_step(self, batch, batch_idx):
self.seen_samples.extend(batch.tolist())
+ # the actual training step is not needed for the test
+ return super().training_step(torch.rand(1, 32, device=self.device), batch_idx)
def on_train_end(self):
seen_samples = self.all_gather(self.seen_samples)
diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py
index bd5fd1c67e7b6..d89e99c9319c6 100644
--- a/tests/tests_pytorch/trainer/test_states.py
+++ b/tests/tests_pytorch/trainer/test_states.py
@@ -84,5 +84,6 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)
- trainer.fit(model)
+ with pytest.raises(SystemExit):
+ trainer.fit(model)
assert trainer.interrupted
diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py
index 565971e1554b1..8946fb4ed9481 100644
--- a/tests/tests_pytorch/trainer/test_trainer.py
+++ b/tests/tests_pytorch/trainer/test_trainer.py
@@ -24,12 +24,11 @@
from unittest.mock import ANY, Mock, call, patch
import cloudpickle
-import lightning.fabric
-import lightning.pytorch
import pytest
import torch
import torch.nn as nn
from lightning.fabric.utilities.cloud_io import _load as pl_load
+from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.seed import seed_everything
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator
@@ -47,11 +46,10 @@
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
-from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher
+from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
-from lightning.pytorch.utilities.warnings import PossibleUserWarning
from torch.multiprocessing import ProcessRaisedException
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
@@ -105,7 +103,7 @@ def __init__(self, lr=1e-2):
trainer.save_checkpoint(new_weights_path)
# assert ckpt has hparams
- ckpt = torch.load(new_weights_path)
+ ckpt = torch.load(new_weights_path, weights_only=True)
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt, "hyper_parameters missing from checkpoints"
# load new model
@@ -366,7 +364,7 @@ def test_model_checkpoint_only_weights(tmp_path):
checkpoint_path = trainer.checkpoint_callback.best_model_path
# assert saved checkpoint has no trainer data
- checkpoint = torch.load(checkpoint_path)
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
assert "optimizer_states" not in checkpoint, "checkpoint should contain only model weights"
assert "lr_schedulers" not in checkpoint, "checkpoint should contain only model weights"
@@ -377,7 +375,7 @@ def test_model_checkpoint_only_weights(tmp_path):
new_weights_path = os.path.join(tmp_path, "save_test.ckpt")
trainer.save_checkpoint(new_weights_path, weights_only=True)
# assert saved checkpoint has no trainer data
- checkpoint = torch.load(new_weights_path)
+ checkpoint = torch.load(new_weights_path, weights_only=True)
assert "optimizer_states" not in checkpoint, "checkpoint should contain only model weights"
assert "lr_schedulers" not in checkpoint, "checkpoint should contain only model weights"
@@ -1010,7 +1008,8 @@ def on_exception(self, trainer, pl_module, exception):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
- trainer.fit(model)
+ with pytest.raises(SystemExit):
+ trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
@@ -1019,6 +1018,30 @@ def on_exception(self, trainer, pl_module, exception):
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)
+def test_keyboard_interrupt(tmp_path):
+ class InterruptCallback(Callback):
+ def __init__(self):
+ super().__init__()
+
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
+ raise KeyboardInterrupt
+
+ model = BoringModel()
+ trainer = Trainer(
+ callbacks=[InterruptCallback()],
+ barebones=True,
+ default_root_dir=tmp_path,
+ )
+
+ trainer.strategy._launcher = Mock(spec=_SubprocessScriptLauncher)
+ trainer.strategy._launcher.launch = lambda function, *args, trainer, **kwargs: function(*args, **kwargs)
+
+ with pytest.raises(SystemExit) as exc_info:
+ trainer.fit(model)
+ assert exc_info.value.args[0] == 1
+ trainer.strategy._launcher.kill.assert_called_once_with(15 if _IS_WINDOWS else 9)
+
+
@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))])
@RunIf(sklearn=True)
def test_gradient_clipping_by_norm(tmp_path, precision):
@@ -2035,7 +2058,7 @@ def on_fit_start(self):
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
-def test_trainer_calls_strategy_on_exception(exception_type):
+def test_trainer_calls_strategy_on_exception(exception_type, tmp_path):
"""Test that when an exception occurs, the Trainer lets the strategy process it."""
exception = exception_type("Test exception")
@@ -2043,16 +2066,16 @@ class ExceptionModel(BoringModel):
def on_fit_start(self):
raise exception
- trainer = Trainer()
+ trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
- Exception
+ Exception, SystemExit
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)
@pytest.mark.parametrize("exception_type", [KeyboardInterrupt, RuntimeError])
-def test_trainer_calls_datamodule_on_exception(exception_type):
+def test_trainer_calls_datamodule_on_exception(exception_type, tmp_path):
"""Test that when an exception occurs, the Trainer lets the data module process it."""
exception = exception_type("Test exception")
@@ -2062,9 +2085,9 @@ def on_fit_start(self):
datamodule = BoringDataModule()
datamodule.on_exception = Mock()
- trainer = Trainer()
+ trainer = Trainer(default_root_dir=tmp_path)
- with suppress(Exception):
+ with suppress(Exception, SystemExit):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)
@@ -2080,12 +2103,6 @@ def test_init_module_context(monkeypatch):
strategy.tensor_init_context.assert_called_once_with(empty_init=None)
strategy.tensor_init_context.reset_mock()
- # Pretend we are using PyTorch < 2.0
- monkeypatch.setattr(lightning.pytorch.trainer.trainer, "_TORCH_GREATER_EQUAL_2_0", False)
- with pytest.warns(PossibleUserWarning, match="can't place .* on the device"), trainer.init_module():
- pass
- strategy.tensor_init_context.assert_called_once()
-
def test_expand_home_trainer():
"""Test that the dirpath gets expanded if it contains `~`."""
diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py
index a0d1d70aa36c4..a31be67911409 100644
--- a/tests/tests_pytorch/tuner/test_lr_finder.py
+++ b/tests/tests_pytorch/tuner/test_lr_finder.py
@@ -434,6 +434,7 @@ def lr_find(self, trainer, pl_module) -> None:
super().lr_find(trainer, pl_module)
pl_module._expected_max_steps = None
assert not trainer.fit_loop.restarting
+ assert not trainer.fit_loop.epoch_loop.restarting
def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py
index 6ab81c581afa3..8dd66fe9bfcff 100644
--- a/tests/tests_pytorch/tuner/test_scale_batch_size.py
+++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py
@@ -438,7 +438,7 @@ class CustomModel(BoringModel):
def val_dataloader(self):
return [super().val_dataloader(), super().val_dataloader()]
- trainer = Trainer()
+ trainer = Trainer(logger=False, enable_checkpointing=False)
tuner = Tuner(trainer)
model = CustomModel()
diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py
index 9aad2aebffec9..9680c90a94c5b 100644
--- a/tests/tests_pytorch/utilities/migration/test_migration.py
+++ b/tests/tests_pytorch/utilities/migration/test_migration.py
@@ -93,7 +93,7 @@ def test_migrate_loop_batches_that_stepped(tmp_path, model_class):
ckpt_path = trainer.checkpoint_callback.best_model_path
# pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then
- ckpt = torch.load(ckpt_path)
+ ckpt = torch.load(ckpt_path, weights_only=True)
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
_set_version(ckpt, "1.6.4")
torch.save(ckpt, ckpt_path)
diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py
index d90cc90126195..56b8701cfcfc2 100644
--- a/tests/tests_pytorch/utilities/migration/test_utils.py
+++ b/tests/tests_pytorch/utilities/migration/test_utils.py
@@ -84,7 +84,7 @@ def test_test_patch_legacy_imports_standalone(pl_version):
path_ckpt = path_ckpts[-1]
with no_warning_call(match="Redirecting import of*"), pl_legacy_patch():
- torch.load(path_ckpt)
+ torch.load(path_ckpt, weights_only=False)
assert any(
key.startswith("pytorch_lightning") for key in sys.modules
@@ -117,7 +117,7 @@ def test_patch_legacy_imports_unified(pl_version):
else:
context = no_warning_call(match="Redirecting import of*")
with context, pl_legacy_patch():
- torch.load(path_ckpt)
+ torch.load(path_ckpt, weights_only=False)
assert any(
key.startswith("lightning." + "pytorch") for key in sys.modules
diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py
index c363bbc94cf8a..67f992421f7ce 100644
--- a/tests/tests_pytorch/utilities/test_compile.py
+++ b/tests/tests_pytorch/utilities/test_compile.py
@@ -11,19 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import sys
+from contextlib import nullcontext
from unittest import mock
import pytest
import torch
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
-from lightning_utilities.core import module_available
+from lightning_utilities.core.imports import RequirementCache
from tests_pytorch.conftest import mock_cuda_count
from tests_pytorch.helpers.runif import RunIf
+_PYTHON_GREATER_EQUAL_3_9_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 9)
+
# https://github.com/pytorch/pytorch/issues/95708
@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
@@ -64,10 +69,20 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
assert trainer.model._compiler_ctx is None
# some strategies do not support it
- if module_available("deepspeed"):
+ if RequirementCache("deepspeed"):
compiled_model = torch.compile(model)
mock_cuda_count(monkeypatch, 2)
- trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
+
+ # TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
+ warn_context = (
+ pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
+ if _TORCH_GREATER_EQUAL_2_4
+ else nullcontext()
+ )
+
+ with warn_context:
+ trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
+
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
trainer.fit(compiled_model)
@@ -114,7 +129,12 @@ def has_dynamo(fn):
# https://github.com/pytorch/pytorch/issues/95708
@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
+@pytest.mark.skipif(not _PYTHON_GREATER_EQUAL_3_9_0, reason="AssertionError: failed to reach fixed point")
+@pytest.mark.xfail(
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
+)
@RunIf(dynamo=True)
+@mock.patch.dict(os.environ, {})
def test_trainer_compiled_model_that_logs(tmp_path):
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
@@ -140,7 +160,12 @@ def training_step(self, batch, batch_idx):
# https://github.com/pytorch/pytorch/issues/95708
@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found")
+@pytest.mark.skipif(not _PYTHON_GREATER_EQUAL_3_9_0, reason="AssertionError: failed to reach fixed point")
+@pytest.mark.xfail(
+ sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
+)
@RunIf(dynamo=True)
+@mock.patch.dict(os.environ, {})
def test_trainer_compiled_model_test(tmp_path):
model = BoringModel()
compiled_model = torch.compile(model)
diff --git a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py
index 41f32f007f93e..a44ed655c2e61 100644
--- a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py
+++ b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py
@@ -49,7 +49,7 @@ def test_deepspeed_collate_checkpoint(tmp_path):
def _assert_checkpoint_equal(model, output_path):
assert os.path.exists(output_path)
- single_output = torch.load(output_path)
+ single_output = torch.load(output_path, weights_only=False)
state_dict = model.state_dict()
for orig_param, saved_model_param in zip(state_dict.values(), single_output["state_dict"].values()):
if model.dtype == torch.half:
diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py
index 290dfb67faf7d..cced6546aab75 100644
--- a/tests/tests_pytorch/utilities/test_model_summary.py
+++ b/tests/tests_pytorch/utilities/test_model_summary.py
@@ -13,11 +13,11 @@
# limitations under the License.
from collections import OrderedDict
from typing import Any
+from unittest import mock
import pytest
import torch
import torch.nn as nn
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.model_summary.model_summary import (
@@ -294,10 +294,6 @@ def __init__(self):
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
- if isinstance(example_input, dict) and not _TORCH_GREATER_EQUAL_2_0:
- # kwargs are not supported when torch < 2.0
- expected_size = UNKNOWN_SIZE
-
model = DummyLightningModule()
model.example_input_array = example_input
summary = summarize(model, max_depth=max_depth)
@@ -350,6 +346,18 @@ def test_lazy_model_summary():
assert summary.trainable_parameters == 0
+@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
+def test_dtensor_model_summary(_):
+ """Test that the model summary can work with layers that have DTensor parameters."""
+ # We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
+ # would require setting up distributed
+ dtensor_model = UnorderedModel()
+ summary = ModelSummary(dtensor_model)
+ assert summary.total_layer_params > 0
+ assert summary.total_parameters > 0
+ assert summary.trainable_parameters > 0
+
+
@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown."""
@@ -428,6 +436,29 @@ def forward(self, x):
assert not model.layer2.training
+def test_total_training_modes():
+ """Test that the `total_training_modes` counts the modules in 'train' and 'eval' mode, excluding the root
+ module."""
+
+ class ModelWithoutChildren(LightningModule):
+ pass
+
+ summary = ModelSummary(ModelWithoutChildren())
+ assert summary.total_training_modes == {"train": 0, "eval": 0}
+
+ model = DeepNestedModel()
+ summary = ModelSummary(model)
+ assert summary.total_training_modes == {"train": 19, "eval": 0}
+ assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1
+
+ model = DeepNestedModel()
+ summary = ModelSummary(model)
+ model.branch1[1][0].eval()
+ model.branch2.eval()
+ assert summary.total_training_modes == {"train": 17, "eval": 2}
+ assert sum(summary.total_training_modes.values()) == len(list(model.modules())) - 1
+
+
def test_summary_training_mode():
"""Test that the model summary captures the training mode on all submodules."""
model = DeepNestedModel()
@@ -441,6 +472,7 @@ def test_summary_training_mode():
"eval", # branch2
"train", # head
]
+ assert summary.total_training_modes == {"train": 17, "eval": 2}
summary = summarize(model, max_depth=-1)
expected_eval = {"branch1.1.0", "branch2"}
@@ -450,5 +482,7 @@ def test_summary_training_mode():
# A model with params not belonging to a layer
model = NonLayerParamsModel()
model.layer.eval()
- summary_data = OrderedDict(summarize(model)._get_summary_data())
+ summary = summarize(model)
+ summary_data = OrderedDict(summary._get_summary_data())
assert summary_data["Mode"] == ["eval", "n/a"]
+ assert summary.total_training_modes == {"train": 0, "eval": 1}
diff --git a/tests/tests_store/__init__.py b/tests/tests_store/__init__.py
deleted file mode 100644
index e69de29bb2d1d..0000000000000
diff --git a/tests/tests_store/test_store.py b/tests/tests_store/test_store.py
deleted file mode 100644
index 308ccda30c489..0000000000000
--- a/tests/tests_store/test_store.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import os
-from unittest import mock
-
-from lightning.store import download_model, list_models, upload_model
-from lightning_cloud.openapi import (
- V1DownloadModelResponse,
- V1GetUserResponse,
- V1ListMembershipsResponse,
- V1ListModelsResponse,
- V1Membership,
- V1Model,
- V1Project,
- V1UploadModelRequest,
- V1UploadModelResponse,
-)
-
-
-@mock.patch("lightning.store.store._Client")
-@mock.patch("lightning.store.store._upload_file_to_url")
-def test_upload_model(mock_upload_file_to_url, mock_client):
- mock_client = mock_client()
-
- mock_client.auth_service_get_user.return_value = V1GetUserResponse(username="test-username")
-
- # either one of these project APIs could be called
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(project_id="test-project-id")],
- )
- mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
-
- mock_client.models_store_upload_model.return_value = V1UploadModelResponse(
- upload_url="https://test",
- )
-
- upload_model("test-model", "test.ckpt", version="0.0.1")
-
- mock_client.auth_service_get_user.assert_called_once()
- mock_client.models_store_upload_model.assert_called_once_with(
- V1UploadModelRequest(
- name="test-username/test-model",
- version="0.0.1",
- project_id="test-project-id",
- )
- )
-
- mock_upload_file_to_url.assert_called_once_with("https://test", "test.ckpt", progress_bar=True)
-
-
-@mock.patch("lightning.store.store._Client")
-@mock.patch("lightning.store.store._download_file_from_url")
-def test_download_model(mock_download_file_from_url, mock_client):
- mock_client = mock_client()
-
- mock_client.models_store_download_model.return_value = V1DownloadModelResponse(
- download_url="https://test",
- )
-
- download_model("test-username/test-model", "test.ckpt", version="0.0.1")
-
- mock_client.models_store_download_model.assert_called_once_with(
- name="test-username/test-model",
- version="0.0.1",
- )
-
- mock_download_file_from_url.assert_called_once_with("https://test", os.path.abspath("test.ckpt"), progress_bar=True)
-
-
-@mock.patch("lightning.store.store._Client")
-def test_list_models(mock_client):
- mock_client = mock_client()
-
- # either one of these project APIs could be called
- mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
- memberships=[V1Membership(project_id="test-project-id")],
- )
- mock_client.projects_service_get_project.return_value = V1Project(id="test-project-id")
-
- mock_client.models_store_list_models.return_value = V1ListModelsResponse(models=[V1Model(name="test-model")])
-
- res = list_models()
- assert res[0].name == "test-model"
-
- mock_client.models_store_list_models.assert_called_once_with(project_id="test-project-id")