Skip to content

Commit c3b6b02

Browse files
committed
working fp gelu
1 parent f81d768 commit c3b6b02

File tree

11 files changed

+1209
-18
lines changed

11 files changed

+1209
-18
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
"""
7+
Example: Custom FastGELU operator integrated with ExecuTorch Qualcomm backend (HTP).
8+
"""
9+
10+
import json
11+
import os
12+
import subprocess
13+
import sys
14+
from multiprocessing.connection import Client
15+
16+
import numpy as np
17+
import torch
18+
19+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
20+
from executorch.backends.qualcomm.serialization.qc_schema import (
21+
_soc_info_table,
22+
HtpArch,
23+
QcomChipset,
24+
QnnExecuTorchOpPackageInfo,
25+
QnnExecuTorchOpPackageOptions,
26+
QnnExecuTorchOpPackagePlatform,
27+
QnnExecuTorchOpPackageTarget,
28+
)
29+
from executorch.examples.qualcomm.utils import (
30+
build_executorch_binary,
31+
generate_inputs,
32+
make_output_dir,
33+
make_quantizer,
34+
setup_common_args_and_variables,
35+
SimpleADB,
36+
)
37+
from torch.library import impl, Library
38+
39+
# ------------------------------------------------------------------------------
40+
# 1. Register PyTorch custom operator (FastGELU)
41+
# ------------------------------------------------------------------------------
42+
43+
my_op_lib = Library("my_ops", "DEF")
44+
my_op_lib.define("fast_gelu(Tensor input) -> Tensor")
45+
46+
47+
@impl(my_op_lib, "fast_gelu", "CompositeExplicitAutograd")
48+
def fast_gelu_impl(x: torch.Tensor) -> torch.Tensor:
49+
return 0.5 * x * (1.0 + torch.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
50+
51+
52+
# registering the out variant.
53+
my_op_lib.define(
54+
"fast_gelu.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)"
55+
) # should print 'fast_gelu.out'
56+
57+
58+
# ------------------------------------------------------------------------------
59+
# 2. Simple model using custom op
60+
# ------------------------------------------------------------------------------
61+
62+
63+
class Model(torch.nn.Module):
64+
def forward(self, a):
65+
return torch.ops.my_ops.fast_gelu.default(a)
66+
67+
68+
# ------------------------------------------------------------------------------
69+
# 3. Build + register custom op package
70+
# ------------------------------------------------------------------------------
71+
72+
73+
def _run(cmd, cwd=None):
74+
subprocess.run(cmd, stdout=sys.stdout, cwd=cwd, check=True)
75+
76+
77+
def prepare_op_package(
78+
workspace: str, op_package_dir: str, arch: HtpArch, build_op_package: bool
79+
):
80+
if build_op_package:
81+
_run(["rm", "-rf", "build"], cwd=op_package_dir)
82+
_run(["make", "htp_x86", "htp_aarch64", f"htp_v{arch}"], cwd=op_package_dir)
83+
_run(
84+
[
85+
"cp",
86+
f"{op_package_dir}/build/hexagon-v{arch}/libQnnFastGeluOpPackage.so",
87+
f"{op_package_dir}/build/hexagon-v{arch}/libQnnFastGeluOpPackage_HTP.so",
88+
]
89+
)
90+
91+
op_package_paths = [
92+
f"{op_package_dir}/build/hexagon-v{arch}/libQnnFastGeluOpPackage_HTP.so",
93+
f"{op_package_dir}/build/aarch64-android/libQnnFastGeluOpPackage.so",
94+
]
95+
96+
op_package_infos_HTP = QnnExecuTorchOpPackageInfo()
97+
op_package_infos_HTP.interface_provider = "FastGeluOpPackageInterfaceProvider"
98+
op_package_infos_HTP.op_package_name = "FastGeluOpPackage"
99+
op_package_infos_HTP.op_package_path = f"{workspace}/libQnnFastGeluOpPackage_HTP.so"
100+
op_package_infos_HTP.target = QnnExecuTorchOpPackageTarget.HTP
101+
op_package_infos_HTP.custom_op_name = "my_ops.fast_gelu.default"
102+
op_package_infos_HTP.qnn_op_type_name = "FastGelu"
103+
op_package_infos_HTP.platform = QnnExecuTorchOpPackagePlatform.AARCH64_ANDROID
104+
op_package_infos_aarch64_CPU = QnnExecuTorchOpPackageInfo()
105+
op_package_infos_aarch64_CPU.interface_provider = (
106+
"FastGeluOpPackageInterfaceProvider"
107+
)
108+
op_package_infos_aarch64_CPU.op_package_name = "FastGeluOpPackage"
109+
op_package_infos_aarch64_CPU.op_package_path = (
110+
f"{workspace}/libQnnFastGeluOpPackage.so"
111+
)
112+
op_package_infos_aarch64_CPU.target = QnnExecuTorchOpPackageTarget.CPU
113+
op_package_infos_aarch64_CPU.custom_op_name = "my_ops.fast_gelu.default"
114+
op_package_infos_aarch64_CPU.qnn_op_type_name = "FastGelu"
115+
op_package_infos_aarch64_CPU.platform = (
116+
QnnExecuTorchOpPackagePlatform.AARCH64_ANDROID
117+
)
118+
op_package_infos_x86_CPU = QnnExecuTorchOpPackageInfo()
119+
op_package_infos_x86_CPU.interface_provider = "FastGeluOpPackageInterfaceProvider"
120+
op_package_infos_x86_CPU.op_package_name = "FastGeluOpPackage"
121+
op_package_infos_x86_CPU.op_package_path = (
122+
f"{op_package_dir}/build/x86_64-linux-clang/libQnnFastGeluOpPackage.so"
123+
)
124+
op_package_infos_x86_CPU.target = QnnExecuTorchOpPackageTarget.CPU
125+
op_package_infos_x86_CPU.custom_op_name = "my_ops.fast_gelu.default"
126+
op_package_infos_x86_CPU.qnn_op_type_name = "FastGelu"
127+
op_package_infos_x86_CPU.platform = QnnExecuTorchOpPackagePlatform.X86_64
128+
op_package_options = QnnExecuTorchOpPackageOptions()
129+
op_package_options.op_package_infos = [
130+
op_package_infos_x86_CPU,
131+
op_package_infos_aarch64_CPU,
132+
op_package_infos_HTP,
133+
]
134+
135+
return op_package_options, op_package_paths
136+
137+
138+
# ------------------------------------------------------------------------------
139+
# 4. Entrypoint — same pattern as custom_ops_1.py
140+
# ------------------------------------------------------------------------------
141+
142+
143+
def main(args):
144+
if args.build_op_package:
145+
if "HEXAGON_SDK_ROOT" not in os.environ:
146+
raise RuntimeError("Environment variable HEXAGON_SDK_ROOT must be set")
147+
print(f"HEXAGON_SDK_ROOT={os.getenv('HEXAGON_SDK_ROOT')}")
148+
149+
if "ANDROID_NDK_ROOT" not in os.environ:
150+
raise RuntimeError("Environment variable ANDROID_NDK_ROOT must be set")
151+
print(f"ANDROID_NDK_ROOT={os.getenv('ANDROID_NDK_ROOT')}")
152+
153+
# ensure the working directory exist.
154+
os.makedirs(args.artifact, exist_ok=True)
155+
156+
quant_dtype = QuantDtype.use_8a8w
157+
if args.use_fp16:
158+
quant_dtype = None
159+
160+
instance = Model()
161+
sample_input = (torch.randn(1, 128),)
162+
pte_filename = "fastgelu_model"
163+
workspace = f"/data/local/tmp/executorch/{pte_filename}"
164+
soc_info: SocInfo = _soc_info_table[getattr(QcomChipset, args.model)]
165+
166+
op_package_options, op_package_paths = prepare_op_package(
167+
workspace,
168+
args.op_package_dir,
169+
soc_info.htp_info.htp_arch,
170+
args.build_op_package,
171+
)
172+
# quantizer = make_quantizer(
173+
# quant_dtype=quant_dtype, custom_annotations=(annotate_custom,)
174+
# )
175+
176+
build_executorch_binary(
177+
instance,
178+
sample_input,
179+
args.model,
180+
f"{args.artifact}/{pte_filename}",
181+
sample_input,
182+
op_package_options=op_package_options,
183+
# quant_dtype=quant_dtype,
184+
# custom_quantizer=quantizer,
185+
)
186+
187+
if args.compile_only:
188+
sys.exit(0)
189+
190+
output_data_folder: LiteralString = os.path.join(args.artifact, "outputs")
191+
192+
adb = SimpleADB(
193+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
194+
build_path=args.build_folder,
195+
pte_path=f"{args.artifact}/{pte_filename}.pte",
196+
workspace=workspace,
197+
device_id=args.device,
198+
host_id=args.host,
199+
soc_model=args.model,
200+
)
201+
adb.push(inputs=sample_input, files=op_package_paths)
202+
adb.execute()
203+
adb.pull(output_path=args.artifact)
204+
205+
# Compare results
206+
x86_golden = model(*sample_input)
207+
import numpy as np
208+
209+
device_output = torch.from_numpy(
210+
np.fromfile(
211+
os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32
212+
)
213+
).reshape(x86_golden.size())
214+
print(
215+
"is_close?",
216+
torch.all(torch.isclose(x86_golden, device_output, atol=1e-2)).item(),
217+
)
218+
219+
220+
if __name__ == "__main__":
221+
parser = setup_common_args_and_variables()
222+
parser.add_argument(
223+
"-a",
224+
"--artifact",
225+
help="path for storing generated artifacts by this example. Default ./custom_op",
226+
default="./custom_op",
227+
type=str,
228+
)
229+
230+
parser.add_argument(
231+
"-d",
232+
"--op_package_dir",
233+
help="Path to operator package which generates from QNN.",
234+
type=str,
235+
required=True,
236+
)
237+
238+
parser.add_argument(
239+
"-F",
240+
"--use_fp16",
241+
help="If specified, will run in fp16 precision and discard ptq setting",
242+
action="store_true",
243+
default=False,
244+
)
245+
246+
parser.add_argument(
247+
"--build_op_package",
248+
help="Build op package based on op_package_dir. Please set up "
249+
"`HEXAGON_SDK_ROOT` and `ANDROID_NDK_ROOT` environment variable. "
250+
"And add clang compiler into `PATH`. Please refer to Qualcomm AI Engine "
251+
"Direct SDK document to get more details",
252+
action="store_true",
253+
default=False,
254+
)
255+
256+
args = parser.parse_args()
257+
args.validate(args)
258+
259+
try:
260+
main(args)
261+
except Exception as e:
262+
if args.ip and args.port != -1:
263+
with Client((args.ip, args.port)) as conn:
264+
conn.send(json.dumps({"Error": str(e)}))
265+
else:
266+
raise Exception(e)

examples/qualcomm/custom_op/example_op_package_htp/ExampleOpPackage/Makefile

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,23 @@ $(info "HEXAGON_SDK_ROOT is [${HEXAGON_SDK_ROOT}]")
4444
HEXAGON_SDK_ROOT_V68 := $(HEXAGON_SDK_BASE)/hexagon-sdk-4.2.0
4545
HEXAGON_SDK_ROOT_V69 := $(HEXAGON_SDK_BASE)/hexagon-sdk-4.3.0
4646
HEXAGON_SDK_ROOT_V73 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
47-
HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
48-
HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.0.0
47+
# HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)/hexagon-sdk-5.4.0
48+
HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_BASE)
49+
# HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.0.0
50+
HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_BASE)
4951

5052
#Updated to point to latest sdk to match with libQnnHtp.so
51-
HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_BASE)/hexagon-sdk-6.0.0
53+
HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_BASE)
5254
HEXAGON_TOOLS_VERSION_V68 := 8.4.09
5355
HEXAGON_TOOLS_VERSION_V69 := 8.5.03
5456
HEXAGON_TOOLS_VERSION_V73 := 8.6.02
5557
HEXAGON_TOOLS_VERSION_V75 := 8.7.03
56-
HEXAGON_TOOLS_VERSION_V79 := 8.8.02
58+
# HEXAGON_TOOLS_VERSION_V79 := 8.8.02
59+
HEXAGON_TOOLS_VERSION_V79 := 8.8.06
5760

5861
#Updated to point to latest sdk to match with libQnnHtp.so
59-
HEXAGON_TOOLS_VERSION_X86 := 8.8.02
62+
# HEXAGON_TOOLS_VERSION_X86 := 8.8.02
63+
HEXAGON_TOOLS_VERSION_X86 := 8.8.06
6064

6165
ifndef ANDROID_NDK_ROOT
6266
ifeq ($(MAKECMDGOALS),htp_aarch64)
@@ -87,6 +91,8 @@ COMMON_CXX_FLAGS += -DQNN_API="__attribute__((visibility(\"default\")))" -D__QA
8791

8892
X86_LIBNATIVE_RELEASE_DIR := $(HEXAGON_SDK_ROOT_X86)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_X86)/Tools
8993

94+
$(info "HEXAGON_SDK_ROOT_X86 is [${HEXAGON_SDK_ROOT_X86}]")
95+
9096
# Ensure hexagon sdk tool version can be retrieved
9197
ifeq ($(wildcard $(X86_LIBNATIVE_RELEASE_DIR)/.),)
9298
$(error "Cannot retrieve hexagon tools from: $(X86_LIBNATIVE_RELEASE_DIR). \
@@ -119,6 +125,8 @@ $(error "ERROR: HEXAGON_SDK_ROOT_V75 is set incorrectly. Cannot retrieve $(HEXAG
119125
endif
120126
endif
121127

128+
$(info "HEXAGON_SDK_ROOT_V79 is [${HEXAGON_SDK_ROOT_V79}]")
129+
122130
#Check tools for hexagon_v79 are present.
123131
ifeq ($(MAKECMDGOALS),htp_v79)
124132
ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V79)),)
@@ -165,6 +173,7 @@ ifeq ($(shell $(X86_CXX) -v 2>&1 | grep -c "clang version"), 0)
165173
X86_CXX := clang++
166174
endif
167175
X86_LDFLAGS:= -Wl,--whole-archive -L$(X86_LIBNATIVE_RELEASE_DIR)/libnative/lib -lnative -Wl,--no-whole-archive -lpthread
176+
168177
X86_C_FLAGS := -D__HVXDBL__ -I$(X86_LIBNATIVE_RELEASE_DIR)/libnative/include -ffast-math -DUSE_OS_LINUX
169178
X86_CXX_FLAGS = $(COMMON_CXX_FLAGS) $(X86_C_FLAGS) -fomit-frame-pointer -Wno-invalid-offsetof
170179
linux_objs =
@@ -178,12 +187,29 @@ HEXAGON_CXX_FLAGS_V73 := $(HEXAGON_CXX_FLAGS) -mv73 -I$(HEXAGON_SDK_ROOT_V73)/rt
178187
HEXAGON_CXX_FLAGS_V75 := $(HEXAGON_CXX_FLAGS) -mv75 -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/qurt -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/posix -I$(HEXAGON_SDK_ROOT_V75)/incs -I$(HEXAGON_SDK_ROOT_V75)/incs/stddef
179188
HEXAGON_CXX_FLAGS_V79 := $(HEXAGON_CXX_FLAGS) -mv79 -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/qurt -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/posix -I$(HEXAGON_SDK_ROOT_V79)/incs -I$(HEXAGON_SDK_ROOT_V79)/incs/stddef
180189

190+
$(info "HEXAGON_TOOLS_VERSION_V68 is [${HEXAGON_TOOLS_VERSION_V68}]")
191+
$(info "HEXAGON_TOOLS_VERSION_V69 is [${HEXAGON_TOOLS_VERSION_V69}]")
192+
$(info "HEXAGON_TOOLS_VERSION_V73 is [${HEXAGON_TOOLS_VERSION_V73}]")
193+
$(info "HEXAGON_TOOLS_VERSION_V75 is [${HEXAGON_TOOLS_VERSION_V75}]")
194+
$(info "HEXAGON_TOOLS_VERSION_V79 is [${HEXAGON_TOOLS_VERSION_V79}]")
195+
196+
$(info "HEXAGON_SDK_ROOT_V68 is [${HEXAGON_SDK_ROOT_V68}]")
197+
$(info "HEXAGON_SDK_ROOT_V69 is [${HEXAGON_SDK_ROOT_V69}]")
198+
$(info "HEXAGON_SDK_ROOT_V73 is [${HEXAGON_SDK_ROOT_V73}]")
199+
$(info "HEXAGON_SDK_ROOT_V75 is [${HEXAGON_SDK_ROOT_V75}]")
200+
$(info "HEXAGON_SDK_ROOT_V79 is [${HEXAGON_SDK_ROOT_V79}]")
201+
181202
HEXAGON_CXX_V68 := $(HEXAGON_SDK_ROOT_V68)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V68)/Tools/bin/hexagon-clang++
182203
HEXAGON_CXX_V69 := $(HEXAGON_SDK_ROOT_V69)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V69)/Tools/bin/hexagon-clang++
183204
HEXAGON_CXX_V73 := $(HEXAGON_SDK_ROOT_V73)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V73)/Tools/bin/hexagon-clang++
184205
HEXAGON_CXX_V75 := $(HEXAGON_SDK_ROOT_V75)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V75)/Tools/bin/hexagon-clang++
185206
HEXAGON_CXX_V79 := $(HEXAGON_SDK_ROOT_V79)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V79)/Tools/bin/hexagon-clang++
186207

208+
$(info "HEXAGON_CXX_V68[2] is [${HEXAGON_CXX_V68}]")
209+
$(info "HEXAGON_CXX_V69[2] is [${HEXAGON_CXX_V69}]")
210+
$(info "HEXAGON_CXX_V73[2] is [${HEXAGON_CXX_V73}]")
211+
$(info "HEXAGON_CXX_V75[2] is [${HEXAGON_CXX_V75}]")
212+
$(info "HEXAGON_CXX_V79[2] is [${HEXAGON_CXX_V79}]")
187213

188214
HEX_LDFLAGS =
189215
hexagon_objs =

0 commit comments

Comments
 (0)