33import sys
44import unittest
55from functools import partial , wraps
6+ from unittest .mock import patch
67
78import torch
89import torch .distributed as dist
910import torch .distributed ._functional_collectives as ft_c
1011import torch .distributed .distributed_c10d as c10d
1112import torch .distributed .tensor as dt
1213from functorch import make_fx
14+ from torch ._dynamo .metrics_context import MetricsContext
1315from torch ._inductor .utils import run_and_get_code
1416from torch .testing import FileCheck
1517from torch .testing ._internal .common_device_type import instantiate_device_type_tests
3133 instantiate_parametrized_tests ,
3234 parametrize ,
3335 run_tests ,
34- skipIfHpu ,
3536 TEST_CUDA ,
3637 TEST_HPU ,
3738 TestCase ,
@@ -90,7 +91,7 @@ def new_subgroups(group_size: int, pg_tag=None):
9091 return cur_subgroup , subgroups
9192
9293
93- @skipIfHpu
94+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
9495class TestExpand (MultiThreadedTestCase ):
9596 @property
9697 def world_size (self ):
@@ -180,7 +181,7 @@ def test_expand_device_mesh_tuple(self):
180181 self .assertEqual (2 , group_size )
181182
182183
183- @skipIfHpu
184+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
184185class TestPgTag (MultiThreadedTestCase ):
185186 @property
186187 def world_size (self ):
@@ -257,7 +258,7 @@ def test_find_root_pg(self):
257258
258259
259260@instantiate_parametrized_tests
260- @skipIfHpu
261+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
261262class TestTraceableCollectives (MultiThreadedTestCase ):
262263 @property
263264 def world_size (self ):
@@ -403,7 +404,7 @@ def test_all_reduce(self):
403404 self .assertEqual (x .size (), out .size ())
404405
405406
406- @skipIfHpu
407+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
407408class TestGradCollectives (MultiThreadedTestCase ):
408409 @property
409410 def world_size (self ):
@@ -656,7 +657,7 @@ def test_permute_tensor_with_sub_group(self, device):
656657
657658
658659@instantiate_parametrized_tests
659- @skipIfHpu
660+ @unittest . skipIf ( TEST_HPU , "Unsupported on HPU" )
660661class TestFunctionalAutograd (MultiThreadedTestCase ):
661662 def setUp (self ):
662663 super ().setUp ()
@@ -666,6 +667,13 @@ def setUp(self):
666667 def world_size (self ):
667668 return 2
668669
670+ # `compilation_metric` attempts to update the `is_forward` field of `metrics_context`. Since
671+ # `metrics_context` is a singleton, a runtime error will occur if multiple threads try to update it
672+ # because `MetricsContext` does not allow updating existing fields when `overwrite` is False.
673+ # So, we need to patch the `update` function of MetricsContext
674+ def _metrics_context_update (self , * args , ** kwargs ) -> None :
675+ pass
676+
669677 @parametrize ("compile" , [True , False ])
670678 def test_all_to_all_single (self , compile : bool = True ) -> None :
671679 group = dist .group .WORLD .group_name
@@ -691,7 +699,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
691699 self .assertIsNotNone (out .grad_fn )
692700 self .assertTrue (out .requires_grad )
693701 loss = out .sum ()
694- loss .backward ()
702+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
703+ loss .backward ()
695704 self .assertEqual (t .grad , torch .full_like (t , 2.0 ))
696705
697706 def test_all_to_all_single_inductor (self ) -> None :
@@ -711,7 +720,8 @@ def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
711720
712721 def run_with_backward ():
713722 out = compiled (t , self .world_size )
714- out .backward ()
723+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
724+ out .backward ()
715725
716726 _ , codes = run_and_get_code (run_with_backward )
717727 for code in codes :
@@ -751,7 +761,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
751761 gathered_tensor = compiled (local_tensor , dim )
752762 self .assertEqual (gathered_tensor , torch .ones (output_size ))
753763
754- gathered_tensor .sum ().backward ()
764+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
765+ gathered_tensor .sum ().backward ()
755766 self .assertEqual (
756767 local_tensor .grad ,
757768 torch .full ((3 , 3 , 3 ), fill_value = float (self .world_size )),
@@ -786,7 +797,8 @@ def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
786797 rs_tensor = compiled (input_tensor , dim )
787798 res_num = 1 * group_size
788799 self .assertEqual (rs_tensor , torch .ones (input_size ) * res_num )
789- rs_tensor .sum ().backward ()
800+ with patch .object (MetricsContext , "update" , self ._metrics_context_update ):
801+ rs_tensor .sum ().backward ()
790802 self .assertEqual (input_tensor .grad , torch .full (output_size , fill_value = 1.0 ))
791803
792804
0 commit comments