@@ -102,24 +102,57 @@ def test_detect(monkeypatch):
102102@mock .patch .dict (os .environ , {}, clear = True )
103103@mock .patch ("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1" , True )
104104@mock .patch ("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1" , True )
105- def test_attributes_from_xla_greater_21_used (xla_available , monkeypatch ):
106- """Test XLA environment attributes when using XLA runtime >= 2.1."""
105+ def test_world_size_from_xla_runtime_greater_2_1 (xla_available ):
106+ """Test that world_size uses torch_xla.runtime when XLA >= 2.1."""
107+ env = XLAEnvironment ()
108+
109+ with mock .patch ("torch_xla.runtime.world_size" , return_value = 4 ) as mock_world_size :
110+ env .world_size .cache_clear ()
111+ assert env .world_size () == 4
112+ mock_world_size .assert_called_once ()
107113
114+
115+ @mock .patch .dict (os .environ , {}, clear = True )
116+ @mock .patch ("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1" , True )
117+ @mock .patch ("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1" , True )
118+ def test_global_rank_from_xla_runtime_greater_2_1 (xla_available ):
119+ """Test that global_rank uses torch_xla.runtime when XLA >= 2.1."""
120+ env = XLAEnvironment ()
121+
122+ with mock .patch ("torch_xla.runtime.global_ordinal" , return_value = 2 ) as mock_global_ordinal :
123+ env .global_rank .cache_clear ()
124+ assert env .global_rank () == 2
125+ mock_global_ordinal .assert_called_once ()
126+
127+
128+ @mock .patch .dict (os .environ , {}, clear = True )
129+ @mock .patch ("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1" , True )
130+ @mock .patch ("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1" , True )
131+ def test_local_rank_from_xla_runtime_greater_2_1 (xla_available ):
132+ """Test that local_rank uses torch_xla.runtime when XLA >= 2.1."""
133+ env = XLAEnvironment ()
134+
135+ with mock .patch ("torch_xla.runtime.local_ordinal" , return_value = 1 ) as mock_local_ordinal :
136+ env .local_rank .cache_clear ()
137+ assert env .local_rank () == 1
138+ mock_local_ordinal .assert_called_once ()
139+
140+
141+ @mock .patch .dict (os .environ , {}, clear = True )
142+ @mock .patch ("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1" , True )
143+ @mock .patch ("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1" , True )
144+ def test_setters_readonly_when_xla_runtime_greater_2_1 (xla_available ):
145+ """Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1."""
108146 env = XLAEnvironment ()
109147
110148 with (
111149 mock .patch ("torch_xla.runtime.world_size" , return_value = 4 ),
112150 mock .patch ("torch_xla.runtime.global_ordinal" , return_value = 2 ),
113- mock .patch ("torch_xla.runtime.local_ordinal" , return_value = 1 ),
114151 ):
115152 env .world_size .cache_clear ()
116153 env .global_rank .cache_clear ()
117- env .local_rank .cache_clear ()
118-
119- assert env .world_size () == 4
120- assert env .global_rank () == 2
121- assert env .local_rank () == 1
122154
155+ # Values should come from XLA runtime and not be affected by setters
123156 env .set_world_size (100 )
124157 assert env .world_size () == 4
125158
0 commit comments