11# Copyright (c) Meta Platforms, Inc. and affiliates.
22# All rights reserved.
3+ # Copyright 2025 Arm Limited and/or its affiliates.
34#
45# This source code is licensed under the BSD-style license found in the
56# LICENSE file in the root directory of this source tree.
67
78import torch
89from executorch .backends .cortex_m .passes .passes_utils import (
9- dequantize_per_tensor_cmsis ,
10- quantize_per_tensor_cmsis ,
10+ requantize_cmsis ,
11+ SHIFT_INT8 ,
1112)
1213from executorch .exir .dialects ._ops import ops as exir_ops
1314
@@ -111,52 +112,6 @@ def dequantize_per_tensor_impl(
111112 "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
112113)
113114
114-
115- @register_fake ("cortex_m::quantized_add" )
116- def quantized_add_meta (
117- self : torch .Tensor ,
118- self_zero_point : int ,
119- self_multiplier : int ,
120- self_shift : int ,
121- other : torch .Tensor ,
122- other_zero_point : int ,
123- other_multiplier : int ,
124- other_shift : int ,
125- output_zero_point : int ,
126- output_multiplier : int ,
127- output_shift : int ,
128- ) -> torch .Tensor :
129- broadcasted_shape = torch .broadcast_shapes (self .shape , other .shape )
130- return torch .empty (broadcasted_shape , dtype = torch .int8 , device = self .device )
131-
132-
133- @impl (lib , "quantized_add" , "CompositeExplicitAutograd" )
134- def quantized_add_impl (
135- self : torch .Tensor ,
136- self_zero_point : int ,
137- self_multiplier : int ,
138- self_shift : int ,
139- other : torch .Tensor ,
140- other_zero_point : int ,
141- other_multiplier : int ,
142- other_shift : int ,
143- output_zero_point : int ,
144- output_multiplier : int ,
145- output_shift : int ,
146- ) -> torch .Tensor :
147- self_fp = dequantize_per_tensor_cmsis (
148- self , self_zero_point , self_multiplier , self_shift
149- )
150- other_fp = dequantize_per_tensor_cmsis (
151- other , other_zero_point , other_multiplier , other_shift
152- )
153- result_fp = self_fp + other_fp
154- result_quantized = quantize_per_tensor_cmsis (
155- result_fp , output_zero_point , output_multiplier , output_shift
156- )
157- return result_quantized
158-
159-
160115# Define the operator schema with multipliers and shifts (11 args + out tensor)
161116lib .define (
162117 "quantized_add.out("
@@ -167,9 +122,8 @@ def quantized_add_impl(
167122)
168123
169124
170- # Fake meta function for shape and dtype inference during compilation
171- @register_fake ("cortex_m::quantized_add.out" )
172- def quantized_add_out_meta (
125+ @register_fake ("cortex_m::quantized_add" )
126+ def quantized_add_meta (
173127 self : torch .Tensor ,
174128 self_zero_point : int ,
175129 self_multiplier : int ,
@@ -181,19 +135,13 @@ def quantized_add_out_meta(
181135 output_zero_point : int ,
182136 output_multiplier : int ,
183137 output_shift : int ,
184- out : torch .Tensor ,
185138) -> torch .Tensor :
186- # Validate against correct broadcasted shape
187- expected_shape = torch .broadcast_shapes (self .shape , other .shape )
188- assert (
189- out .shape == expected_shape
190- ), f"Output shape { out .shape } must match broadcasted shape { expected_shape } "
191- return out
139+ broadcasted_shape = torch .broadcast_shapes (self .shape , other .shape )
140+ return torch .empty (broadcasted_shape , dtype = torch .int8 , device = self .device )
192141
193142
194- # Actual implementation delegating to backend or custom kernel
195- @impl (lib , "quantized_add.out" , "CompositeExplicitAutograd" )
196- def quantized_add_out_impl (
143+ @impl (lib , "quantized_add" , "CompositeExplicitAutograd" )
144+ def quantized_add_impl (
197145 self : torch .Tensor ,
198146 self_zero_point : int ,
199147 self_multiplier : int ,
@@ -205,24 +153,17 @@ def quantized_add_out_impl(
205153 output_zero_point : int ,
206154 output_multiplier : int ,
207155 output_shift : int ,
208- * ,
209- out : torch .Tensor ,
210156) -> torch .Tensor :
211- self_fp = dequantize_per_tensor_cmsis (
212- self , self_zero_point , self_multiplier , self_shift
213- )
214- other_fp = dequantize_per_tensor_cmsis (
215- other , other_zero_point , other_multiplier , other_shift
216- )
217- result_fp = self_fp + other_fp
218- result_quantized = quantize_per_tensor_cmsis (
219- result_fp , output_zero_point , output_multiplier , output_shift
220- )
157+ self_shifted = (self .to (torch .int32 ) - self_zero_point ) << SHIFT_INT8
158+ self_fp = requantize_cmsis (self_shifted , self_multiplier , self_shift )
221159
222- # Write into the provided output tensor
223- out . copy_ ( result_quantized )
160+ other_shifted = ( other . to ( torch . int32 ) - other_zero_point ) << SHIFT_INT8
161+ other_fp = requantize_cmsis ( other_shifted , other_multiplier , other_shift )
224162
225- return out
163+ result_fp = self_fp + other_fp
164+ result_quantized = requantize_cmsis (result_fp , output_multiplier , output_shift )
165+ result = torch .clamp (result_quantized + output_zero_point , - 128 , 127 ).to (torch .int8 )
166+ return result
226167
227168
228169# ===================================================================
0 commit comments