@@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str):
372372
373373def  calculate_mse (ref_values : ProgramOutput , values : ProgramOutput ):
374374    def  mean_squared_error (a : torch .Tensor , b : torch .Tensor ):
375-         return  round ((torch .pow ((a  -  b ). to ( torch . float32 ) , 2 )).mean ().item (), 2 )
375+         return  round ((torch .pow ((a  -  b ), 2 )).mean ().item (), 2 )
376376
377377    results  =  []
378378    for  ref_value , value  in  zip (ref_values , values ):
379379        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 
380380        if  isinstance (ref_value , torch .Tensor ) and  isinstance (value , torch .Tensor ):
381-             results .append (mean_squared_error (ref_value , value ))
381+             results .append (
382+                 mean_squared_error (ref_value .to (torch .float32 ), value .to (torch .float32 ))
383+             )
382384        else :
383385            results .append (None )
384386
@@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387389
388390def  calculate_snr (ref_values : ProgramOutput , values : ProgramOutput ):
389391    def  signal_to_noise (signal : torch .Tensor , noise : torch .Tensor ):
390-         signal  =  signal .type (torch .float32 )
391-         noise  =  noise .type (torch .float32 )
392392        signal_power  =  torch .mean (torch .pow (signal , 2 ))
393393        noise_power  =  torch .mean (torch .pow (noise , 2 ))
394394        snr  =  10  *  torch .log10 (signal_power  /  noise_power )
@@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398398    for  ref_value , value  in  zip (ref_values , values ):
399399        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 
400400        if  isinstance (ref_value , torch .Tensor ) and  isinstance (value , torch .Tensor ):
401-             diff  =  ref_value  -  value 
402-             snr  =  signal_to_noise (ref_value , diff )
401+             ref_value_fp  =  ref_value .to (torch .float32 )
402+             value_fp  =  value .to (torch .float32 )
403+             diff  =  ref_value_fp  -  value_fp 
404+             snr  =  signal_to_noise (ref_value_fp , diff )
403405            results .append (snr )
404406        else :
405407            results .append (None )
@@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429431    for  ref_value , value  in  zip (ref_values , values ):
430432        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 
431433        if  isinstance (ref_value , torch .Tensor ) and  isinstance (value , torch .Tensor ):
432-             results .append (cosine_similarity (ref_value , value ))
434+             results .append (
435+                 cosine_similarity (ref_value .to (torch .float32 ), value .to (torch .float32 ))
436+             )
433437        else :
434438            results .append (None )
435439
0 commit comments