Skip to content

Commit 03ea8c2

Browse files
committed
[ExecuTorch] Arm Ethos: Do not depend on torch.testing._internal
Pull Request resolved: pytorch/executorch#8839 This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) ghstack-source-id: 269021579
1 parent 4df0ade commit 03ea8c2

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

backends/arm/test/passes/test_rescale_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from executorch.backends.arm.test import common, conftest
1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1515
from parameterized import parameterized
16-
from torch.testing._internal import optests
1716

1817

1918
def test_rescale_op():
@@ -64,7 +63,7 @@ def test_nonzero_zp_for_int32():
6463
),
6564
]
6665
for sample_input in sample_inputs:
67-
with pytest.raises(optests.generate_tests.OpCheckError):
66+
with pytest.raises(Exception):
6867
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
6968

7069

@@ -87,7 +86,7 @@ def test_zp_outside_range():
8786
),
8887
]
8988
for sample_input in sample_inputs:
90-
with pytest.raises(optests.generate_tests.OpCheckError):
89+
with pytest.raises(Exception):
9190
torch.library.opcheck(torch.ops.tosa._rescale, sample_input)
9291

9392

backends/arm/test/runner_utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,32 @@
3434
from torch.fx.node import Node
3535

3636
from torch.overrides import TorchFunctionMode
37-
from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
3837
from tosa import TosaGraph
3938

4039
logger = logging.getLogger(__name__)
4140
logger.setLevel(logging.CRITICAL)
4241

42+
# Copied from PyTorch.
43+
# From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
44+
# To avoid a dependency on _internal stuff.
45+
_torch_to_numpy_dtype_dict = {
46+
torch.bool : np.bool_,
47+
torch.uint8 : np.uint8,
48+
torch.uint16 : np.uint16,
49+
torch.uint32 : np.uint32,
50+
torch.uint64 : np.uint64,
51+
torch.int8 : np.int8,
52+
torch.int16 : np.int16,
53+
torch.int32 : np.int32,
54+
torch.int64 : np.int64,
55+
torch.float16 : np.float16,
56+
torch.float32 : np.float32,
57+
torch.float64 : np.float64,
58+
torch.bfloat16 : np.float32,
59+
torch.complex32 : np.complex64,
60+
torch.complex64 : np.complex64,
61+
torch.complex128: np.complex128
62+
}
4363

4464
class QuantizationParams:
4565
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
@@ -335,7 +355,7 @@ def run_corstone(
335355
output_dtype = node.meta["val"].dtype
336356
tosa_ref_output = np.fromfile(
337357
os.path.join(intermediate_path, f"out-{i}.bin"),
338-
torch_to_numpy_dtype_dict[output_dtype],
358+
_torch_to_numpy_dtype_dict[output_dtype],
339359
)
340360

341361
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
@@ -349,7 +369,7 @@ def prep_data_for_save(
349369
):
350370
if isinstance(data, torch.Tensor):
351371
data_np = np.array(data.detach(), order="C").astype(
352-
torch_to_numpy_dtype_dict[data.dtype]
372+
_torch_to_numpy_dtype_dict[data.dtype]
353373
)
354374
else:
355375
data_np = np.array(data)

0 commit comments

Comments
 (0)