@@ -78,41 +78,43 @@ def __init__(
7878 device : Union [str , torch .device ] = torch .device ("cpu" ),
7979 ) -> None :
8080 super (PrecisionRecallCurve , self ).__init__ (
81- precision_recall_curve_compute_fn ,
81+ precision_recall_curve_compute_fn , # type: ignore[arg-type]
8282 output_transform = output_transform ,
8383 check_compute_fn = check_compute_fn ,
8484 device = device ,
8585 )
8686
87- def compute (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
87+ def compute (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]: # type: ignore[override]
8888 if len (self ._predictions ) < 1 or len (self ._targets ) < 1 :
8989 raise NotComputableError ("PrecisionRecallCurve must have at least one example before it can be computed." )
9090
91- _prediction_tensor = torch .cat (self ._predictions , dim = 0 )
92- _target_tensor = torch .cat (self ._targets , dim = 0 )
93-
94- ws = idist .get_world_size ()
95- if ws > 1 and not self ._is_reduced :
96- # All gather across all processes
97- _prediction_tensor = cast (torch .Tensor , idist .all_gather (_prediction_tensor ))
98- _target_tensor = cast (torch .Tensor , idist .all_gather (_target_tensor ))
99- self ._is_reduced = True
100-
101- if idist .get_rank () == 0 :
102- # Run compute_fn on zero rank only
103- precision , recall , thresholds = self .compute_fn (_prediction_tensor , _target_tensor )
104- precision = torch .tensor (precision )
105- recall = torch .tensor (recall )
106- # thresholds can have negative strides, not compatible with torch tensors
107- # https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
108- thresholds = torch .tensor (thresholds .copy ())
109- else :
110- precision , recall , thresholds = None , None , None
111-
112- if ws > 1 :
113- # broadcast result to all processes
114- precision = idist .broadcast (precision , src = 0 , safe_mode = True )
115- recall = idist .broadcast (recall , src = 0 , safe_mode = True )
116- thresholds = idist .broadcast (thresholds , src = 0 , safe_mode = True )
117-
118- return precision , recall , thresholds
91+ if self ._result is None :
92+ _prediction_tensor = torch .cat (self ._predictions , dim = 0 )
93+ _target_tensor = torch .cat (self ._targets , dim = 0 )
94+
95+ ws = idist .get_world_size ()
96+ if ws > 1 :
97+ # All gather across all processes
98+ _prediction_tensor = cast (torch .Tensor , idist .all_gather (_prediction_tensor ))
99+ _target_tensor = cast (torch .Tensor , idist .all_gather (_target_tensor ))
100+
101+ if idist .get_rank () == 0 :
102+ # Run compute_fn on zero rank only
103+ precision , recall , thresholds = cast (Tuple , self .compute_fn (_prediction_tensor , _target_tensor ))
104+ precision = torch .tensor (precision , device = _prediction_tensor .device )
105+ recall = torch .tensor (recall , device = _prediction_tensor .device )
106+ # thresholds can have negative strides, not compatible with torch tensors
107+ # https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
108+ thresholds = torch .tensor (thresholds .copy (), device = _prediction_tensor .device )
109+ else :
110+ precision , recall , thresholds = None , None , None
111+
112+ if ws > 1 :
113+ # broadcast result to all processes
114+ precision = idist .broadcast (precision , src = 0 , safe_mode = True )
115+ recall = idist .broadcast (recall , src = 0 , safe_mode = True )
116+ thresholds = idist .broadcast (thresholds , src = 0 , safe_mode = True )
117+
118+ self ._result = (precision , recall , thresholds ) # type: ignore[assignment]
119+
120+ return cast (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ], self ._result )
0 commit comments