@@ -870,16 +870,25 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls):
870870 assert isinstance (connector .strategy , strategy_cls )
871871
872872
873- @pytest .mark .parametrize ("precision" , [None , "64-true" , "32-true" , "16-mixed" , "bf16-mixed" ])
873+ @pytest .mark .parametrize (
874+ ("precision" , "expected" ),
875+ [
876+ (None , Precision ),
877+ ("64-true" , DoublePrecision ),
878+ ("32-true" , Precision ),
879+ ("16-true" , HalfPrecision ),
880+ ("16-mixed" , MixedPrecision ),
881+ ],
882+ )
874883@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 1 )
875- def test_precision_from_environment (_ , precision ):
884+ def test_precision_from_environment (_ , precision , expected ):
876885 """Test that the precision input can be set through the environment variable."""
877- env_vars = {}
886+ env_vars = {"LT_CLI_USED" : "1" }
878887 if precision is not None :
879888 env_vars ["LT_PRECISION" ] = precision
880889 with mock .patch .dict (os .environ , env_vars ):
881890 connector = _Connector (accelerator = "cuda" ) # need to use cuda, because AMP not available on CPU
882- assert isinstance (connector .precision , Precision )
891+ assert isinstance (connector .precision , expected )
883892
884893
885894@pytest .mark .parametrize (
@@ -897,7 +906,7 @@ def test_precision_from_environment(_, precision):
897906)
898907def test_accelerator_strategy_from_environment (accelerator , strategy , expected_accelerator , expected_strategy ):
899908 """Test that the accelerator and strategy input can be set through the environment variables."""
900- env_vars = {}
909+ env_vars = {"LT_CLI_USED" : "1" }
901910 if accelerator is not None :
902911 env_vars ["LT_ACCELERATOR" ] = accelerator
903912 if strategy is not None :
@@ -912,7 +921,7 @@ def test_accelerator_strategy_from_environment(accelerator, strategy, expected_a
912921@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 8 )
913922def test_devices_from_environment (* _ ):
914923 """Test that the devices and number of nodes can be set through the environment variables."""
915- with mock .patch .dict (os .environ , {"LT_DEVICES" : "2" , "LT_NUM_NODES" : "3" }):
924+ with mock .patch .dict (os .environ , {"LT_DEVICES" : "2" , "LT_NUM_NODES" : "3" , "LT_CLI_USED" : "1" }):
916925 connector = _Connector (accelerator = "cuda" )
917926 assert isinstance (connector .accelerator , CUDAAccelerator )
918927 assert isinstance (connector .strategy , DDPStrategy )
0 commit comments