Skip to content

Commit 1b12971

Browse files
authored
migrate quant layer norm hifi ops to oss
Differential Revision: D64780546 Pull Request resolved: #6479
1 parent 80807fd commit 1b12971

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \
6+
_(uint8_t, Byte) \
7+
_(int8_t, Char)

backends/cadence/hifi/operators/quantized_layer_norm.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
*/
88

99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/backends/cadence/hifi/operators/operators.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112
#include <algorithm>
1213
#include <cmath>
1314
#include <tuple>
1415

15-
using executorch::aten::Tensor;
16-
using executorch::runtime::getLeadingDims;
17-
using executorch::runtime::KernelRuntimeContext;
16+
using ::executorch::aten::IntArrayRef;
17+
using ::executorch::aten::ScalarType;
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::getLeadingDims;
20+
using ::executorch::runtime::KernelRuntimeContext;
1821

1922
namespace cadence {
2023
namespace impl {
@@ -77,10 +80,10 @@ void quantized_layer_norm_(
7780
for (size_t j = 0; j < last_dim; ++j) {
7881
// Since X is quantized, we dequantize it, compute fp32 result, and
7982
// quantize the result to an int8/uint8 value.
80-
float val = cadence::impl::HiFi::kernels::dequantize<T>(
83+
float val = ::cadence::impl::HiFi::kernels::dequantize<T>(
8184
x[j], input_scale, input_zero_point);
8285
val = (val - mean) * inv_std * weight_data[j] + bias_data[j];
83-
y[j] = cadence::impl::HiFi::kernels::quantize<T>(
86+
y[j] = ::cadence::impl::HiFi::kernels::quantize<T>(
8487
val, output_inv_scale, output_zero_point);
8588
}
8689
}
@@ -121,38 +124,37 @@ void quantized_layer_norm_out(
121124
const Tensor& input,
122125
const Tensor& in_scale,
123126
const Tensor& in_zero_point,
124-
const executorch::aten::IntArrayRef normalized_shape,
127+
__ET_UNUSED const IntArrayRef normalized_shape,
125128
const Tensor& weight,
126129
const Tensor& bias,
127130
double eps,
128131
double output_scale,
129132
int64_t output_zero_point,
130133
Tensor& out) {
131-
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
132-
quantized_layer_norm_<uint8_t>(
133-
input,
134-
in_scale,
135-
in_zero_point,
136-
weight,
137-
bias,
138-
eps,
139-
output_scale,
140-
output_zero_point,
141-
out);
142-
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
143-
quantized_layer_norm_<int8_t>(
144-
input,
145-
in_scale,
146-
in_zero_point,
147-
weight,
148-
bias,
149-
eps,
150-
output_scale,
151-
output_zero_point,
152-
out);
153-
} else {
154-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
134+
#define typed_quantized_layer_norm(ctype, dtype) \
135+
case ScalarType::dtype: { \
136+
quantized_layer_norm_<ctype>( \
137+
input, \
138+
in_scale, \
139+
in_zero_point, \
140+
weight, \
141+
bias, \
142+
eps, \
143+
output_scale, \
144+
output_zero_point, \
145+
out); \
146+
break; \
155147
}
148+
149+
ScalarType dtype = input.scalar_type();
150+
switch (dtype) {
151+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm)
152+
default:
153+
ET_DCHECK_MSG(
154+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
155+
}
156+
157+
#undef typed_quantized_layer_norm
156158
}
157159

158160
}; // namespace native

backends/cadence/hifi/operators/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_common_targets():
1515
srcs = glob([
1616
"*.cpp",
1717
]),
18+
exported_headers = glob(["*.h"]),
1819
platforms = CXX,
1920
deps = [
2021
"//executorch/kernels/portable/cpu/util:all_deps",

0 commit comments

Comments
 (0)