11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
1817
1918from executorch .exir .pass_base import ExportPass , PassResult
2019from torch .fx import GraphModule
20+
2121from torch .library import impl , Library
2222
2323lib = Library ("tosa" , "DEF" )
2626
2727@impl (lib , "_table" )
2828def _table_impl (* args , ** kwargs ): # pyre-ignore
29- return args [0 ]
29+ in_dtype = args [0 ].dtype
30+ if in_dtype == torch .int8 :
31+ return args [0 ]
32+ return args [0 ].to (dtype = torch .int32 )
3033
3134
3235class InsertTableOpsPass (ExportPass ):
@@ -59,29 +62,89 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
5962 """
6063 self .exported_program .state_dict [buffer_name ] = buffer
6164
62- def generate_table_values (
65+ def generate_8bit_table_values (
6366 self ,
6467 torch_op : Callable [[torch .Tensor ], torch .Tensor ],
6568 in_quantargs : QuantArgs ,
6669 out_quantargs : QuantArgs ,
67- ) -> torch .Tensor :
70+ ) -> tuple [torch .Tensor , int ]:
71+ """Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72+ The INT8 table is a simple 256 value 1-1 LUT.
73+ """
74+
6875 def f (x : torch .Tensor ) -> torch .Tensor :
6976 x = in_quantargs .dequantize_value (x )
7077 x = torch_op (x )
7178 return out_quantargs .quantize_value (x )
7279
73- input_dtype = in_quantargs .dtype
74- steps = in_quantargs .qmax - in_quantargs .qmin + 1
75- return f (
80+ return (
81+ f (
82+ torch .linspace (
83+ start = in_quantargs .qmin ,
84+ end = in_quantargs .qmax ,
85+ steps = 256 ,
86+ # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87+ # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88+ dtype = torch .int64 ,
89+ )
90+ ).to (dtype = torch .int8 ),
91+ 0 ,
92+ )
93+
94+ def generate_16_bit_table_values (
95+ self ,
96+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
97+ in_quantargs : QuantArgs ,
98+ out_quantargs : QuantArgs ,
99+ ) -> tuple [torch .Tensor , int ]:
100+ """Compute LUT values for a INT16 TOSA.TABLE with 32 bit output (in practice 23 bit, see specification).
101+ The output of the the table will have 7 fractional bits, which means the output will interpreted as
102+ x128 times too large unless accounted for. Right shift the table values to fit
103+ in 16 bits. Return a lshift of the right shift - 7 due to the fractional bits.
104+ """
105+
106+ def f (x : torch .Tensor ) -> torch .Tensor :
107+ # Dont use the 7 LSBs
108+ x = in_quantargs .dequantize_value ((x & ~ 0x7F ))
109+ x = torch_op (x )
110+ return out_quantargs .quantize_value (x )
111+
112+ lut_values = f (
76113 torch .linspace (
77114 start = in_quantargs .qmin ,
78- end = in_quantargs .qmax ,
79- steps = steps ,
115+ end = in_quantargs .qmax + 1 ,
116+ steps = 513 ,
80117 # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81118 # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82119 dtype = torch .int64 ,
83120 )
84- ).to (dtype = input_dtype )
121+ )
122+ # Calculate how much we need to shift table values to fit in 16 bits
123+ # ceil(log2(max absolute table value)) + 1 bit for signedness - 16
124+ # Note: for out_quantargs.dtype=torch.int16, rshift == 0.
125+ rshift = int (torch .ceil (torch .log2 (lut_values .abs ().max ()))) + 1 - 16
126+ lut_values = lut_values >> rshift
127+ return lut_values .to (dtype = torch .int16 ), rshift - 7
128+
129+ def generate_table_values (
130+ self ,
131+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
132+ in_quantargs : QuantArgs ,
133+ out_quantargs : QuantArgs ,
134+ ) -> tuple [torch .Tensor , int ]:
135+ match out_quantargs .dtype :
136+ case torch .int8 :
137+ return self .generate_8bit_table_values (
138+ torch_op , in_quantargs , out_quantargs
139+ )
140+ case torch .int16 | torch .int32 :
141+ return self .generate_16_bit_table_values (
142+ torch_op , in_quantargs , out_quantargs
143+ )
144+ case _:
145+ raise ValueError (
146+ f"Unsupported output dtype for table: { out_quantargs .dtype } "
147+ )
85148
86149 def call (self , graph_module : GraphModule ) -> PassResult :
87150 modified = False
@@ -100,10 +163,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100163 op_target = torch .ops .tosa ._table .default ,
101164 args = (node .args [0 ],),
102165 )
166+ output_node = table_node
103167 assert len (input_qparams ) == 1
104168 assert len (output_qparams ) == 1
105- # Generate table buffer
106- buffer = self .generate_table_values (
169+
170+ # Generate table buffer and how much to lshift the table output.
171+ buffer , lshift = self .generate_table_values (
107172 torch_op = self .table_ops [node .target ],
108173 in_quantargs = input_qparams [0 ],
109174 out_quantargs = output_qparams [0 ],
@@ -114,10 +179,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114179 self .register_buffer (
115180 buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116181 )
117- node .replace_all_uses_with (table_node )
182+
183+ if lshift != 0 :
184+ scale = 2.0 ** lshift
185+ rescale_node = create_node (
186+ graph = graph_module .graph ,
187+ op_target = torch .ops .tosa ._rescale .default ,
188+ args = (table_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
189+ )
190+ output_node = rescale_node
191+
192+ node .replace_all_uses_with (output_node )
118193 graph_module .graph .erase_node (node )
119- table_node .meta ["input_qparams" ] = input_qparams
120- table_node .meta ["output_qparams" ] = output_qparams
194+ output_node .meta ["input_qparams" ] = input_qparams
195+ output_node .meta ["output_qparams" ] = output_qparams
121196 modified = True
122197
123198 if modified :
0 commit comments