1111# the standardised TOSA representation.
1212#
1313
14- from typing import Any , Sequence
14+ from typing import Any , Optional , Sequence
1515
1616import torch
17-
18- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
19-
17+ from executorch .backends .arm .tosa_specification import (
18+ Tosa_0_80 ,
19+ Tosa_1_00 ,
20+ TosaSpecification ,
21+ )
2022
2123UNSUPPORTED_DTYPES = (
2224 torch .float64 ,
3032 torch .long ,
3133)
3234
33- DTYPE_MAP = {
34- torch .float32 : ts .DType .FP32 ,
35- torch .float : ts .DType .FP32 ,
36- torch .float16 : ts .DType .FP16 ,
37- torch .half : ts .DType .FP16 ,
38- torch .bfloat16 : ts .DType .BF16 ,
39- torch .int8 : ts .DType .INT8 ,
40- torch .int16 : ts .DType .INT16 ,
41- torch .short : ts .DType .INT16 ,
42- torch .int32 : ts .DType .INT32 ,
43- torch .int : ts .DType .INT32 ,
44- torch .bool : ts .DType .BOOL ,
45- }
46-
47-
48- def map_dtype (data_type : torch .dtype ) -> ts .DType :
35+
36+ def map_dtype (data_type : torch .dtype , tosa_spec : TosaSpecification ) -> Any :
4937 if data_type in UNSUPPORTED_DTYPES :
5038 raise ValueError (f"Unsupported type: { data_type } " )
51- if data_type not in DTYPE_MAP :
39+ if isinstance (tosa_spec , Tosa_0_80 ):
40+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
41+ elif isinstance (tosa_spec , Tosa_1_00 ):
42+ import serializer .tosa_serializer as ts # type: ignore
43+ else :
44+ raise RuntimeError (f"Unsupported tosa_spec: { tosa_spec } " )
45+
46+ dtype_map = {
47+ torch .float32 : ts .DType .FP32 ,
48+ torch .float : ts .DType .FP32 ,
49+ torch .float16 : ts .DType .FP16 ,
50+ torch .half : ts .DType .FP16 ,
51+ torch .bfloat16 : ts .DType .BF16 ,
52+ torch .int8 : ts .DType .INT8 ,
53+ torch .int16 : ts .DType .INT16 ,
54+ torch .short : ts .DType .INT16 ,
55+ torch .int32 : ts .DType .INT32 ,
56+ torch .int : ts .DType .INT32 ,
57+ torch .bool : ts .DType .BOOL ,
58+ }
59+ if data_type not in dtype_map :
5260 raise ValueError (f"Unknown type: { data_type } " )
53- return DTYPE_MAP [data_type ]
61+ return dtype_map [data_type ]
5462
5563
5664# Returns the shape and type of a node
5765# TODO: other types, can be
5866# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
59- def extract_tensor_meta (meta ):
67+ def extract_tensor_meta (meta , tosa_spec : TosaSpecification ):
6068 assert meta .get ("val" ) is not None
6169 val = meta ["val" ]
6270 if type (val ) is tuple :
@@ -67,7 +75,7 @@ def extract_tensor_meta(meta):
6775 raise ValueError (
6876 f"Expected first value in node.meta['val'] to be FakeTensor, got { val .__class__ } "
6977 )
70- dtype = map_dtype (val .dtype )
78+ dtype = map_dtype (val .dtype , tosa_spec )
7179 shape = tuple (val .size ())
7280
7381 if meta .get ("tosa_dim_order" ) is not None :
@@ -81,17 +89,28 @@ def extract_tensor_meta(meta):
8189class TosaArg :
8290 def __process_node (self , argument : torch .fx .Node ):
8391 self .name : str = argument .name
84- self .dtype , self .shape , self .dim_order = extract_tensor_meta (argument .meta )
92+ self .dtype , self .shape , self .dim_order = extract_tensor_meta (
93+ argument .meta , self .tosa_spec
94+ )
8595
8696 def __process_list (self , argument ):
8797 self .special : list = list (argument )
8898
8999 def __process_number (self , argument : float | int ):
90100 self .number : float | int = argument
91101
92- def __init__ (self , argument : Any ) -> None :
102+ def __init__ (
103+ self , argument : Any , tosa_spec : Optional [TosaSpecification ] = None
104+ ) -> None :
93105 if argument is None :
94106 return
107+ if tosa_spec is None :
108+ raise ValueError ("tosa_spec is None" )
109+ elif not isinstance (tosa_spec , TosaSpecification ):
110+ raise ValueError (
111+ f"Expected tosa_spec to be a TosaSpecification, but got { tosa_spec } "
112+ )
113+ self .tosa_spec = tosa_spec
95114
96115 if isinstance (argument , torch .fx .Node ):
97116 self .__process_node (argument )
@@ -116,6 +135,12 @@ def __repr__(self):
116135 if self .name is not None :
117136 attrs .append (f"name={ self .name !r} " )
118137 if self .dtype is not None :
138+ if isinstance (self .tosa_spec , Tosa_0_80 ):
139+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
140+ elif isinstance (self .tosa_spec , Tosa_1_00 ):
141+ import serializer .tosa_serializer as ts # type: ignore
142+ else :
143+ raise RuntimeError (f"Unsupported tosa_spec: { self .tosa_spec } " )
119144 attrs .append (f"dtype={ ts .DTypeNames [self .dtype ]} " )
120145 if self .shape is not None :
121146 attrs .append (f"shape={ self .shape !r} " )
@@ -125,4 +150,6 @@ def __repr__(self):
125150 attrs .append (f"special={ self .special !r} " )
126151 if hasattr (self , "number" ) and self .number is not None :
127152 attrs .append (f"number={ self .number !r} " )
153+ if hasattr (self , "tosa_spec" ) and self .tosa_spec is not None :
154+ attrs .append (f"tosa_spec={ self .tosa_spec !r} " )
128155 return f"{ self .__class__ .__name__ } ({ ', ' .join (attrs )} )"
0 commit comments