11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
1817
1918from executorch .exir .pass_base import ExportPass , PassResult
2019from torch .fx import GraphModule
20+
2121from torch .library import impl , Library
2222
2323lib = Library ("tosa" , "DEF" )
2626
2727@impl (lib , "_table" )
2828def _table_impl (* args , ** kwargs ): # pyre-ignore
29- return args [0 ]
29+ in_dtype = args [0 ].dtype
30+ if in_dtype == torch .int8 :
31+ return args [0 ]
32+ return args [0 ].to (dtype = torch .int32 )
3033
3134
3235class InsertTableOpsPass (ExportPass ):
@@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
5962 """
6063 self .exported_program .state_dict [buffer_name ] = buffer
6164
62- def generate_table_values (
65+ def generate_8bit_table_values (
6366 self ,
6467 torch_op : Callable [[torch .Tensor ], torch .Tensor ],
6568 in_quantargs : QuantArgs ,
6669 out_quantargs : QuantArgs ,
67- ) -> torch .Tensor :
70+ ) -> tuple [torch .Tensor , int ]:
71+ """Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72+ The INT8 table is a simple 256 value 1-1 LUT.
73+ """
74+
6875 def f (x : torch .Tensor ) -> torch .Tensor :
6976 x = in_quantargs .dequantize_value (x )
7077 x = torch_op (x )
7178 return out_quantargs .quantize_value (x )
7279
73- input_dtype = in_quantargs .dtype
74- steps = in_quantargs .qmax - in_quantargs .qmin + 1
75- return f (
80+ return (
81+ f (
82+ torch .linspace (
83+ start = in_quantargs .qmin ,
84+ end = in_quantargs .qmax ,
85+ steps = 256 ,
86+ # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87+ # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88+ dtype = torch .int64 ,
89+ )
90+ ).to (dtype = torch .int8 ),
91+ 0 ,
92+ )
93+
94+ def generate_16_bit_table_values (
95+ self ,
96+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
97+ in_quantargs : QuantArgs ,
98+ out_quantargs : QuantArgs ,
99+ ) -> tuple [torch .Tensor , int ]:
100+ """Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
101+ In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
102+ the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
103+ will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.
104+
105+ Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
106+ the TOSA.TABLE output. In that case, we need to rescale up the output.
107+
108+ To handle this we need to:
109+ 1) Make sure that our table values fit within 16 bits.
110+ 2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.
111+
112+ The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
113+ """
114+
115+ def f (x : torch .Tensor ) -> torch .Tensor :
116+ # Dont use the 7 LSBs.
117+ x = in_quantargs .dequantize_value ((x & ~ 0x7F ))
118+ x = torch_op (x )
119+ return out_quantargs .quantize_value (x )
120+
121+ lut_values = f (
76122 torch .linspace (
77123 start = in_quantargs .qmin ,
78- end = in_quantargs .qmax ,
79- steps = steps ,
124+ end = in_quantargs .qmax + 1 ,
125+ steps = 513 ,
80126 # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81127 # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82128 dtype = torch .int64 ,
83129 )
84- ).to (dtype = input_dtype )
130+ )
131+ # Calculate how much we need to shift table values to fit in 16 signed bits
132+ # ceil(log2(max absolute table value)) + 1 bit for signedness - 16
133+ # Example:
134+ # Max value in the table is 70 000. We want to fit it in 16 signed bits.
135+ # 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
136+ # If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
137+ # but due to signedness this is a negative number! So we need to shift it one more bit.
138+ # Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
139+ rshift = int (torch .ceil (torch .log2 (lut_values .abs ().max ()))) + 1 - 16
140+ # The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
141+ rescale_lshift = rshift - 7
142+ lut_values = lut_values >> rshift
143+ return lut_values .to (dtype = torch .int16 ), rescale_lshift
144+
145+ def generate_table_values (
146+ self ,
147+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
148+ in_quantargs : QuantArgs ,
149+ out_quantargs : QuantArgs ,
150+ ) -> tuple [torch .Tensor , int ]:
151+ match out_quantargs .dtype :
152+ case torch .int8 :
153+ return self .generate_8bit_table_values (
154+ torch_op , in_quantargs , out_quantargs
155+ )
156+ case torch .int16 | torch .int32 :
157+ return self .generate_16_bit_table_values (
158+ torch_op , in_quantargs , out_quantargs
159+ )
160+ case _:
161+ raise ValueError (
162+ f"Unsupported output dtype for table: { out_quantargs .dtype } "
163+ )
85164
86165 def call (self , graph_module : GraphModule ) -> PassResult :
87166 modified = False
@@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100179 op_target = torch .ops .tosa ._table .default ,
101180 args = (node .args [0 ],),
102181 )
182+ output_node = table_node
103183 assert len (input_qparams ) == 1
104184 assert len (output_qparams ) == 1
105- # Generate table buffer
106- buffer = self .generate_table_values (
185+
186+ # Generate table buffer and how much to lshift the table output.
187+ buffer , lshift = self .generate_table_values (
107188 torch_op = self .table_ops [node .target ],
108189 in_quantargs = input_qparams [0 ],
109190 out_quantargs = output_qparams [0 ],
@@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114195 self .register_buffer (
115196 buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116197 )
117- node .replace_all_uses_with (table_node )
198+
199+ if lshift != 0 :
200+ scale = 2.0 ** lshift
201+ rescale_node = create_node (
202+ graph = graph_module .graph ,
203+ op_target = torch .ops .tosa ._rescale .default ,
204+ args = (table_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
205+ )
206+ output_node = rescale_node
207+
208+ node .replace_all_uses_with (output_node )
118209 graph_module .graph .erase_node (node )
119- table_node .meta ["input_qparams" ] = input_qparams
120- table_node .meta ["output_qparams" ] = output_qparams
210+ output_node .meta ["input_qparams" ] = input_qparams
211+ output_node .meta ["output_qparams" ] = output_qparams
121212 modified = True
122213
123214 if modified :
0 commit comments