diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 3e9d3620cca..d2ee113a5d2 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -448,8 +448,11 @@ def run_tosa_ref_model( ), "There are no quantization parameters, check output parameters" tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if tosa_ref_output.dtype == np.double: + tosa_ref_output = tosa_ref_output.astype("float32") + # tosa_output is a numpy array, convert to torch tensor for comparison - tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32"))) + tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) return tosa_ref_outputs @@ -457,7 +460,9 @@ def run_tosa_ref_model( def prep_data_for_save( data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): - data_np = np.array(data.detach(), order="C").astype(np.float32) + data_np = np.array(data.detach(), order="C").astype( + f"{data.dtype}".replace("torch.", "") + ) if is_quantized: assert quant_param.node_name in input_name, (