Skip to content

Commit f588893

Browse files
awaelchlilexierule
authored andcommitted
Set accelerator through CLI only if set explicitly (#16818)
1 parent 41a5c2a commit f588893

File tree

4 files changed

+18
-9
lines changed

4 files changed

+18
-9
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323
### Fixed
2424

2525
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
26+
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
2627

2728

2829
## [1.9.2] - 2023-02-15

src/lightning_fabric/cli.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
@click.option(
4646
"--accelerator",
4747
type=click.Choice(_SUPPORTED_ACCELERATORS),
48-
default="cpu",
48+
default=None,
4949
help="The hardware accelerator to run on.",
5050
)
5151
@click.option(
@@ -97,7 +97,7 @@
9797
@click.option(
9898
"--precision",
9999
type=click.Choice(_SUPPORTED_PRECISION),
100-
default="32",
100+
default=None,
101101
help=(
102102
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
103103
" (``'bf16'``)"
@@ -122,12 +122,14 @@ def _set_env_variables(args: Namespace) -> None:
122122
The Fabric connector will parse the arguments set here.
123123
"""
124124
os.environ["LT_CLI_USED"] = "1"
125-
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
125+
if args.accelerator is not None:
126+
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
126127
if args.strategy is not None:
127128
os.environ["LT_STRATEGY"] = str(args.strategy)
128129
os.environ["LT_DEVICES"] = str(args.devices)
129130
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
130-
os.environ["LT_PRECISION"] = str(args.precision)
131+
if args.precision is not None:
132+
os.environ["LT_PRECISION"] = str(args.precision)
131133

132134

133135
def _get_num_processes(accelerator: str, devices: str) -> int:

tests/tests_fabric/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
3636
_run_model.main([fake_script])
3737
assert e.value.code == 0
3838
assert os.environ["LT_CLI_USED"] == "1"
39-
assert os.environ["LT_ACCELERATOR"] == "cpu"
39+
assert "LT_ACCELERATOR" not in os.environ
4040
assert "LT_STRATEGY" not in os.environ
4141
assert os.environ["LT_DEVICES"] == "1"
4242
assert os.environ["LT_NUM_NODES"] == "1"
43-
assert os.environ["LT_PRECISION"] == "32"
43+
assert "LT_PRECISION" not in os.environ
4444

4545

4646
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])

tests/tests_fabric/test_connector.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -799,18 +799,22 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
799799
assert isinstance(connector.strategy, strategy_cls)
800800

801801

802-
@pytest.mark.parametrize("precision", ["64", "32", "16", pytest.param("bf16", marks=RunIf(min_torch="1.10"))])
802+
@pytest.mark.parametrize("precision", [None, "64", "32", "16", pytest.param("bf16", marks=RunIf(min_torch="1.10"))])
803803
@mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=1)
804804
def test_precision_from_environment(_, precision):
805805
"""Test that the precision input can be set through the environment variable."""
806-
with mock.patch.dict(os.environ, {"LT_PRECISION": precision}):
806+
env_vars = {}
807+
if precision is not None:
808+
env_vars["LT_PRECISION"] = precision
809+
with mock.patch.dict(os.environ, env_vars):
807810
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
808811
assert isinstance(connector.precision, Precision)
809812

810813

811814
@pytest.mark.parametrize(
812815
"accelerator, strategy, expected_accelerator, expected_strategy",
813816
[
817+
(None, None, CPUAccelerator, SingleDeviceStrategy),
814818
("cpu", None, CPUAccelerator, SingleDeviceStrategy),
815819
("cpu", "ddp", CPUAccelerator, DDPStrategy),
816820
pytest.param("mps", None, MPSAccelerator, SingleDeviceStrategy, marks=RunIf(mps=True)),
@@ -822,7 +826,9 @@ def test_precision_from_environment(_, precision):
822826
)
823827
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
824828
"""Test that the accelerator and strategy input can be set through the environment variables."""
825-
env_vars = {"LT_ACCELERATOR": accelerator}
829+
env_vars = {}
830+
if accelerator is not None:
831+
env_vars["LT_ACCELERATOR"] = accelerator
826832
if strategy is not None:
827833
env_vars["LT_STRATEGY"] = strategy
828834

0 commit comments

Comments
 (0)