Skip to content

Commit bb88bb6

Browse files
committed
use intrinsics for mul3
1 parent c3b6b02 commit bb88bb6

File tree

4 files changed

+97
-68
lines changed

4 files changed

+97
-68
lines changed

examples/qualcomm/custom_op/custom_ops_1.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,13 @@ def annotate_custom(gm: torch.fx.GraphModule) -> None:
6969
This function is specific for custom op.
7070
The source_fn of the rewritten nn module turns out to be "my_ops.mul3.default"
7171
"""
72-
from executorch.backends.qualcomm.quantizer.annotators import (
73-
_is_annotated,
74-
QUANT_ANNOTATION_KEY,
75-
)
76-
72+
from executorch.backends.qualcomm.quantizer.annotators import _is_annotated
7773
from executorch.backends.qualcomm.quantizer.qconfig import (
7874
get_ptq_per_channel_quant_config,
7975
)
8076
from torch.fx import Node
8177
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
78+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
8279

8380
quantization_config = get_ptq_per_channel_quant_config()
8481
for node in gm.graph.nodes:
@@ -95,7 +92,7 @@ def annotate_custom(gm: torch.fx.GraphModule) -> None:
9592
input_spec = quantization_config.input_activation
9693
input_qspec_map[input_act] = input_spec
9794

98-
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
95+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
9996
input_qspec_map=input_qspec_map,
10097
output_qspec=quantization_config.output_activation,
10198
_annotated=True,
@@ -180,7 +177,8 @@ def main(args):
180177
# ensure the working directory exist.
181178
os.makedirs(args.artifact, exist_ok=True)
182179

183-
quant_dtype = QuantDtype.use_8a8w
180+
# quant_dtype: Literal[QuantDtype.use_16a16w] = QuantDtype.use_16a16w
181+
quant_dtype: Literal[QuantDtype.use_16a16w] = QuantDtype.use_8a8w
184182
if args.use_fp16:
185183
quant_dtype = None
186184

@@ -197,9 +195,11 @@ def main(args):
197195
soc_info.htp_info.htp_arch,
198196
args.build_op_package,
199197
)
200-
quantizer = make_quantizer(
201-
quant_dtype=quant_dtype, custom_annotations=(annotate_custom,)
202-
)
198+
quantizer = None
199+
if not args.use_fp16:
200+
quantizer = make_quantizer(
201+
quant_dtype=quant_dtype, custom_annotations=(annotate_custom,)
202+
)
203203

204204
build_executorch_binary(
205205
instance,
@@ -228,13 +228,14 @@ def main(args):
228228

229229
runner_cmd = " ".join(
230230
[
231-
f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&",
231+
f"export QNN_FARF_LEVEL=4 && export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&",
232232
f"./{args.build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner",
233233
f"--model_path {args.artifact}/{pte_filename}.pte",
234234
f"--input_list_path {args.artifact}/{input_list_filename}",
235235
f"--output_folder_path {output_data_folder}",
236236
]
237237
)
238+
238239
subprocess.run(
239240
runner_cmd,
240241
# stdout=subprocess.PIPE,
@@ -258,6 +259,7 @@ def main(args):
258259
device_id=args.device,
259260
host_id=args.host,
260261
soc_model=args.model,
262+
shared_buffer=args.shared_buffer,
261263
)
262264
adb.push(inputs=sample_input, files=op_package_paths)
263265
adb.execute()

examples/qualcomm/custom_op/example_op_package_htp/ExampleOpPackage/Makefile

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,32 @@ HEXAGON_CXX_FLAGS_V73 := $(HEXAGON_CXX_FLAGS) -mv73 -I$(HEXAGON_SDK_ROOT_V73)/rt
187187
HEXAGON_CXX_FLAGS_V75 := $(HEXAGON_CXX_FLAGS) -mv75 -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/qurt -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/posix -I$(HEXAGON_SDK_ROOT_V75)/incs -I$(HEXAGON_SDK_ROOT_V75)/incs/stddef
188188
HEXAGON_CXX_FLAGS_V79 := $(HEXAGON_CXX_FLAGS) -mv79 -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/qurt -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/posix -I$(HEXAGON_SDK_ROOT_V79)/incs -I$(HEXAGON_SDK_ROOT_V79)/incs/stddef
189189

190+
QHL_HVX_DIR := $(HEXAGON_SDK_ROOT)/libs/qhl_hvx
191+
QHL_HVX_INC_DIRS := \
192+
$(QHL_HVX_DIR)/inc/internal \
193+
$(QHL_HVX_DIR)/inc/qhdsp_hvx \
194+
$(QHL_HVX_DIR)/inc/qhblas_hvx \
195+
$(QHL_HVX_DIR)/inc/qhmath_hvx
196+
197+
HEXAGON_CXX_FLAGS_V79 += $(addprefix -I,$(QHL_HVX_INC_DIRS))
198+
199+
QHL_DIR := $(HEXAGON_SDK_ROOT)/libs/qhl
200+
QHL_INC_DIRS := \
201+
$(QHL_DIR)/inc/qhmath \
202+
$(QHL_DIR)/inc/qhcomplex \
203+
$(QHL_DIR)/inc/qhdsp \
204+
$(QHL_DIR)/inc/qhblas
205+
206+
QHL_LIBS := \
207+
$(HEXAGON_SDK_ROOT)/libs/qhl/prebuilt/hexagon_toolv88_v79/libqhblas.a \
208+
$(HEXAGON_SDK_ROOT)/libs/qhl/prebuilt/hexagon_toolv88_v79/libqhdsp.a \
209+
$(HEXAGON_SDK_ROOT)/libs/qhl/prebuilt/hexagon_toolv88_v79/libqhmath.a \
210+
$(HEXAGON_SDK_ROOT)/libs/qhl/prebuilt/hexagon_toolv88_v79/libqhcomplex.a \
211+
$(HEXAGON_SDK_ROOT)/libs/qhl_hvx/prebuilt/hexagon_toolv88_v79/libqhdsp_hvx.a \
212+
$(HEXAGON_SDK_ROOT)/libs/qhl_hvx/prebuilt/hexagon_toolv88_v79/libqhblas_hvx.a
213+
214+
HEXAGON_CXX_FLAGS_V79 += $(addprefix -I,$(QHL_INC_DIRS))
215+
190216
$(info "HEXAGON_TOOLS_VERSION_V68 is [${HEXAGON_TOOLS_VERSION_V68}]")
191217
$(info "HEXAGON_TOOLS_VERSION_V69 is [${HEXAGON_TOOLS_VERSION_V69}]")
192218
$(info "HEXAGON_TOOLS_VERSION_V73 is [${HEXAGON_TOOLS_VERSION_V73}]")
@@ -253,7 +279,6 @@ HEXAGON_BUILD_V75: $(WORK)/hexagon-v75/$(LIBRARY_NAME)
253279
HEXAGON_BUILD_V79: $(WORK)/hexagon-v79/$(LIBRARY_NAME)
254280

255281

256-
257282
X86_BUILD: $(WORK)/x86_64-linux-clang/$(LIBRARY_NAME)
258283

259284

@@ -366,8 +391,11 @@ $(WORK)/hexagon-v79/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v79
366391
$(WORK)/hexagon-v79/ops/%.o: $(OP_SRC_DIR)/v79_asm/%.S | $(WORK)/hexagon-v79
367392
$(HEXAGON_CXX_V79) $(HEXAGON_CXX_FLAGS_V79) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@
368393

369-
$(WORK)/hexagon-v79/$(LIBRARY_NAME): $(hexagon-v79_objs) | $(HFILES)
370-
$(HEXAGON_CXX_V79) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS)
394+
# $(WORK)/hexagon-v79/$(LIBRARY_NAME): $(hexagon-v79_objs) | $(HFILES)
395+
# $(HEXAGON_CXX_V79) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS)
396+
397+
$(WORK)/hexagon-v79/$(LIBRARY_NAME): $(hexagon-v79_objs)
398+
$(HEXAGON_CXX_V79) -fPIC -std=c++17 -g -shared -o $@ $^ $(QHL_LIBS) $(HEX_LDFLAGS)
371399

372400

373401

examples/qualcomm/custom_op/example_op_package_htp/ExampleOpPackage/src/ExampleOpPackageInterface.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
#include "QnnOpPackage.h"
1313
#include "QnnSdkBuildId.h"
1414

15+
#ifdef __hexagon__
16+
#include "qhblas_hvx.h" // may re-export symbols in qhblas
17+
#include "qhcomplex.h"
18+
#include "qhdsp_hvx.h" // still present under qhl/inc/qhdsp
19+
#endif
20+
1521
DEFINE_UNIQ_TY()
1622
BEGIN_PKG_OPS_OPTS_LIST()
1723

examples/qualcomm/custom_op/example_op_package_htp/ExampleOpPackage/src/ops/ExampleCustomOp.cpp

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
//==============================================================================
44

55
#include "HTP/core/constraints.h"
6+
#include "HTP/core/intrinsics.h"
67
#include "HTP/core/op_package_feature_support.h"
78
#include "HTP/core/op_register_ext.h"
89
#include "HTP/core/optimize.h"
910
#include "HTP/core/simple_reg.h"
1011
#include "QnnOpPackage.h"
12+
#include "hexagon_protos.h"
1113
#include "hexagon_types.h"
1214
#include "hvx_hexagon_protos.h"
1315

@@ -163,7 +165,51 @@ GraphStatus examplecustomopImpl(TensorType& out_0, const TensorType& in_0)
163165
if (input_intfc.dtype == DType::Float32) {
164166
const float* p_input = static_cast<const float*>(in_0.raw_data_const());
165167
float* p_output = static_cast<float*>(out_0.raw_data());
166-
const int multiplier = 3;
168+
const size_t N = in_0.total_storage_elements();
169+
170+
// Allocate temporary FP16 buffers on stack or heap
171+
std::vector<Float16> tmp_in(N);
172+
std::vector<Float16> tmp_out(N);
173+
174+
// 1. Convert FP32 -> FP16
175+
for (size_t i = 0; i < N; ++i) {
176+
tmp_in[i] = static_cast<Float16>(p_input[i]);
177+
}
178+
179+
#ifdef __hexagon__
180+
// 2. Run HVX multiply (FP16 domain)
181+
union {
182+
Float16 f16;
183+
uint16_t bits;
184+
} f3 = {static_cast<Float16>(3.0f)};
185+
HVX_Vector v_mul = Q6_Vh_vsplat_R(f3.bits);
186+
187+
const int vector_bytes = 128;
188+
const int elems_per_vec = vector_bytes / sizeof(Float16);
189+
190+
for (size_t i = 0; i < N; i += elems_per_vec) {
191+
HVX_Vector vin = q6op_V_vldu_A(&tmp_in[i]);
192+
HVX_Vector vout = Q6_Vhf_vmpy_VhfVhf(vin, v_mul);
193+
q6op_vstu_AV(&tmp_out[i], vout);
194+
}
195+
#else
196+
// 2. Fallback scalar multiply
197+
for (size_t i = 0; i < N; ++i) {
198+
tmp_out[i] = static_cast<Float16>(tmp_in[i] * static_cast<Float16>(3.0f));
199+
}
200+
#endif
201+
202+
// 3. Convert FP16 -> FP32
203+
for (size_t i = 0; i < N; ++i) {
204+
p_output[i] = static_cast<float>(tmp_out[i]);
205+
}
206+
207+
return GraphStatus::Success;
208+
} else if (input_intfc.dtype == DType::QUInt8) {
209+
// printf("[QNN ExecuTorch Op Package test] input is QUInt8\n");
210+
const uint8_t* p_input = static_cast<const uint8_t*>(in_0.raw_data_const());
211+
uint8_t* p_output = static_cast<uint8_t*>(out_0.raw_data());
212+
const int multiplier = 3 * input_intfc.scale / out_intfc.scale;
167213
for (size_t i = 0; i < input_num_elements; ++i) {
168214
p_output[i] = multiplier * p_input[i];
169215

@@ -177,59 +223,6 @@ GraphStatus examplecustomopImpl(TensorType& out_0, const TensorType& in_0)
177223
i,
178224
p_output[i]);
179225
}
180-
} else if (input_intfc.dtype == DType::QUInt8) {
181-
// const uint8_t* p_input = static_cast<const
182-
// uint8_t*>(in_0.raw_data_const()); uint8_t* p_output =
183-
// static_cast<uint8_t*>(out_0.raw_data()); const int multiplier = 3 *
184-
// input_intfc.scale / out_intfc.scale; for (size_t i = 0; i <
185-
// input_num_elements; ++i) {
186-
// p_output[i] = multiplier * p_input[i];
187-
188-
// FARF(
189-
// ALWAYS,
190-
// "[QNN ExecuTorch Op Package test]"
191-
// "input0[%zu]=%f, multiplier=%d, output[%zu]=%f",
192-
// i,
193-
// p_input[i],
194-
// multiplier,
195-
// i,
196-
// p_output[i]);
197-
// }
198-
199-
const uint8_t* p_input = static_cast<const uint8_t*>(in_0.raw_data_const());
200-
uint8_t* p_output = static_cast<uint8_t*>(out_0.raw_data());
201-
const float multiplier_f = 3.0f * input_intfc.scale / out_intfc.scale;
202-
const int multiplier =
203-
static_cast<int>(multiplier_f * 128.0f); // fixed-point scale
204-
205-
const HVX_Vector* in_vec = reinterpret_cast<const HVX_Vector*>(p_input);
206-
HVX_Vector* out_vec = reinterpret_cast<HVX_Vector*>(p_output);
207-
208-
HVX_Vector v_mult = Q6_V_vsplat_R(multiplier & 0xFF);
209-
HVX_Vector vzero = Q6_V_vzero();
210-
211-
const size_t vec_elems = 128; // 128 bytes per HVX vector
212-
const size_t nvecs = input_num_elements / vec_elems;
213-
214-
for (size_t i = 0; i < nvecs; ++i) {
215-
HVX_Vector vin = Q6_V_vldu_A(in_vec + i);
216-
HVX_Vector vout;
217-
218-
#if defined(__HEXAGON_ARCH__)
219-
// use available multiply intrinsic
220-
vout = Q6_Vub_vmpy_VubRb_s1_rnd_sat(vin, v_mult);
221-
#else
222-
// fallback scalar multiply for x86 simulation
223-
alignas(128) uint8_t tmp_in[128], tmp_out[128];
224-
memcpy(tmp_in, p_input + i * 128, 128);
225-
for (int j = 0; j < 128; ++j)
226-
tmp_out[j] = std::min(255, (tmp_in[j] * multiplier) >> 7);
227-
memcpy(p_output + i * 128, tmp_out, 128);
228-
continue;
229-
#endif
230-
231-
Q6_V_vstu_A(out_vec + i, vout);
232-
}
233226
}
234227

235228
return GraphStatus::Success;

0 commit comments

Comments
 (0)