11
11
12
12
from typing import Any , Tuple
13
13
14
- import executorch .backends .arm .tosa_specification as tosa_specification
15
-
14
+ import serializer .tosa_serializer as ts # type: ignore
16
15
import torch .fx
17
16
import torch .fx .node
18
17
@@ -247,25 +246,18 @@ def build_rescale_to_int32(
247
246
) -> Any :
248
247
input_A_rescaled_to_int32 = None
249
248
250
- if isinstance (tosa_spec , tosa_specification .Tosa_1_00 ):
251
- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
252
- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
253
- import serializer .tosa_serializer as ts # type: ignore
254
-
255
- input_A_rescaled_to_int32 = tosa_fb .addIntermediate (
256
- input_arg .shape , ts .DType .INT32
257
- )
249
+ input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input_arg .shape , ts .DType .INT32 )
258
250
259
- build_rescale (
260
- tosa_fb ,
261
- [rescale_scale ],
262
- input_arg ,
263
- input_A_rescaled_to_int32 .name ,
264
- ts .DType .INT32 ,
265
- [input_zp ],
266
- [0 ],
267
- rounding_mode = RoundingMode .SINGLE_ROUND ,
268
- ) # type: ignore[call-arg]
251
+ build_rescale (
252
+ tosa_fb ,
253
+ [rescale_scale ],
254
+ input_arg ,
255
+ input_A_rescaled_to_int32 .name ,
256
+ ts .DType .INT32 ,
257
+ [input_zp ],
258
+ [0 ],
259
+ rounding_mode = RoundingMode .SINGLE_ROUND ,
260
+ ) # type: ignore[call-arg]
269
261
270
262
return input_A_rescaled_to_int32
271
263
@@ -281,21 +273,19 @@ def build_rescale_from_int32(
281
273
per_channel : bool = False ,
282
274
tosa_spec = None ,
283
275
) -> None :
284
- if isinstance (tosa_spec , tosa_specification .Tosa_1_00 ):
285
- import serializer .tosa_serializer as ts # type: ignore
286
-
287
- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
288
- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
289
- build_rescale (
290
- tosa_fb ,
291
- [rescale_scale ],
292
- input_node ,
293
- output_name = output_name ,
294
- output_type = ts .DType .INT8 ,
295
- input_zp = [0 ],
296
- output_zp = [output_zp ],
297
- rounding_mode = RoundingMode .SINGLE_ROUND ,
298
- ) # type: ignore[call-arg]
276
+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
277
+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
278
+ build_rescale (
279
+ tosa_fb ,
280
+ [rescale_scale ],
281
+ input_node ,
282
+ output_name = output_name ,
283
+ output_type = ts .DType .INT8 ,
284
+ input_zp = [0 ],
285
+ output_zp = [output_zp ],
286
+ rounding_mode = RoundingMode .SINGLE_ROUND ,
287
+ ) # type: ignore[call-arg]
288
+
299
289
return
300
290
301
291
@@ -318,18 +308,17 @@ def build_rescale_conv_output(
318
308
(inp * w ) / out for inp , w , out in zip (input_scale , weight_scale , output_scale )
319
309
]
320
310
321
- if isinstance (tosa_spec [0 ], tosa_specification .Tosa_1_00 ):
322
- # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
323
- # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
324
- build_rescale (
325
- tosa_fb = tosa_fb ,
326
- scale = post_conv2d_scale ,
327
- input_node = op ,
328
- output_name = output_name ,
329
- output_type = output_type ,
330
- input_zp = [0 ],
331
- output_zp = output_zp ,
332
- rounding_mode = RoundingMode .SINGLE_ROUND ,
333
- per_channel = isinstance (weight_scale , torch .Tensor ),
334
- ) # type: ignore[call-arg]
311
+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
312
+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
313
+ build_rescale (
314
+ tosa_fb = tosa_fb ,
315
+ scale = post_conv2d_scale ,
316
+ input_node = op ,
317
+ output_name = output_name ,
318
+ output_type = output_type ,
319
+ input_zp = [0 ],
320
+ output_zp = output_zp ,
321
+ rounding_mode = RoundingMode .SINGLE_ROUND ,
322
+ per_channel = isinstance (weight_scale , torch .Tensor ),
323
+ ) # type: ignore[call-arg]
335
324
return
0 commit comments