@@ -535,6 +535,29 @@ def update(self, output):
535535 pass
536536
537537
538+ def _test_compute_with_sync_all_reduce_doesnt_change_attributes (device ):
539+ class DummyMetric3 (Metric ):
540+ @reinit__is_reduced
541+ def reset (self ):
542+ self .a = torch .tensor (0.0 , device = self ._device )
543+ self .b = 0.0
544+
545+ def update (self , output ):
546+ self .a += torch .tensor (1.0 )
547+ self .b += 1.0
548+
549+ @sync_all_reduce ("a" , "b" )
550+ def compute (self ):
551+ return self .a .item (), self .b
552+
553+ metric_device = device if torch .device (device ).type != "xla" else "cpu"
554+ metric = DummyMetric3 (device = metric_device )
555+ metric .update (None )
556+ assert metric .a .item () == metric .b == 1.0
557+ metric .compute ()
558+ assert metric .a .item () == metric .b == 1.0
559+
560+
538561def _test_invalid_sync_all_reduce (device ):
539562 class InvalidMetric (Metric ):
540563 @reinit__is_reduced
@@ -543,6 +566,7 @@ def reset(self):
543566 self .c = 0.0
544567 self .n = 0
545568 self .m = - 1
569+ self .d = "a string"
546570
547571 def compute (self ):
548572 pass
@@ -566,6 +590,14 @@ def invalid_reduction_op_3(self):
566590 def invalid_reduction_op_4 (self ):
567591 pass
568592
593+ @sync_all_reduce ("missingattr" )
594+ def invalid_reduction_op_5 (self ):
595+ pass
596+
597+ @sync_all_reduce ("d" )
598+ def invalid_reduction_op_6 (self ):
599+ pass
600+
569601 metric_device = device if torch .device (device ).type != "xla" else "cpu"
570602 m = InvalidMetric (device = metric_device )
571603 m .reset ()
@@ -583,6 +615,14 @@ def invalid_reduction_op_4(self):
583615 with pytest .raises (ValueError , match = r"Reduction operation is not valid" ):
584616 m .invalid_reduction_op_4 ()
585617
618+ with pytest .raises (ValueError , match = r"has no attribute named `missingattr`." ):
619+ m .invalid_reduction_op_5 ()
620+
621+ with pytest .raises (
622+ TypeError , match = r"Attribute provided to sync_all_reduce should be a number or tensor but `d`"
623+ ):
624+ m .invalid_reduction_op_6 ()
625+
586626
587627def _test_distrib_sync_all_reduce_decorator (device ):
588628 class DummyMetric (Metric ):
@@ -647,7 +687,7 @@ def update(self, output):
647687 m = DummyMetric (device = metric_device )
648688 m .update (None )
649689 m .compute ()
650- # check if can call compute multiple times without all reduce invocation
690+ # check if attributes are restored to their original values after previous `compute`
651691 m .compute ()
652692
653693
@@ -664,6 +704,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
664704 device = idist .device ()
665705 _test_distrib_sync_all_reduce_decorator (device )
666706 _test_invalid_sync_all_reduce (device )
707+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
667708
668709
669710@pytest .mark .distributed
@@ -673,6 +714,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
673714 device = idist .device ()
674715 _test_distrib_sync_all_reduce_decorator (device )
675716 _test_invalid_sync_all_reduce (device )
717+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
676718
677719
678720@pytest .mark .distributed
@@ -685,6 +727,7 @@ def test_distrib_hvd(gloo_hvd_executor):
685727
686728 gloo_hvd_executor (_test_distrib_sync_all_reduce_decorator , (device ,), np = nproc , do_init = True )
687729 gloo_hvd_executor (_test_invalid_sync_all_reduce , (device ,), np = nproc , do_init = True )
730+ gloo_hvd_executor (_test_compute_with_sync_all_reduce_doesnt_change_attributes , (device ,), np = nproc , do_init = True )
688731
689732
690733@pytest .mark .multinode_distributed
@@ -695,6 +738,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
695738 device = idist .device ()
696739 _test_distrib_sync_all_reduce_decorator (device )
697740 _test_invalid_sync_all_reduce (device )
741+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
698742
699743
700744@pytest .mark .multinode_distributed
@@ -705,6 +749,7 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
705749 device = idist .device ()
706750 _test_distrib_sync_all_reduce_decorator (device )
707751 _test_invalid_sync_all_reduce (device )
752+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
708753
709754
710755@pytest .mark .tpu
@@ -715,13 +760,15 @@ def test_distrib_single_device_xla():
715760 _test_distrib_sync_all_reduce_decorator (device )
716761 _test_creating_on_xla_fails (device )
717762 _test_invalid_sync_all_reduce (device )
763+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
718764
719765
720766def _test_distrib_xla_nprocs (index ):
721767 device = idist .device ()
722768 _test_distrib_sync_all_reduce_decorator (device )
723769 _test_creating_on_xla_fails (device )
724770 _test_invalid_sync_all_reduce (device )
771+ _test_compute_with_sync_all_reduce_doesnt_change_attributes (device )
725772
726773
727774@pytest .mark .tpu
0 commit comments