@@ -372,13 +372,13 @@ def plot_metric(result: List[float], metric_name: str):
372372
373373def calculate_mse (ref_values : ProgramOutput , values : ProgramOutput ):
374374 def mean_squared_error (a : torch .Tensor , b : torch .Tensor ):
375- return round ((torch .pow ((a - b ). to ( torch . float32 ) , 2 )).mean ().item (), 2 )
375+ return round ((torch .pow ((a - b ), 2 )).mean ().item (), 2 )
376376
377377 results = []
378378 for ref_value , value in zip (ref_values , values ):
379379 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
380380 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
381- results .append (mean_squared_error (ref_value , value ))
381+ results .append (mean_squared_error (ref_value . to ( torch . float32 ) , value . to ( torch . float32 ) ))
382382 else :
383383 results .append (None )
384384
@@ -387,8 +387,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387387
388388def calculate_snr (ref_values : ProgramOutput , values : ProgramOutput ):
389389 def signal_to_noise (signal : torch .Tensor , noise : torch .Tensor ):
390- signal = signal .type (torch .float32 )
391- noise = noise .type (torch .float32 )
392390 signal_power = torch .mean (torch .pow (signal , 2 ))
393391 noise_power = torch .mean (torch .pow (noise , 2 ))
394392 snr = 10 * torch .log10 (signal_power / noise_power )
@@ -398,8 +396,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398396 for ref_value , value in zip (ref_values , values ):
399397 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
400398 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
401- diff = ref_value - value
402- snr = signal_to_noise (ref_value , diff )
399+ ref_value_fp = ref_value .to (torch .float32 )
400+ value_fp = value .to (torch .float32 )
401+ diff = ref_value_fp - value_fp
402+ snr = signal_to_noise (ref_value_fp , diff )
403403 results .append (snr )
404404 else :
405405 results .append (None )
@@ -429,7 +429,7 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429429 for ref_value , value in zip (ref_values , values ):
430430 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
431431 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
432- results .append (cosine_similarity (ref_value , value ))
432+ results .append (cosine_similarity (ref_value . to ( torch . float32 ) , value . to ( torch . float32 ) ))
433433 else :
434434 results .append (None )
435435
0 commit comments