diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py index 90ad502378c..5725e1884b3 100644 --- a/backends/arm/test/passes/test_rescale_pass.py +++ b/backends/arm/test/passes/test_rescale_pass.py @@ -13,7 +13,6 @@ from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester from parameterized import parameterized -from torch.testing._internal import optests def test_rescale_op(): @@ -64,7 +63,7 @@ def test_nonzero_zp_for_int32(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) @@ -87,7 +86,7 @@ def test_zp_outside_range(): ), ] for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): + with pytest.raises(Exception): torch.library.opcheck(torch.ops.tosa._rescale, sample_input) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 5a0bfe2c37c..0b1c5b05431 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -34,12 +34,32 @@ from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict from tosa import TosaGraph logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) +# Copied from PyTorch. +# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict +# To avoid a dependency on _internal stuff. +_torch_to_numpy_dtype_dict = { + torch.bool : np.bool_, + torch.uint8 : np.uint8, + torch.uint16 : np.uint16, + torch.uint32 : np.uint32, + torch.uint64 : np.uint64, + torch.int8 : np.int8, + torch.int16 : np.int16, + torch.int32 : np.int32, + torch.int64 : np.int64, + torch.float16 : np.float16, + torch.float32 : np.float32, + torch.float64 : np.float64, + torch.bfloat16 : np.float32, + torch.complex32 : np.complex64, + torch.complex64 : np.complex64, + torch.complex128: np.complex128 +} class QuantizationParams: __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"] @@ -335,7 +355,7 @@ def run_corstone( output_dtype = node.meta["val"].dtype tosa_ref_output = np.fromfile( os.path.join(intermediate_path, f"out-{i}.bin"), - torch_to_numpy_dtype_dict[output_dtype], + _torch_to_numpy_dtype_dict[output_dtype], ) output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) @@ -349,7 +369,7 @@ def prep_data_for_save( ): if isinstance(data, torch.Tensor): data_np = np.array(data.detach(), order="C").astype( - torch_to_numpy_dtype_dict[data.dtype] + _torch_to_numpy_dtype_dict[data.dtype] ) else: data_np = np.array(data)