11
11
# the standardised TOSA representation.
12
12
#
13
13
14
- from typing import Any , Sequence
14
+ from typing import Any , Optional , Sequence
15
15
16
16
import torch
17
-
18
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
19
-
17
+ from executorch .backends .arm .tosa_specification import (
18
+ Tosa_0_80 ,
19
+ Tosa_1_00 ,
20
+ TosaSpecification ,
21
+ )
20
22
21
23
UNSUPPORTED_DTYPES = (
22
24
torch .float64 ,
30
32
torch .long ,
31
33
)
32
34
33
- DTYPE_MAP = {
34
- torch .float32 : ts .DType .FP32 ,
35
- torch .float : ts .DType .FP32 ,
36
- torch .float16 : ts .DType .FP16 ,
37
- torch .half : ts .DType .FP16 ,
38
- torch .bfloat16 : ts .DType .BF16 ,
39
- torch .int8 : ts .DType .INT8 ,
40
- torch .int16 : ts .DType .INT16 ,
41
- torch .short : ts .DType .INT16 ,
42
- torch .int32 : ts .DType .INT32 ,
43
- torch .int : ts .DType .INT32 ,
44
- torch .bool : ts .DType .BOOL ,
45
- }
46
-
47
-
48
- def map_dtype (data_type : torch .dtype ) -> ts .DType :
35
+
36
+ def map_dtype (data_type : torch .dtype , tosa_spec : TosaSpecification ) -> Any :
49
37
if data_type in UNSUPPORTED_DTYPES :
50
38
raise ValueError (f"Unsupported type: { data_type } " )
51
- if data_type not in DTYPE_MAP :
39
+ if isinstance (tosa_spec , Tosa_0_80 ):
40
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
41
+ elif isinstance (tosa_spec , Tosa_1_00 ):
42
+ import serializer .tosa_serializer as ts # type: ignore
43
+ else :
44
+ raise RuntimeError (f"Unsupported tosa_spec: { tosa_spec } " )
45
+
46
+ dtype_map = {
47
+ torch .float32 : ts .DType .FP32 ,
48
+ torch .float : ts .DType .FP32 ,
49
+ torch .float16 : ts .DType .FP16 ,
50
+ torch .half : ts .DType .FP16 ,
51
+ torch .bfloat16 : ts .DType .BF16 ,
52
+ torch .int8 : ts .DType .INT8 ,
53
+ torch .int16 : ts .DType .INT16 ,
54
+ torch .short : ts .DType .INT16 ,
55
+ torch .int32 : ts .DType .INT32 ,
56
+ torch .int : ts .DType .INT32 ,
57
+ torch .bool : ts .DType .BOOL ,
58
+ }
59
+ if data_type not in dtype_map :
52
60
raise ValueError (f"Unknown type: { data_type } " )
53
- return DTYPE_MAP [data_type ]
61
+ return dtype_map [data_type ]
54
62
55
63
56
64
# Returns the shape and type of a node
57
65
# TODO: other types, can be
58
66
# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
59
- def extract_tensor_meta (meta ):
67
+ def extract_tensor_meta (meta , tosa_spec : TosaSpecification ):
60
68
assert meta .get ("val" ) is not None
61
69
val = meta ["val" ]
62
70
if type (val ) is tuple :
@@ -67,7 +75,7 @@ def extract_tensor_meta(meta):
67
75
raise ValueError (
68
76
f"Expected first value in node.meta['val'] to be FakeTensor, got { val .__class__ } "
69
77
)
70
- dtype = map_dtype (val .dtype )
78
+ dtype = map_dtype (val .dtype , tosa_spec )
71
79
shape = tuple (val .size ())
72
80
73
81
if meta .get ("tosa_dim_order" ) is not None :
@@ -81,17 +89,28 @@ def extract_tensor_meta(meta):
81
89
class TosaArg :
82
90
def __process_node (self , argument : torch .fx .Node ):
83
91
self .name : str = argument .name
84
- self .dtype , self .shape , self .dim_order = extract_tensor_meta (argument .meta )
92
+ self .dtype , self .shape , self .dim_order = extract_tensor_meta (
93
+ argument .meta , self .tosa_spec
94
+ )
85
95
86
96
def __process_list (self , argument ):
87
97
self .special : list = list (argument )
88
98
89
99
def __process_number (self , argument : float | int ):
90
100
self .number : float | int = argument
91
101
92
- def __init__ (self , argument : Any ) -> None :
102
+ def __init__ (
103
+ self , argument : Any , tosa_spec : Optional [TosaSpecification ] = None
104
+ ) -> None :
93
105
if argument is None :
94
106
return
107
+ if tosa_spec is None :
108
+ raise ValueError ("tosa_spec is None" )
109
+ elif not isinstance (tosa_spec , TosaSpecification ):
110
+ raise ValueError (
111
+ f"Expected tosa_spec to be a TosaSpecification, but got { tosa_spec } "
112
+ )
113
+ self .tosa_spec = tosa_spec
95
114
96
115
if isinstance (argument , torch .fx .Node ):
97
116
self .__process_node (argument )
@@ -116,6 +135,12 @@ def __repr__(self):
116
135
if self .name is not None :
117
136
attrs .append (f"name={ self .name !r} " )
118
137
if self .dtype is not None :
138
+ if isinstance (self .tosa_spec , Tosa_0_80 ):
139
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
140
+ elif isinstance (self .tosa_spec , Tosa_1_00 ):
141
+ import serializer .tosa_serializer as ts # type: ignore
142
+ else :
143
+ raise RuntimeError (f"Unsupported tosa_spec: { self .tosa_spec } " )
119
144
attrs .append (f"dtype={ ts .DTypeNames [self .dtype ]} " )
120
145
if self .shape is not None :
121
146
attrs .append (f"shape={ self .shape !r} " )
@@ -125,4 +150,6 @@ def __repr__(self):
125
150
attrs .append (f"special={ self .special !r} " )
126
151
if hasattr (self , "number" ) and self .number is not None :
127
152
attrs .append (f"number={ self .number !r} " )
153
+ if hasattr (self , "tosa_spec" ) and self .tosa_spec is not None :
154
+ attrs .append (f"tosa_spec={ self .tosa_spec !r} " )
128
155
return f"{ self .__class__ .__name__ } ({ ', ' .join (attrs )} )"
0 commit comments