Skip to content

Commit 14d0745

Browse files
authored
Backend-agnostic implementation of quantized_layer_norm_per_tensor
Differential Revision: D81459333 Pull Request resolved: #13847
1 parent 30e7497 commit 14d0745

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,55 @@ def quantized_linear(
245245
).reshape(*leading_dims, N)
246246

247247

248+
@impl(m, "quantized_layer_norm_per_tensor")
249+
def quantized_layer_norm_per_tensor(
250+
input_tensor: torch.Tensor,
251+
X_scale: float,
252+
X_zero_point: int,
253+
normalized_shape: int,
254+
weight: torch.Tensor,
255+
bias: torch.Tensor,
256+
eps: float,
257+
output_scale: float,
258+
output_zero_point: int,
259+
) -> torch.Tensor:
260+
"""
261+
Quantized layer norm operation.
262+
263+
Args:
264+
- input_tensor (Tensor): The activations tensor
265+
- X_scale (float): The scale of the input
266+
- X_zero_point (int): The zero point of the input
267+
- normalized_shape (int): The shape of the input
268+
- weight (Tensor): The weight tensor
269+
- bias (Tensor): The bias tensor
270+
- eps (float): The epsilon value
271+
- output_scale (float): The scale of the output
272+
- output_zero_point (int): The zero point of the output
273+
"""
274+
supported_dtypes = [torch.int8, torch.uint8]
275+
if input_tensor.dtype not in supported_dtypes:
276+
raise ValueError(
277+
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
278+
)
279+
280+
float_input_tensor = dequantize_per_tensor(
281+
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
282+
)
283+
out = torch.nn.functional.layer_norm(
284+
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
285+
)
286+
287+
return quantize_per_tensor(
288+
out,
289+
1 / output_scale,
290+
output_zero_point,
291+
torch.iinfo(input_tensor.dtype).min,
292+
torch.iinfo(input_tensor.dtype).max,
293+
input_tensor.dtype,
294+
)
295+
296+
248297
@impl(m, "requantize")
249298
def requantize(
250299
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
dequantize_per_tensor,
1616
quantize_per_tensor,
1717
quantized_add,
18+
quantized_layer_norm_per_tensor,
1819
quantized_linear,
1920
)
2021
from executorch.backends.cadence.aot.typing_stubs import expand
@@ -240,3 +241,97 @@ def test_quantized_linear(
240241
torch.equal(output, expected_output),
241242
f"Values don't match: got {output}, expected {expected_output}",
242243
)
244+
245+
@expand(
246+
[
247+
# Test case 1: Simple case with int8, zero mean input
248+
(
249+
torch.tensor(
250+
[[-1, 1]], dtype=torch.int8
251+
), # input: dequantized to [-0.1, 0.1]
252+
0.1, # X_scale
253+
0, # X_zero_point
254+
2, # normalized_shape (last dimension)
255+
torch.tensor([1.0, 1.0]), # weight
256+
torch.tensor([0.0, 0.0]), # bias
257+
1e-5, # eps
258+
0.1, # output_scale
259+
0, # output_zero_point
260+
torch.int8, # dtype
261+
torch.tensor([[-10, 10]], dtype=torch.int8), # expected_output
262+
),
263+
# Test case 2: uint8 with zero_point offset
264+
(
265+
torch.tensor(
266+
[[127, 129]], dtype=torch.uint8
267+
), # input: dequantized to [-0.05, 0.05]
268+
0.05, # X_scale
269+
128, # X_zero_point
270+
2, # normalized_shape (last dimension)
271+
torch.tensor([1.0, 1.0]), # weight
272+
torch.tensor([0.0, 0.0]), # bias
273+
1e-5, # eps
274+
0.05, # output_scale
275+
128, # output_zero_point
276+
torch.uint8, # dtype
277+
torch.tensor([[108, 148]], dtype=torch.uint8), # expected_output
278+
),
279+
# Test case 3: Test with weight and bias scaling
280+
(
281+
torch.tensor(
282+
[[-2, 2]], dtype=torch.int8
283+
), # input: dequantized to [-0.2, 0.2]
284+
0.1, # X_scale
285+
0, # X_zero_point
286+
2, # normalized_shape (last dimension)
287+
torch.tensor(
288+
[2.0, 0.5]
289+
), # weight: scale first element by 2, second by 0.5
290+
torch.tensor(
291+
[0.1, -0.1]
292+
), # bias: add 0.1 to first, subtract 0.1 from second
293+
1e-5, # eps
294+
0.1, # output_scale
295+
0, # output_zero_point
296+
torch.int8, # dtype
297+
torch.tensor([[-19, 4]], dtype=torch.int8), # expected_output
298+
),
299+
]
300+
)
301+
def test_quantized_layer_norm_per_tensor(
302+
self,
303+
input_tensor: torch.Tensor,
304+
X_scale: float,
305+
X_zero_point: int,
306+
normalized_shape: int,
307+
weight: torch.Tensor,
308+
bias: torch.Tensor,
309+
eps: float,
310+
output_scale: float,
311+
output_zero_point: int,
312+
dtype: torch.dtype,
313+
expected_output: torch.Tensor,
314+
) -> None:
315+
output = quantized_layer_norm_per_tensor(
316+
input_tensor,
317+
X_scale,
318+
X_zero_point,
319+
normalized_shape,
320+
weight,
321+
bias,
322+
eps,
323+
output_scale,
324+
output_zero_point,
325+
)
326+
327+
# Verify output properties
328+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
329+
self.assertEqual(
330+
output.shape, input_tensor.shape, "Output shape should match input shape"
331+
)
332+
333+
# Verify output matches expected values
334+
self.assertTrue(
335+
torch.equal(output, expected_output),
336+
f"Output values don't match expected. Got {output}, expected {expected_output}",
337+
)

0 commit comments

Comments
 (0)