Skip to content

Commit cc14b6e

Browse files
committed
update
1 parent bb88bb6 commit cc14b6e

File tree

7 files changed

+221
-84
lines changed

7 files changed

+221
-84
lines changed

examples/qualcomm/custom_op/custom_ops_fast_gelu.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,47 @@ def fast_gelu_impl(x: torch.Tensor) -> torch.Tensor:
5050

5151

5252
# registering the out variant.
53-
my_op_lib.define(
54-
"fast_gelu.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)"
55-
) # should print 'fast_gelu.out'
56-
57-
58-
# ------------------------------------------------------------------------------
59-
# 2. Simple model using custom op
60-
# ------------------------------------------------------------------------------
53+
my_op_lib.define("fast_gelu.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")
6154

6255

6356
class Model(torch.nn.Module):
6457
def forward(self, a):
6558
return torch.ops.my_ops.fast_gelu.default(a)
6659

6760

68-
# ------------------------------------------------------------------------------
69-
# 3. Build + register custom op package
70-
# ------------------------------------------------------------------------------
61+
def annotate_custom(gm: torch.fx.GraphModule) -> None:
62+
"""
63+
This function is specific for custom op.
64+
The source_fn of the rewritten nn module turns out to be "my_ops.fast_gelu.default"
65+
"""
66+
from executorch.backends.qualcomm.quantizer.annotators import _is_annotated
67+
from executorch.backends.qualcomm.quantizer.qconfig import (
68+
get_ptq_per_channel_quant_config,
69+
)
70+
from torch.fx import Node
71+
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
72+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
73+
74+
quantization_config = get_ptq_per_channel_quant_config()
75+
for node in gm.graph.nodes:
76+
if node.target != torch.ops.my_ops.fast_gelu.default:
77+
continue
78+
79+
# skip annotation if it is already annotated
80+
if _is_annotated([node]):
81+
continue
82+
83+
input_qspec_map = {}
84+
input_act = node.args[0]
85+
assert isinstance(input_act, Node)
86+
input_spec = quantization_config.input_activation
87+
input_qspec_map[input_act] = input_spec
88+
89+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
90+
input_qspec_map=input_qspec_map,
91+
output_qspec=quantization_config.output_activation,
92+
_annotated=True,
93+
)
7194

7295

7396
def _run(cmd, cwd=None):
@@ -135,11 +158,6 @@ def prepare_op_package(
135158
return op_package_options, op_package_paths
136159

137160

138-
# ------------------------------------------------------------------------------
139-
# 4. Entrypoint — same pattern as custom_ops_1.py
140-
# ------------------------------------------------------------------------------
141-
142-
143161
def main(args):
144162
if args.build_op_package:
145163
if "HEXAGON_SDK_ROOT" not in os.environ:
@@ -158,7 +176,7 @@ def main(args):
158176
quant_dtype = None
159177

160178
instance = Model()
161-
sample_input = (torch.randn(1, 128),)
179+
sample_input = (torch.randn(1, 16384),)
162180
pte_filename = "fastgelu_model"
163181
workspace = f"/data/local/tmp/executorch/{pte_filename}"
164182
soc_info: SocInfo = _soc_info_table[getattr(QcomChipset, args.model)]
@@ -169,9 +187,14 @@ def main(args):
169187
soc_info.htp_info.htp_arch,
170188
args.build_op_package,
171189
)
172-
# quantizer = make_quantizer(
173-
# quant_dtype=quant_dtype, custom_annotations=(annotate_custom,)
174-
# )
190+
quant_dtype: Literal[QuantDtype.use_16a16w] = QuantDtype.use_8a8w
191+
if args.use_fp16:
192+
quant_dtype = None
193+
quantizer = None
194+
if not args.use_fp16:
195+
quantizer = make_quantizer(
196+
quant_dtype=quant_dtype, custom_annotations=(annotate_custom,)
197+
)
175198

176199
build_executorch_binary(
177200
instance,
@@ -180,8 +203,8 @@ def main(args):
180203
f"{args.artifact}/{pte_filename}",
181204
sample_input,
182205
op_package_options=op_package_options,
183-
# quant_dtype=quant_dtype,
184-
# custom_quantizer=quantizer,
206+
quant_dtype=quant_dtype,
207+
custom_quantizer=quantizer,
185208
)
186209

187210
if args.compile_only:
@@ -203,6 +226,7 @@ def main(args):
203226
adb.pull(output_path=args.artifact)
204227

205228
# Compare results
229+
model = Model()
206230
x86_golden = model(*sample_input)
207231
import numpy as np
208232

@@ -211,10 +235,14 @@ def main(args):
211235
os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32
212236
)
213237
).reshape(x86_golden.size())
238+
result = torch.all(torch.isclose(x86_golden, device_output, atol=1e-2)).item()
214239
print(
215240
"is_close?",
216-
torch.all(torch.isclose(x86_golden, device_output, atol=1e-2)).item(),
241+
result,
217242
)
243+
if not result:
244+
print(f"x86_golden {x86_golden}")
245+
print(f"device_out {device_output}")
218246

219247

220248
if __name__ == "__main__":

examples/qualcomm/custom_op/fastgelu_op_package_htp/FastGeluOpPackage/Makefile

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,20 @@ $(info "HEXAGON_SDK_ROOT is [${HEXAGON_SDK_ROOT}]")
3535
HEXAGON_SDK_ROOT_V68 := $(HEXAGON_SDK_BASE)/hexagon-sdk-4.2.0
3636
HEXAGON_SDK_ROOT_V69 := $(HEXAGON_SDK_BASE)/hexagon-sdk-4.3.0
3737
HEXAGON_SDK_ROOT_V73 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
38-
# HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
39-
# HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.0.0
40-
HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)
41-
HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)
38+
HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
39+
HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.0.0
4240
HEXAGON_SDK_ROOT_V81 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.2.0
4341
#Updated to point to latest sdk to match with libQnnHtp.so
44-
# HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_ROOT_V81)
45-
HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_BASE)
42+
HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_ROOT_V81)
4643

4744
HEXAGON_TOOLS_VERSION_V68 := 8.4.09
4845
HEXAGON_TOOLS_VERSION_V69 := 8.5.03
4946
HEXAGON_TOOLS_VERSION_V73 := 8.6.02
5047
HEXAGON_TOOLS_VERSION_V75 := 8.7.03
51-
# HEXAGON_TOOLS_VERSION_V79 := 8.8.02
52-
HEXAGON_TOOLS_VERSION_V79 := 8.8.06
48+
HEXAGON_TOOLS_VERSION_V79 := 8.8.02
5349
HEXAGON_TOOLS_VERSION_V81 := 19.0.01
5450
#Updated to point to latest sdk to match with libQnnHtp.so
55-
# HEXAGON_TOOLS_VERSION_X86 := 19.0.01
56-
HEXAGON_TOOLS_VERSION_X86 := 8.8.06
51+
HEXAGON_TOOLS_VERSION_X86 := 19.0.01
5752

5853
ifndef ANDROID_NDK_ROOT
5954
ifeq ($(MAKECMDGOALS),htp_aarch64)

examples/qualcomm/custom_op/fastgelu_op_package_htp/FastGeluOpPackage/src/ops/FastGelu.cpp

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Auto Generated Code for FastGeluOpPackage
33
//==============================================================================
44

5+
#include <algorithm>
56
#include <cmath>
67
#include "HTP/core/constraints.h"
78
#include "HTP/core/op_package_feature_support.h"
@@ -80,27 +81,179 @@ DEF_PACKAGE_OP((fastgeluImpl<Tensor>), "FastGelu")
8081

8182
/* execute functions for ops */
8283

84+
// template <typename TensorType>
85+
// GraphStatus fastgeluImpl(TensorType& y, const TensorType& x) {
86+
// const uint32_t numElements = x.total_storage_elements();
87+
88+
// if (y.total_storage_elements() != numElements) {
89+
// return GraphStatus::ErrorFatal;
90+
// }
91+
92+
// const float kAlpha = 0.7978845608f; // sqrt(2/pi)
93+
// const float kCoeff = 0.044715f;
94+
95+
// float* yData = reinterpret_cast<float*>(y.raw_data());
96+
// const float* xData = reinterpret_cast<const float*>(x.raw_data_const());
97+
98+
// for (uint32_t i = 0; i < numElements; ++i) {
99+
// const float v = xData[i];
100+
// const float inner = kAlpha * (v + kCoeff * v * v * v);
101+
// yData[i] = 0.5f * v * (1.0f + std::tanh(inner));
102+
// }
103+
104+
// return GraphStatus::Success;
105+
// }
106+
83107
template <typename TensorType>
84108
GraphStatus fastgeluImpl(TensorType& y, const TensorType& x) {
85-
const uint32_t numElements = x.total_storage_elements();
109+
const uint32_t N = x.total_storage_elements();
86110

87-
if (y.total_storage_elements() != numElements) {
111+
if (y.total_storage_elements() != N) {
88112
return GraphStatus::ErrorFatal;
89113
}
90114

91-
const float kAlpha = 0.7978845608f; // sqrt(2/pi)
92-
const float kCoeff = 0.044715f;
115+
const auto in_info = x.get_dtype_intfc();
116+
const auto out_info = y.get_dtype_intfc();
93117

94-
float* yData = reinterpret_cast<float*>(y.raw_data());
95-
const float* xData = reinterpret_cast<const float*>(x.raw_data_const());
96-
97-
for (uint32_t i = 0; i < numElements; ++i) {
98-
const float v = xData[i];
99-
const float inner = kAlpha * (v + kCoeff * v * v * v);
100-
yData[i] = 0.5f * v * (1.0f + std::tanh(inner));
118+
if (in_info.dtype != DType::Float32 || in_info.dtype != DType::QUInt8) {
119+
return GraphStatus::ErrorPrecision;
101120
}
121+
if (in_info.dtype == DType::Float32 && out_info.dtype == DType::Float32) {
122+
const float* xData = static_cast<const float*>(x.raw_data_const());
123+
float* yData = static_cast<float*>(y.raw_data());
124+
125+
// --- Temporary FP16 buffers ---
126+
std::vector<Float16> tmp_in(N);
127+
std::vector<Float16> tmp_out(N);
128+
129+
for (uint32_t i = 0; i < N; ++i) {
130+
tmp_in[i] = static_cast<Float16>(xData[i]);
131+
}
132+
133+
#ifdef __hexagon__
134+
union {
135+
Float16 f;
136+
uint16_t b;
137+
} kAlpha = {(Float16)0.7978845608f}; // sqrt(2/pi)
138+
union {
139+
Float16 f;
140+
uint16_t b;
141+
} kCoeff = {(Float16)0.044715f};
142+
union {
143+
Float16 f;
144+
uint16_t b;
145+
} kHalf = {(Float16)0.5f};
146+
union {
147+
Float16 f;
148+
uint16_t b;
149+
} kOne = {(Float16)1.0f};
150+
union {
151+
Float16 f;
152+
uint16_t b;
153+
} k27 = {(Float16)27.0f};
154+
union {
155+
Float16 f;
156+
uint16_t b;
157+
} kInv27 = {(Float16)(1.0f / 27.0f)};
158+
union {
159+
Float16 f;
160+
uint16_t b;
161+
} kOne3 = {(Float16)(1.0f / 3.0f)};
162+
union {
163+
Float16 f;
164+
uint16_t b;
165+
} kOne9 = {(Float16)(1.0f / 9.0f)};
166+
167+
HVX_Vector v_alpha = Q6_Vh_vsplat_R(kAlpha.b);
168+
HVX_Vector v_coeff = Q6_Vh_vsplat_R(kCoeff.b);
169+
HVX_Vector v_half = Q6_Vh_vsplat_R(kHalf.b);
170+
HVX_Vector v_one = Q6_Vh_vsplat_R(kOne.b);
171+
HVX_Vector v_27 = Q6_Vh_vsplat_R(k27.b);
172+
HVX_Vector v_inv27 = Q6_Vh_vsplat_R(kInv27.b);
173+
HVX_Vector v_1_3 = Q6_Vh_vsplat_R(kOne3.b);
174+
HVX_Vector v_1_9 = Q6_Vh_vsplat_R(kOne9.b);
175+
176+
const int VBYTES = 128;
177+
const int ELEMS = VBYTES / sizeof(Float16); // 64
102178

103-
return GraphStatus::Success;
179+
for (uint32_t i = 0; i < N; i += ELEMS) {
180+
HVX_Vector vx = q6op_V_vldu_A(&tmp_in[i]); // x
181+
HVX_Vector vx2 = Q6_Vhf_vmpy_VhfVhf(vx, vx); // x^2
182+
HVX_Vector vx3 = Q6_Vhf_vmpy_VhfVhf(vx2, vx); // x^3
183+
184+
// z = α * (x + c*x^3)
185+
HVX_Vector vcx3 = Q6_Vhf_vmpy_VhfVhf(vx3, v_coeff);
186+
HVX_Vector vsum = Q6_Vhf_vadd_VhfVhf(vx, vcx3);
187+
HVX_Vector vz = Q6_Vhf_vmpy_VhfVhf(vsum, v_alpha);
188+
189+
// z^2, z^4
190+
HVX_Vector vz2 = Q6_Vhf_vmpy_VhfVhf(vz, vz);
191+
HVX_Vector vz4 = Q6_Vhf_vmpy_VhfVhf(vz2, vz2);
192+
193+
// inv_den ≈ (1/27) * (1 - (1/3) z^2 + (1/9) z^4)
194+
HVX_Vector term1 = Q6_Vhf_vmpy_VhfVhf(vz2, v_1_3); // (1/3) z^2
195+
HVX_Vector one_m_t = Q6_Vhf_vsub_VhfVhf(v_one, term1); // 1 - (1/3) z^2
196+
HVX_Vector term2 = Q6_Vhf_vmpy_VhfVhf(vz4, v_1_9); // (1/9) z^4
197+
HVX_Vector poly =
198+
Q6_Vhf_vadd_VhfVhf(one_m_t, term2); // 1 - 1/3 z^2 + 1/9 z^4
199+
HVX_Vector inv_den = Q6_Vhf_vmpy_VhfVhf(poly, v_inv27); // * (1/27)
200+
201+
// num = z * (27 + z^2) = 27z + z^3
202+
HVX_Vector z3 = Q6_Vhf_vmpy_VhfVhf(vz2, vz);
203+
HVX_Vector t27z = Q6_Vhf_vmpy_VhfVhf(vz, v_27);
204+
HVX_Vector num = Q6_Vhf_vadd_VhfVhf(t27z, z3);
205+
206+
// tanh(z) ≈ num * inv_den
207+
HVX_Vector vtanh = Q6_Vhf_vmpy_VhfVhf(num, inv_den);
208+
209+
// y = 0.5 * x * (1 + tanh)
210+
HVX_Vector one_plus_tanh = Q6_Vhf_vadd_VhfVhf(v_one, vtanh);
211+
HVX_Vector t = Q6_Vhf_vmpy_VhfVhf(vx, one_plus_tanh);
212+
HVX_Vector vy = Q6_Vhf_vmpy_VhfVhf(t, v_half);
213+
214+
q6op_vstu_AV(&tmp_out[i], vy);
215+
}
216+
#else
217+
// Scalar fallback
218+
for (uint32_t i = 0; i < N; ++i) {
219+
const float v = xData[i];
220+
const float inner = 0.7978845608f * (v + 0.044715f * v * v * v);
221+
yData[i] = 0.5f * v * (1.0f + std::tanh(inner));
222+
}
223+
#endif
224+
225+
for (uint32_t i = 0; i < N; ++i) {
226+
yData[i] = static_cast<float>(tmp_out[i]);
227+
}
228+
return GraphStatus::Success;
229+
} else if (in_info.dtype == DType::QUInt8) {
230+
const uint8_t* xData = static_cast<const uint8_t*>(x.raw_data_const());
231+
uint8_t* yData = static_cast<uint8_t*>(y.raw_data());
232+
233+
const float x_scale = in_info.scale;
234+
const float y_scale = out_info.scale;
235+
const int32_t x_zero = in_info.offset;
236+
const int32_t y_zero = out_info.offset;
237+
238+
alignas(128) static uint8_t lut[256];
239+
static bool lut_init = false;
240+
if (!lut_init) {
241+
for (int i = 0; i < 256; ++i) {
242+
float x_f = (i - x_zero) * x_scale;
243+
float inner = 0.7978845608f * (x_f + 0.044715f * x_f * x_f * x_f);
244+
float y_f = 0.5f * x_f * (1.0f + std::tanh(inner));
245+
int y_q = static_cast<int>(std::round(y_f / y_scale)) + y_zero;
246+
lut[i] = static_cast<uint8_t>(std::clamp(y_q, 0, 255));
247+
}
248+
lut_init = true;
249+
}
250+
for (uint32_t i = 0; i < N; ++i) {
251+
yData[i] = lut[xData[i]];
252+
}
253+
return GraphStatus::Success;
254+
} else {
255+
return GraphStatus::ErrorFatal;
256+
}
104257
}
105258

106259
__attribute__((unused)) static float fastgeluCostFunc(const Op* op) {

examples/qualcomm/custom_op/fastgelu_op_package_htp/FastGeluOpPackage_old/Makefile

Whitespace-only changes.

0 commit comments

Comments
 (0)