Skip to content

Commit fba8bf3

Browse files
awaelchlipre-commit-ci[bot]Borda
authored andcommitted
Fix failing lightning cli entry point (#18821)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 9e75bc9)
1 parent fb37826 commit fba8bf3

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

src/lightning/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Root package info."""
22
import logging
3+
import sys
34

45
# explicitly don't set root logger's propagation and leave this to subpackages to manage
56
_logger = logging.getLogger(__name__)
@@ -28,3 +29,19 @@
2829
"seed_everything",
2930
"Fabric",
3031
]
32+
33+
34+
def _cli_entry_point() -> None:
35+
from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache
36+
37+
if not (
38+
ModuleAvailableCache("lightning.app")
39+
if RequirementCache("lightning-utilities<0.10.0")
40+
else RequirementCache(module="lightning.app") # type: ignore[call-arg]
41+
):
42+
print("The `lightning` command requires additional dependencies: `pip install lightning[app]`")
43+
sys.exit(1)
44+
45+
from lightning.app.cli.lightning_cli import main
46+
47+
main()

src/lightning/__setup__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _setup_args() -> Dict[str, Any]:
115115
"python_requires": ">=3.8", # todo: take the lowes based on all packages
116116
"entry_points": {
117117
"console_scripts": [
118-
"lightning = lightning.app.cli.lightning_cli:main",
118+
"lightning = lightning:_cli_entry_point",
119119
],
120120
},
121121
"setup_requires": [],

tests/tests_fabric/test_cli.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414
import contextlib
1515
import os
16+
import subprocess
1617
from io import StringIO
1718
from unittest import mock
1819
from unittest.mock import Mock
1920

2021
import pytest
2122
import torch.distributed.run
2223
from lightning.fabric.cli import _get_supported_strategies, _run_model
24+
from lightning_utilities.core.imports import ModuleAvailableCache
2325

2426
from tests_fabric.helpers.runif import RunIf
2527

@@ -172,3 +174,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
172174
fake_script,
173175
]
174176
)
177+
178+
179+
@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
180+
def test_cli_through_lightning_entry_point():
181+
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)
182+
if not ModuleAvailableCache("lightning.app"):
183+
message = "The `lightning` command requires additional dependencies"
184+
assert message in result.stdout or message in result.stderr
185+
assert result.returncode != 0
186+
else:
187+
message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
188+
assert message in result.stdout or message in result.stderr

0 commit comments

Comments
 (0)