@@ -119,9 +119,10 @@ class InsertTableOpsPass(ArmPass):
119119
120120    _passes_required_after : Set [Type [ExportPass ]] =  set ()
121121
122-     def  __init__ (self , exported_program : ExportedProgram ) ->  None :
122+     def  __init__ (self , exported_program : ExportedProgram ,  tosa_spec = None ) ->  None :
123123        super ().__init__ ()
124124        self .exported_program  =  exported_program 
125+         self .tosa_spec  =  tosa_spec 
125126        self .table_ops  =  TableOps (exported_program )
126127
127128    def  register_buffer (self , buffer_name : str , buffer : torch .Tensor ) ->  None :
@@ -157,6 +158,82 @@ def f(x: torch.Tensor) -> torch.Tensor:
157158            0 ,
158159        )
159160
161+     def  generate_16_bit_table_values_u55_tflite_style (
162+         self ,
163+         torch_op : Callable [[torch .Tensor ], torch .Tensor ],
164+         in_quantargs : QuantArgs ,
165+         out_quantargs : QuantArgs ,
166+     ) ->  tuple [torch .Tensor , int ]:
167+         """ 
168+         Generate table values for U55 using U55-style bias correction. 
169+ 
170+         1. Evaluate function at base, midpoint, and next for each interval IN FLOAT SPACE 
171+         2. Quantize all three output values 
172+         3. Calculate bias = (interpolated_midpoint - actual_midpoint) / 2 
173+         4. Apply bias correction to base value 
174+         5. Store corrected base values (513 values total) 
175+         """ 
176+         import  math 
177+ 
178+         # Debug: Check if this function is being called 
179+ 
180+         # Calculate input range in FLOAT space (like TFLite) 
181+         qmin_in  =  in_quantargs .qmin 
182+         qmax_in  =  in_quantargs .qmax 
183+         qmin_out  =  out_quantargs .qmin 
184+         qmax_out  =  out_quantargs .qmax 
185+ 
186+         input_min  =  in_quantargs .scale  *  (qmin_in  -  in_quantargs .zp )
187+         input_max  =  in_quantargs .scale  *  (qmax_in  -  in_quantargs .zp )
188+         output_min  =  out_quantargs .scale  *  (qmin_out  -  out_quantargs .zp )
189+         output_max  =  out_quantargs .scale  *  (qmax_out  -  out_quantargs .zp )
190+ 
191+         steps  =  512 
192+         step  =  (input_max  -  input_min ) /  steps 
193+         half_step  =  step  /  2.0 
194+         output_scaling_inv  =  (qmax_out  -  qmin_out  +  1 ) /  (output_max  -  output_min )
195+ 
196+ 
197+         def  f (x_float : float ) ->  float :
198+             """Evaluate torch_op at x_float, handling NaN/inf.""" 
199+             x_tensor  =  torch .tensor ([x_float ], dtype = torch .float32 )
200+             result  =  torch_op (x_tensor ).item ()
201+ 
202+             if  math .isnan (result ) or  math .isinf (result ):
203+                 return  input_max   # Will quantize to qmax_out 
204+ 
205+             return  result 
206+ 
207+         lut_values  =  []
208+ 
209+         for  i  in  range (steps  +  1 ):  # 513 values (0 to 512) 
210+             val  =  f (input_min  +  i  *  step )
211+             sample_val  =  round (val  *  output_scaling_inv )
212+ 
213+             if  i  <  steps :
214+                 val_midpoint  =  f (input_min  +  i  *  step  +  half_step )
215+                 val_next  =  f (input_min  +  (i  +  1 ) *  step )
216+ 
217+                 midpoint_interp_val  =  round (
218+                     (val_next  *  output_scaling_inv  +  sample_val ) /  2.0 
219+                 )
220+                 midpoint_val  =  round (val_midpoint  *  output_scaling_inv )
221+                 midpoint_err  =  midpoint_interp_val  -  midpoint_val 
222+                 bias  =  round (midpoint_err  /  2.0 )
223+ 
224+                 clamped_lut_result  =  max (qmin_out , min (qmax_out , sample_val  -  bias ))
225+                 lut_result  =  int (clamped_lut_result )
226+ 
227+                 lut_values .append (lut_result )
228+             else :
229+                 # Last value (i == steps): no bias correction, just quantize and clamp 
230+                 clamped  =  max (qmin_out , min (qmax_out , sample_val ))
231+                 lut_values .append (int (clamped ))
232+ 
233+         buffer  =  torch .tensor (lut_values , dtype = torch .int16 ).contiguous ()
234+ 
235+         return  buffer , 0 
236+ 
160237    def  generate_16_bit_table_values (
161238        self ,
162239        torch_op : Callable [[torch .Tensor ], torch .Tensor ],
@@ -178,6 +255,12 @@ def generate_16_bit_table_values(
178255        The function returns rescale_lshift which says how much to rescale after the table. This value can negative. 
179256        """ 
180257
258+         # U55 needs TFLite-style table generation with bias correction 
259+         if  self .tosa_spec  is  not None  and  self .tosa_spec .is_U55_subset :
260+             return  self .generate_16_bit_table_values_u55_tflite_style (
261+                 torch_op , in_quantargs , out_quantargs 
262+             )
263+ 
181264        def  f (x : torch .Tensor ) ->  torch .Tensor :
182265            x  =  x .clamp (in_quantargs .qmin , in_quantargs .qmax ).to (
183266                dtype = in_quantargs .dtype 
@@ -280,7 +363,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
280363                )
281364                output_node  =  table_op_node 
282365
283-                 if  lshift  !=  0 :
366+                 if  (
367+                     self .tosa_spec 
368+                     and  self .tosa_spec .is_U55_subset 
369+                     and  input_qparams [0 ].dtype  ==  torch .int16 
370+                 ):
371+                     # U55: NO RESCALE needed - use table output directly 
372+                     # Adding RESCALE creates a second operation that overwrites the table output! 
373+                     output_node  =  table_op_node   # Use table output directly! 
374+                 elif  lshift  !=  0 :
284375                    scale  =  2.0 ** lshift 
285376                    rescale_node  =  create_node (
286377                        graph = graph_module .graph ,
0 commit comments