Skip to content

Commit d947b01

Browse files
carmoccaawaelchli
authored andcommitted
Fix precision default from environment (#18928)
Co-authored-by: awaelchli <[email protected]> (cherry picked from commit 466f772)
1 parent 4b81af6 commit d947b01

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/lightning/fabric/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
strategy = self._argument_from_env("strategy", strategy, default="auto")
115115
devices = self._argument_from_env("devices", devices, default="auto")
116116
num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1))
117-
precision = self._argument_from_env("precision", precision, default="32-true")
117+
precision = self._argument_from_env("precision", precision, default=None)
118118

119119
# 1. Parsing flags
120120
# Get registered strategies, built-in accelerators and precision plugins

tests/tests_fabric/test_connector.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -870,16 +870,25 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
870870
assert isinstance(connector.strategy, strategy_cls)
871871

872872

873-
@pytest.mark.parametrize("precision", [None, "64-true", "32-true", "16-mixed", "bf16-mixed"])
873+
@pytest.mark.parametrize(
874+
("precision", "expected"),
875+
[
876+
(None, Precision),
877+
("64-true", DoublePrecision),
878+
("32-true", Precision),
879+
("16-true", HalfPrecision),
880+
("16-mixed", MixedPrecision),
881+
],
882+
)
874883
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
875-
def test_precision_from_environment(_, precision):
884+
def test_precision_from_environment(_, precision, expected):
876885
"""Test that the precision input can be set through the environment variable."""
877-
env_vars = {}
886+
env_vars = {"LT_CLI_USED": "1"}
878887
if precision is not None:
879888
env_vars["LT_PRECISION"] = precision
880889
with mock.patch.dict(os.environ, env_vars):
881890
connector = _Connector(accelerator="cuda") # need to use cuda, because AMP not available on CPU
882-
assert isinstance(connector.precision, Precision)
891+
assert isinstance(connector.precision, expected)
883892

884893

885894
@pytest.mark.parametrize(
@@ -897,7 +906,7 @@ def test_precision_from_environment(_, precision):
897906
)
898907
def test_accelerator_strategy_from_environment(accelerator, strategy, expected_accelerator, expected_strategy):
899908
"""Test that the accelerator and strategy input can be set through the environment variables."""
900-
env_vars = {}
909+
env_vars = {"LT_CLI_USED": "1"}
901910
if accelerator is not None:
902911
env_vars["LT_ACCELERATOR"] = accelerator
903912
if strategy is not None:
@@ -912,7 +921,7 @@ def test_accelerator_strategy_from_environment(accelerator, strategy, expected_a
912921
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=8)
913922
def test_devices_from_environment(*_):
914923
"""Test that the devices and number of nodes can be set through the environment variables."""
915-
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3"}):
924+
with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_NUM_NODES": "3", "LT_CLI_USED": "1"}):
916925
connector = _Connector(accelerator="cuda")
917926
assert isinstance(connector.accelerator, CUDAAccelerator)
918927
assert isinstance(connector.strategy, DDPStrategy)

0 commit comments

Comments
 (0)