Skip to content

Commit b34f13c

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Backend-agnostic implementation of quantized_layer_norm_per_tensor (pytorch#13847)
Summary: Continuing support for supporting backend-agnostic Cadence custom ops. Differential Revision: D81459333
1 parent 0489aa7 commit b34f13c

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,85 @@ def quantized_linear(
241241
return out.reshape(*leading_dims, N)
242242

243243

244+
@impl(m, "quantized_layer_norm_per_tensor")
245+
def quantized_layer_norm_per_tensor(
246+
input: torch.Tensor,
247+
X_scale: float,
248+
X_zero_point: int,
249+
normalized_shape: int,
250+
weight: torch.Tensor,
251+
bias: torch.Tensor,
252+
eps: float,
253+
output_scale: float,
254+
output_zero_point: int,
255+
) -> torch.Tensor:
256+
"""
257+
Quantized layer norm operation.
258+
259+
Args:
260+
- input (Tensor): The activations tensor
261+
- X_scale (float): The scale of the input
262+
- X_zero_point (int): The zero point of the input
263+
- normalized_shape (int): The shape of the input (unused)
264+
- weight (Tensor): The weight tensor
265+
- bias (Tensor): The bias tensor
266+
- eps (float): The epsilon value
267+
- output_scale (float): The scale of the output
268+
- output_zero_point (int): The zero point of the output
269+
"""
270+
supported_dtypes = [torch.int8, torch.uint8]
271+
if input.dtype not in supported_dtypes:
272+
raise ValueError(
273+
f"Input dtype must be one of {supported_dtypes}. Got {input.dtype}"
274+
)
275+
276+
# Get dimensions
277+
last_dim = input.size(-1)
278+
leading_dims = input.numel() // last_dim
279+
280+
# Reshape input to process as 1D vectors
281+
input_flat = input.view(leading_dims, last_dim)
282+
output = torch.empty_like(input)
283+
output_flat = output.view(leading_dims, last_dim)
284+
285+
output_inv_scale = 1.0 / output_scale
286+
287+
# Process each 1D vector
288+
for i in range(leading_dims):
289+
x = input_flat[i]
290+
291+
# Compute sum and squared sum in quantized space
292+
# Following the C++ implementation logic
293+
sum_val = torch.sum(x.to(torch.int32))
294+
sq_sum = last_dim * X_zero_point * X_zero_point + torch.sum(
295+
x.to(torch.int32) * x.to(torch.int32)
296+
)
297+
sq_sum -= 2 * sum_val * X_zero_point
298+
sum_val -= last_dim * X_zero_point
299+
300+
# Convert to floating point mean and variance
301+
mean = (X_scale * sum_val) / last_dim
302+
variance = (sq_sum * X_scale * X_scale) / last_dim - mean * mean
303+
inv_std = 1.0 / torch.sqrt(torch.tensor(variance + eps)) # type: ignore[arg-type]
304+
305+
# Apply layer norm: (x - mean) / std * weight + bias
306+
for j in range(last_dim):
307+
# Dequantize input value
308+
val = dequantize_per_tensor(
309+
x[j], X_scale, X_zero_point, -128, 127, torch.float32
310+
)
311+
312+
# Apply layer norm formula
313+
val = (val - mean) * inv_std * weight[j] + bias[j]
314+
315+
# Quantize result
316+
output_flat[i, j] = quantize_per_tensor(
317+
val, output_inv_scale, output_zero_point, -128, 127, input.dtype
318+
)
319+
320+
return output
321+
322+
244323
@impl(m, "requantize")
245324
def requantize(
246325
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
dequantize_per_tensor,
1616
quantize_per_tensor,
1717
quantized_add,
18+
quantized_layer_norm_per_tensor,
1819
quantized_linear,
1920
)
2021
from executorch.backends.cadence.aot.typing_stubs import expand
@@ -232,3 +233,97 @@ def test_quantized_linear(
232233
torch.equal(output, expected_output),
233234
f"Values don't match: got {output}, expected {expected_output}",
234235
)
236+
237+
@expand(
238+
[
239+
# Test case 1: Simple case with int8, zero mean input
240+
(
241+
torch.tensor(
242+
[[-1, 1]], dtype=torch.int8
243+
), # input: dequantized to [-0.1, 0.1]
244+
0.1, # X_scale
245+
0, # X_zero_point
246+
2, # normalized_shape (last dimension)
247+
torch.tensor([1.0, 1.0]), # weight
248+
torch.tensor([0.0, 0.0]), # bias
249+
1e-5, # eps
250+
0.1, # output_scale
251+
0, # output_zero_point
252+
torch.int8, # dtype
253+
torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output
254+
),
255+
# Test case 2: uint8 with zero_point offset
256+
(
257+
torch.tensor(
258+
[[127, 129]], dtype=torch.uint8
259+
), # input: dequantized to [-0.05, 0.05]
260+
0.05, # X_scale
261+
128, # X_zero_point
262+
2, # normalized_shape (last dimension)
263+
torch.tensor([1.0, 1.0]), # weight
264+
torch.tensor([0.0, 0.0]), # bias
265+
1e-5, # eps
266+
0.05, # output_scale
267+
128, # output_zero_point
268+
torch.uint8, # dtype
269+
torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output
270+
),
271+
# Test case 3: Test with weight and bias scaling
272+
(
273+
torch.tensor(
274+
[[-2, 2]], dtype=torch.int8
275+
), # input: dequantized to [-0.2, 0.2]
276+
0.1, # X_scale
277+
0, # X_zero_point
278+
2, # normalized_shape (last dimension)
279+
torch.tensor(
280+
[2.0, 0.5]
281+
), # weight: scale first element by 2, second by 0.5
282+
torch.tensor(
283+
[0.1, -0.1]
284+
), # bias: add 0.1 to first, subtract 0.1 from second
285+
1e-5, # eps
286+
0.1, # output_scale
287+
0, # output_zero_point
288+
torch.int8, # dtype
289+
torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output
290+
),
291+
]
292+
)
293+
def test_quantized_layer_norm_per_tensor(
294+
self,
295+
input_tensor: torch.Tensor,
296+
X_scale: float,
297+
X_zero_point: int,
298+
normalized_shape: int,
299+
weight: torch.Tensor,
300+
bias: torch.Tensor,
301+
eps: float,
302+
output_scale: float,
303+
output_zero_point: int,
304+
dtype: torch.dtype,
305+
expected_output: torch.Tensor,
306+
) -> None:
307+
output = quantized_layer_norm_per_tensor(
308+
input_tensor,
309+
X_scale,
310+
X_zero_point,
311+
normalized_shape,
312+
weight,
313+
bias,
314+
eps,
315+
output_scale,
316+
output_zero_point,
317+
)
318+
319+
# Verify output properties
320+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
321+
self.assertEqual(
322+
output.shape, input_tensor.shape, "Output shape should match input shape"
323+
)
324+
325+
# Verify output matches expected values
326+
self.assertTrue(
327+
torch.equal(output, expected_output),
328+
f"Output values don't match expected. Got {output}, expected {expected_output}",
329+
)

0 commit comments

Comments
 (0)