@@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str):
372372
373373def calculate_mse (ref_values : ProgramOutput , values : ProgramOutput ):
374374 def mean_squared_error (a : torch .Tensor , b : torch .Tensor ):
375- return round ((torch .pow ((a - b ). to ( torch . float32 ) , 2 )).mean ().item (), 2 )
375+ return round ((torch .pow ((a - b ), 2 )).mean ().item (), 2 )
376376
377377 results = []
378378 for ref_value , value in zip (ref_values , values ):
379379 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
380380 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
381- results .append (mean_squared_error (ref_value , value ))
381+ results .append (
382+ mean_squared_error (ref_value .to (torch .float32 ), value .to (torch .float32 ))
383+ )
382384 else :
383385 results .append (None )
384386
@@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387389
388390def calculate_snr (ref_values : ProgramOutput , values : ProgramOutput ):
389391 def signal_to_noise (signal : torch .Tensor , noise : torch .Tensor ):
390- signal = signal .type (torch .float32 )
391- noise = noise .type (torch .float32 )
392392 signal_power = torch .mean (torch .pow (signal , 2 ))
393393 noise_power = torch .mean (torch .pow (noise , 2 ))
394394 snr = 10 * torch .log10 (signal_power / noise_power )
@@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398398 for ref_value , value in zip (ref_values , values ):
399399 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
400400 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
401- diff = ref_value - value
402- snr = signal_to_noise (ref_value , diff )
401+ ref_value_fp = ref_value .to (torch .float32 )
402+ value_fp = value .to (torch .float32 )
403+ diff = ref_value_fp - value_fp
404+ snr = signal_to_noise (ref_value_fp , diff )
403405 results .append (snr )
404406 else :
405407 results .append (None )
@@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429431 for ref_value , value in zip (ref_values , values ):
430432 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
431433 if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
432- results .append (cosine_similarity (ref_value , value ))
434+ results .append (
435+ cosine_similarity (ref_value .to (torch .float32 ), value .to (torch .float32 ))
436+ )
433437 else :
434438 results .append (None )
435439
0 commit comments