3434from torch .fx .node import Node
3535
3636from torch .overrides import TorchFunctionMode
37- from torch .testing ._internal .common_utils import torch_to_numpy_dtype_dict
3837from tosa import TosaGraph
3938
4039logger = logging .getLogger (__name__ )
4140logger .setLevel (logging .CRITICAL )
4241
42+ # Copied from PyTorch.
43+ # From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
44+ # To avoid a dependency on _internal stuff.
45+ _torch_to_numpy_dtype_dict = {
46+ torch .bool : np .bool_ ,
47+ torch .uint8 : np .uint8 ,
48+ torch .uint16 : np .uint16 ,
49+ torch .uint32 : np .uint32 ,
50+ torch .uint64 : np .uint64 ,
51+ torch .int8 : np .int8 ,
52+ torch .int16 : np .int16 ,
53+ torch .int32 : np .int32 ,
54+ torch .int64 : np .int64 ,
55+ torch .float16 : np .float16 ,
56+ torch .float32 : np .float32 ,
57+ torch .float64 : np .float64 ,
58+ torch .bfloat16 : np .float32 ,
59+ torch .complex32 : np .complex64 ,
60+ torch .complex64 : np .complex64 ,
61+ torch .complex128 : np .complex128
62+ }
4363
4464class QuantizationParams :
4565 __slots__ = ["node_name" , "zp" , "scale" , "qmin" , "qmax" , "dtype" ]
@@ -335,7 +355,7 @@ def run_corstone(
335355 output_dtype = node .meta ["val" ].dtype
336356 tosa_ref_output = np .fromfile (
337357 os .path .join (intermediate_path , f"out-{ i } .bin" ),
338- torch_to_numpy_dtype_dict [output_dtype ],
358+ _torch_to_numpy_dtype_dict [output_dtype ],
339359 )
340360
341361 output_np .append (torch .from_numpy (tosa_ref_output ).reshape (output_shape ))
@@ -349,7 +369,7 @@ def prep_data_for_save(
349369):
350370 if isinstance (data , torch .Tensor ):
351371 data_np = np .array (data .detach (), order = "C" ).astype (
352- torch_to_numpy_dtype_dict [data .dtype ]
372+ _torch_to_numpy_dtype_dict [data .dtype ]
353373 )
354374 else :
355375 data_np = np .array (data )
0 commit comments