3434from  torch .fx .node  import  Node 
3535
3636from  torch .overrides  import  TorchFunctionMode 
37- from  torch .testing ._internal .common_utils  import  torch_to_numpy_dtype_dict 
3837from  tosa  import  TosaGraph 
3938
4039logger  =  logging .getLogger (__name__ )
4140logger .setLevel (logging .CRITICAL )
4241
42+ # Copied from PyTorch. 
43+ # From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict 
44+ # To avoid a dependency on _internal stuff. 
45+ _torch_to_numpy_dtype_dict  =  {
46+     torch .bool : np .bool_ ,
47+     torch .uint8 : np .uint8 ,
48+     torch .uint16 : np .uint16 ,
49+     torch .uint32 : np .uint32 ,
50+     torch .uint64 : np .uint64 ,
51+     torch .int8 : np .int8 ,
52+     torch .int16 : np .int16 ,
53+     torch .int32 : np .int32 ,
54+     torch .int64 : np .int64 ,
55+     torch .float16 : np .float16 ,
56+     torch .float32 : np .float32 ,
57+     torch .float64 : np .float64 ,
58+     torch .bfloat16 : np .float32 ,
59+     torch .complex32 : np .complex64 ,
60+     torch .complex64 : np .complex64 ,
61+     torch .complex128 : np .complex128 ,
62+ }
63+ 
4364
4465class  QuantizationParams :
4566    __slots__  =  ["node_name" , "zp" , "scale" , "qmin" , "qmax" , "dtype" ]
@@ -335,7 +356,7 @@ def run_corstone(
335356        output_dtype  =  node .meta ["val" ].dtype 
336357        tosa_ref_output  =  np .fromfile (
337358            os .path .join (intermediate_path , f"out-{ i }  .bin" ),
338-             torch_to_numpy_dtype_dict [output_dtype ],
359+             _torch_to_numpy_dtype_dict [output_dtype ],
339360        )
340361
341362        output_np .append (torch .from_numpy (tosa_ref_output ).reshape (output_shape ))
@@ -349,7 +370,7 @@ def prep_data_for_save(
349370):
350371    if  isinstance (data , torch .Tensor ):
351372        data_np  =  np .array (data .detach (), order = "C" ).astype (
352-             torch_to_numpy_dtype_dict [data .dtype ]
373+             _torch_to_numpy_dtype_dict [data .dtype ]
353374        )
354375    else :
355376        data_np  =  np .array (data )
0 commit comments