@@ -97,3 +97,34 @@ def test_detect(monkeypatch):
9797
9898 monkeypatch .setattr (lightning .fabric .accelerators .xla .XLAAccelerator , "is_available" , lambda : True )
9999 assert XLAEnvironment .detect ()
100+
101+
102+ @mock .patch .dict (os .environ , {}, clear = True )
103+ def test_attributes_from_xla_greater_21_used (xla_available , monkeypatch ):
104+ """Test XLA environment attributes when using XLA runtime >= 2.1."""
105+ monkeypatch .setattr (lightning .fabric .accelerators .xla , "_XLA_GREATER_EQUAL_2_1" , True )
106+ monkeypatch .setattr (lightning .fabric .plugins .environments .xla , "_XLA_GREATER_EQUAL_2_1" , True )
107+
108+ env = XLAEnvironment ()
109+
110+ with (
111+ mock .patch ("torch_xla.runtime.world_size" , return_value = 4 ),
112+ mock .patch ("torch_xla.runtime.global_ordinal" , return_value = 2 ),
113+ mock .patch ("torch_xla.runtime.local_ordinal" , return_value = 1 ),
114+ ):
115+ env .world_size .cache_clear ()
116+ env .global_rank .cache_clear ()
117+ env .local_rank .cache_clear ()
118+
119+ assert env .world_size () == 4
120+ assert env .global_rank () == 2
121+ assert env .local_rank () == 1
122+
123+ env .set_world_size (100 )
124+ assert env .world_size () == 4
125+
126+ env .set_global_rank (100 )
127+ assert env .global_rank () == 2
128+
129+ env .set_local_rank (100 )
130+ assert env .local_rank () == 1
0 commit comments