Skip to content

Commit 3e0cd20

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Backend-agnostic implementation of quantized_layer_norm_per_tensor
Summary: Continuing support for supporting backend-agnostic Cadence custom ops. Differential Revision: D81459333
1 parent ee3c9a0 commit 3e0cd20

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,85 @@ def quantized_linear(
239239
return out.reshape(*leading_dims, N)
240240

241241

242+
@impl(m, "quantized_layer_norm_per_tensor")
243+
def quantized_layer_norm_per_tensor(
244+
input: torch.Tensor,
245+
X_scale: float,
246+
X_zero_point: int,
247+
normalized_shape: int,
248+
weight: torch.Tensor,
249+
bias: torch.Tensor,
250+
eps: float,
251+
output_scale: float,
252+
output_zero_point: int,
253+
) -> torch.Tensor:
254+
"""
255+
Quantized layer norm operation.
256+
257+
Args:
258+
- input (Tensor): The activations tensor
259+
- X_scale (float): The scale of the input
260+
- X_zero_point (int): The zero point of the input
261+
- normalized_shape (int): The shape of the input (unused)
262+
- weight (Tensor): The weight tensor
263+
- bias (Tensor): The bias tensor
264+
- eps (float): The epsilon value
265+
- output_scale (float): The scale of the output
266+
- output_zero_point (int): The zero point of the output
267+
"""
268+
supported_dtypes = [torch.int8, torch.uint8]
269+
if input.dtype not in supported_dtypes:
270+
raise ValueError(
271+
f"Input dtype must be one of {supported_dtypes}. Got {input.dtype}"
272+
)
273+
274+
# Get dimensions
275+
last_dim = input.size(-1)
276+
leading_dims = input.numel() // last_dim
277+
278+
# Reshape input to process as 1D vectors
279+
input_flat = input.view(leading_dims, last_dim)
280+
output = torch.empty_like(input)
281+
output_flat = output.view(leading_dims, last_dim)
282+
283+
output_inv_scale = 1.0 / output_scale
284+
285+
# Process each 1D vector
286+
for i in range(leading_dims):
287+
x = input_flat[i]
288+
289+
# Compute sum and squared sum in quantized space
290+
# Following the C++ implementation logic
291+
sum_val = torch.sum(x.to(torch.int32))
292+
sq_sum = last_dim * X_zero_point * X_zero_point + torch.sum(
293+
x.to(torch.int32) * x.to(torch.int32)
294+
)
295+
sq_sum -= 2 * sum_val * X_zero_point
296+
sum_val -= last_dim * X_zero_point
297+
298+
# Convert to floating point mean and variance
299+
mean = (X_scale * sum_val) / last_dim
300+
variance = (sq_sum * X_scale * X_scale) / last_dim - mean * mean
301+
inv_std = 1.0 / torch.sqrt(torch.tensor(variance + eps)) # type: ignore[arg-type]
302+
303+
# Apply layer norm: (x - mean) / std * weight + bias
304+
for j in range(last_dim):
305+
# Dequantize input value
306+
val = dequantize_per_tensor(
307+
x[j], X_scale, X_zero_point, -128, 127, torch.float32
308+
)
309+
310+
# Apply layer norm formula
311+
val = (val - mean) * inv_std * weight[j] + bias[j]
312+
313+
# Quantize result
314+
output_flat[i, j] = quantize_per_tensor(
315+
val, output_inv_scale, output_zero_point, -128, 127, input.dtype
316+
)
317+
318+
return output
319+
320+
242321
@impl(m, "requantize")
243322
def requantize(
244323
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
dequantize_per_tensor,
1616
quantize_per_tensor,
1717
quantized_add,
18+
quantized_layer_norm_per_tensor,
1819
quantized_linear,
1920
)
2021
from executorch.backends.cadence.aot.typing_stubs import expand
@@ -232,3 +233,97 @@ def test_quantized_linear(
232233
torch.equal(output, expected_output),
233234
f"Values don't match: got {output}, expected {expected_output}",
234235
)
236+
237+
@expand(
238+
[
239+
# Test case 1: Simple case with int8, zero mean input
240+
(
241+
torch.tensor(
242+
[[-1, 1]], dtype=torch.int8
243+
), # input: dequantized to [-0.1, 0.1]
244+
0.1, # X_scale
245+
0, # X_zero_point
246+
2, # normalized_shape (last dimension)
247+
torch.tensor([1.0, 1.0]), # weight
248+
torch.tensor([0.0, 0.0]), # bias
249+
1e-5, # eps
250+
0.1, # output_scale
251+
0, # output_zero_point
252+
torch.int8, # dtype
253+
torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output
254+
),
255+
# Test case 2: uint8 with zero_point offset
256+
(
257+
torch.tensor(
258+
[[127, 129]], dtype=torch.uint8
259+
), # input: dequantized to [-0.05, 0.05]
260+
0.05, # X_scale
261+
128, # X_zero_point
262+
2, # normalized_shape (last dimension)
263+
torch.tensor([1.0, 1.0]), # weight
264+
torch.tensor([0.0, 0.0]), # bias
265+
1e-5, # eps
266+
0.05, # output_scale
267+
128, # output_zero_point
268+
torch.uint8, # dtype
269+
torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output
270+
),
271+
# Test case 3: Test with weight and bias scaling
272+
(
273+
torch.tensor(
274+
[[-2, 2]], dtype=torch.int8
275+
), # input: dequantized to [-0.2, 0.2]
276+
0.1, # X_scale
277+
0, # X_zero_point
278+
2, # normalized_shape (last dimension)
279+
torch.tensor(
280+
[2.0, 0.5]
281+
), # weight: scale first element by 2, second by 0.5
282+
torch.tensor(
283+
[0.1, -0.1]
284+
), # bias: add 0.1 to first, subtract 0.1 from second
285+
1e-5, # eps
286+
0.1, # output_scale
287+
0, # output_zero_point
288+
torch.int8, # dtype
289+
torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output
290+
),
291+
]
292+
)
293+
def test_quantized_layer_norm_per_tensor(
294+
self,
295+
input_tensor: torch.Tensor,
296+
X_scale: float,
297+
X_zero_point: int,
298+
normalized_shape: int,
299+
weight: torch.Tensor,
300+
bias: torch.Tensor,
301+
eps: float,
302+
output_scale: float,
303+
output_zero_point: int,
304+
dtype: torch.dtype,
305+
expected_output: torch.Tensor,
306+
) -> None:
307+
output = quantized_layer_norm_per_tensor(
308+
input_tensor,
309+
X_scale,
310+
X_zero_point,
311+
normalized_shape,
312+
weight,
313+
bias,
314+
eps,
315+
output_scale,
316+
output_zero_point,
317+
)
318+
319+
# Verify output properties
320+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
321+
self.assertEqual(
322+
output.shape, input_tensor.shape, "Output shape should match input shape"
323+
)
324+
325+
# Verify output matches expected values
326+
self.assertTrue(
327+
torch.equal(output, expected_output),
328+
f"Output values don't match expected. Got {output}, expected {expected_output}",
329+
)

0 commit comments

Comments
 (0)