@@ -128,14 +128,14 @@ def test_basic_arithmetic_operations(self):
128128 self .assertEqual (len (result_add ._local_tensors ), 2 )
129129
130130 # Verify the operation was applied to each local tensor
131- for rank in identical_local_tensors . keys () :
131+ for rank in identical_local_tensors :
132132 expected = identical_local_tensors [rank ] + identical_local_tensors [rank ]
133133 self .assertEqual (result_add ._local_tensors [rank ], expected )
134134
135135 # Test multiplication
136136 result_mul = lt1 * 2.0
137137 self .assertIsInstance (result_mul , LocalTensor )
138- for rank in identical_local_tensors . keys () :
138+ for rank in identical_local_tensors :
139139 expected = identical_local_tensors [rank ] * 2.0
140140 self .assertEqual (result_mul ._local_tensors [rank ], expected )
141141
@@ -163,7 +163,7 @@ def test_mixed_operations_with_regular_tensors(self):
163163 result = lt + regular_tensor
164164 self .assertIsInstance (result , LocalTensor )
165165
166- for rank in identical_local_tensors . keys () :
166+ for rank in identical_local_tensors :
167167 expected = identical_local_tensors [rank ] + regular_tensor
168168 self .assertEqual (result ._local_tensors [rank ], expected )
169169
@@ -212,14 +212,14 @@ def test_collectives_within_local_tensor_mode(self):
212212 dist .all_reduce (lt_sum , group = fake_pg )
213213
214214 expected_sum = torch .tensor ([[6.0 , 8.0 ], [10.0 , 12.0 ]])
215- for rank in test_tensors . keys () :
215+ for rank in test_tensors :
216216 self .assertEqual (lt_sum ._local_tensors [rank ], expected_sum )
217217
218218 # Test broadcast within mode
219219 lt_broadcast = LocalTensor ({k : v .clone () for k , v in test_tensors .items ()})
220220 dist .broadcast (lt_broadcast , src = 0 , group = fake_pg )
221221
222- for rank in test_tensors . keys () :
222+ for rank in test_tensors :
223223 self .assertEqual (lt_broadcast ._local_tensors [rank ], test_tensors [0 ])
224224
225225 # Test that regular operations still work
@@ -293,21 +293,21 @@ def test_collective_reduction_operations(self):
293293 lt_sum = LocalTensor ({k : v .clone () for k , v in test_tensors .items ()})
294294 dist .all_reduce (lt_sum , op = dist .ReduceOp .SUM , group = fake_pg )
295295 expected_sum = torch .tensor ([[6.0 , 7.0 ], [6.0 , 15.0 ]]) # Sum of all tensors
296- for rank in test_tensors . keys () :
296+ for rank in test_tensors :
297297 self .assertEqual (lt_sum ._local_tensors [rank ], expected_sum )
298298
299299 # Test MAX reduction
300300 lt_max = LocalTensor ({k : v .clone () for k , v in test_tensors .items ()})
301301 dist .all_reduce (lt_max , op = dist .ReduceOp .MAX , group = fake_pg )
302302 expected_max = torch .tensor ([[3.0 , 4.0 ], [3.0 , 6.0 ]]) # Max across all tensors
303- for rank in test_tensors . keys () :
303+ for rank in test_tensors :
304304 self .assertEqual (lt_max ._local_tensors [rank ], expected_max )
305305
306306 # Test MIN reduction
307307 lt_min = LocalTensor ({k : v .clone () for k , v in test_tensors .items ()})
308308 dist .all_reduce (lt_min , op = dist .ReduceOp .MIN , group = fake_pg )
309309 expected_min = torch .tensor ([[1.0 , 1.0 ], [1.0 , 4.0 ]]) # Min across all tensors
310- for rank in test_tensors . keys () :
310+ for rank in test_tensors :
311311 self .assertEqual (lt_min ._local_tensors [rank ], expected_min )
312312
313313 def test_all_reduce_collective (self ):
@@ -328,7 +328,7 @@ def test_all_reduce_collective(self):
328328
329329 # Verify all ranks have the sum of all tensors (after adding 1 to each)
330330 expected_sum = torch .tensor ([[114.0 , 225.0 , 336.0 ], [447.0 , 558.0 , 669.0 ]])
331- for rank in different_tensors . keys () :
331+ for rank in different_tensors :
332332 self .assertEqual (lt_sum ._local_tensors [rank ], expected_sum )
333333
334334 def test_broadcast_collective (self ):
@@ -348,7 +348,7 @@ def test_broadcast_collective(self):
348348
349349 # Verify all ranks have rank 1's original tensor
350350 expected_broadcast = different_tensors [1 ]
351- for rank in different_tensors . keys () :
351+ for rank in different_tensors :
352352 self .assertEqual (lt_broadcast ._local_tensors [rank ], expected_broadcast )
353353
354354 def test_all_gather_collective (self ):
0 commit comments