1313import torch
1414
1515
16- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
17- def distance (fn : Callable [[np .ndarray , np .ndarray ], float ]) -> Callable [
16+ def distance (
17+ fn : Callable [[np .ndarray , np .ndarray ], float ],
18+ ) -> Callable [
1819 [
19- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
2020 typing .Union [np .ndarray , torch ._tensor .Tensor ],
21- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
2221 typing .Union [np .ndarray , torch ._tensor .Tensor ],
2322 ],
2423 float ,
@@ -27,9 +26,7 @@ def distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[
2726 # the distance between two N-D tensors given a function. This can be a RMS
2827 # function, maximum abs diff, or any kind of distance function.
2928 def wrapper (
30- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
3129 a : Union [np .ndarray , torch .Tensor ],
32- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
3330 b : Union [np .ndarray , torch .Tensor ],
3431 ) -> float :
3532 # convert a and b to np.ndarray type fp64
@@ -68,24 +65,20 @@ def wrapper(
6865
6966
7067@distance
71- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
7268def rms (a : np .ndarray , b : np .ndarray ) -> float :
7369 return ((a - b ) ** 2 ).mean () ** 0.5
7470
7571
7672@distance
77- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
7873def max_abs_diff (a : np .ndarray , b : np .ndarray ) -> float :
7974 return np .abs (a - b ).max ()
8075
8176
8277@distance
83- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
8478def max_rel_diff (x : np .ndarray , x_ref : np .ndarray ) -> float :
8579 return np .abs ((x - x_ref ) / x_ref ).max ()
8680
8781
88- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
8982def to_np_arr_fp64 (x : Union [np .ndarray , torch .Tensor ]) -> np .ndarray :
9083 if isinstance (x , torch .Tensor ):
9184 x = x .detach ().cpu ().numpy ()
@@ -94,11 +87,8 @@ def to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
9487 return x
9588
9689
97- # pyre-fixme[3]: Return type must be annotated.
9890def normalized_rms (
99- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
10091 predicted : Union [np .ndarray , torch .Tensor ],
101- # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
10292 ground_truth : Union [np .ndarray , torch .Tensor ],
10393):
10494 num = rms (predicted , ground_truth )
0 commit comments