From b67f59803a711b34354b27a1a4f8aa3d2f2ee73c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Fri, 25 Oct 2024 11:43:33 +0200 Subject: [PATCH] Fix type handling for output types from TOSA reference model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I80953a699e4861b901af4b2fb17d47d3d7efcedd --- backends/arm/test/runner_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 3e9d3620cca..d2ee113a5d2 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -448,8 +448,11 @@ def run_tosa_ref_model( ), "There are no quantization parameters, check output parameters" tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if tosa_ref_output.dtype == np.double: + tosa_ref_output = tosa_ref_output.astype("float32") + # tosa_output is a numpy array, convert to torch tensor for comparison - tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32"))) + tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) return tosa_ref_outputs @@ -457,7 +460,9 @@ def run_tosa_ref_model( def prep_data_for_save( data, is_quantized: bool, input_name: str, quant_param: QuantizationParams ): - data_np = np.array(data.detach(), order="C").astype(np.float32) + data_np = np.array(data.detach(), order="C").astype( + f"{data.dtype}".replace("torch.", "") + ) if is_quantized: assert quant_param.node_name in input_name, (