diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 1b6f03512bd..d6e2f12884a 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -679,6 +680,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] + assert ( + ref.shape == model.shape + ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" assert torch.allclose( model, ref,