Skip to content

Commit cceac32

Browse files
committed
Update on "[Executorch] Renable operator optimization flags"
Previous attempt at this resulted in revert due to app size increase. Much of this was due to op_div exploding. Two diffs underneath solve this issue Differential Revision: [D65606666](https://our.internmc.facebook.com/intern/diff/D65606666/) [ghstack-poisoned]
2 parents 22a2235 + bd6a641 commit cceac32

File tree

30 files changed

+606
-884
lines changed

30 files changed

+606
-884
lines changed

.github/workflows/ghstack_land.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
branches:
66
- 'gh/cccclai/[0-9]+/base'
77
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/dvorjackz/[0-9]+/base'
89
- 'gh/guangy10/[0-9]+/base'
910
- 'gh/helunwencser/[0-9]+/base'
1011
- 'gh/jorgep31415/[0-9]+/base'

backends/arm/test/runner_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,16 +448,21 @@ def run_tosa_ref_model(
448448
), "There are no quantization parameters, check output parameters"
449449
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
450450

451+
if tosa_ref_output.dtype == np.double:
452+
tosa_ref_output = tosa_ref_output.astype("float32")
453+
451454
# tosa_output is a numpy array, convert to torch tensor for comparison
452-
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))
455+
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output))
453456

454457
return tosa_ref_outputs
455458

456459

457460
def prep_data_for_save(
458461
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
459462
):
460-
data_np = np.array(data.detach(), order="C").astype(np.float32)
463+
data_np = np.array(data.detach(), order="C").astype(
464+
f"{data.dtype}".replace("torch.", "")
465+
)
461466

462467
if is_quantized:
463468
assert quant_param.node_name in input_name, (

backends/cadence/aot/ops_registrations.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
lib.define(
6767
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
6868
)
69+
lib.define(
70+
"quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)"
71+
)
72+
lib.define(
73+
"quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
74+
)
6975

7076
lib.define(
7177
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
@@ -171,6 +177,54 @@ def quantized_conv_meta(
171177
return input.new_empty(output_size, dtype=input.dtype)
172178

173179

180+
@register_fake("cadence::quantized_conv.per_tensor")
181+
def quantized_conv_per_tensor_meta(
182+
input: torch.Tensor,
183+
weight: torch.Tensor,
184+
bias: torch.Tensor,
185+
stride: Tuple[int],
186+
padding: Tuple[int],
187+
dilation: Tuple[int],
188+
groups: int,
189+
in_zero_point: int,
190+
weight_zero_point: int,
191+
bias_scale: float,
192+
output_scale: float,
193+
output_zero_point: int,
194+
out_multiplier: int,
195+
out_shift: int,
196+
channel_last: bool = False,
197+
) -> torch.Tensor:
198+
if channel_last:
199+
out_channels, *kernel_size, _ = weight.shape
200+
else:
201+
out_channels, _, *kernel_size = weight.shape
202+
203+
in_size = input.shape
204+
# Assert that the input tensor has at least 3 dimensions, and at most 6
205+
assert len(in_size) > 2
206+
assert len(in_size) < 6
207+
208+
# Compute the output tensor size
209+
output_size = (
210+
get_conv1d_output_size(
211+
in_size,
212+
out_channels,
213+
stride[1],
214+
padding[1],
215+
dilation[1],
216+
kernel_size[0],
217+
channel_last,
218+
)
219+
if len(in_size) == 3
220+
else get_conv2d_output_size(
221+
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
222+
)
223+
)
224+
225+
return input.new_empty(output_size, dtype=input.dtype)
226+
227+
174228
@register_fake("cadence::quantized_layer_norm")
175229
def quantized_layer_norm_meta(
176230
input: torch.Tensor,

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
env: Dict[Any, Any],
541541
glslc_path: Optional[str],
542542
glslc_flags: str = "",
543+
replace_u16vecn: bool = False,
543544
) -> None:
544545
if isinstance(src_dir_paths, str):
545546
self.src_dir_paths = [src_dir_paths]
@@ -549,6 +550,7 @@ def __init__(
549550
self.env = env
550551
self.glslc_path = glslc_path
551552
self.glslc_flags = glslc_flags
553+
self.replace_u16vecn = replace_u16vecn
552554

553555
self.glsl_src_files: Dict[str, str] = {}
554556
self.template_yaml_files: List[str] = []
@@ -705,6 +707,22 @@ def constructOutputMap(self) -> None:
705707
self.create_shader_params(),
706708
)
707709

710+
def maybe_replace_u16vecn(self, input_text: str) -> str:
711+
"""
712+
There is a latency benefit to using u16vecn variables to store texture position
713+
variables instead of ivecn, likely due to reduced register pressure. However,
714+
SwiftShader does not support 16 bit integer types in shaders, so this is a crude
715+
way to fallback to using ivecn to store texture positions so that testing with
716+
SwiftShader is still possible.
717+
"""
718+
if not self.replace_u16vecn:
719+
return input_text
720+
if "codegen-nosub" in input_text:
721+
return input_text
722+
723+
input_text = input_text.replace("u16vec", "ivec")
724+
return input_text
725+
708726
def generateSPV(self, output_dir: str) -> Dict[str, str]:
709727
output_file_map = {}
710728

@@ -716,6 +734,7 @@ def process_shader(shader_paths_pair):
716734

717735
with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
718736
input_text = input_file.read()
737+
input_text = self.maybe_replace_u16vecn(input_text)
719738
output_text = preprocess(input_text, shader_params)
720739

721740
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
@@ -1029,6 +1048,7 @@ def main(argv: List[str]) -> int:
10291048
parser.add_argument("-c", "--glslc-path", required=True, help="")
10301049
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
10311050
parser.add_argument("-o", "--output-path", required=True, help="")
1051+
parser.add_argument("--replace-u16vecn", action="store_true", default=False)
10321052
parser.add_argument("--optimize_size", action="store_true", help="")
10331053
parser.add_argument("--optimize", action="store_true", help="")
10341054
parser.add_argument(
@@ -1056,7 +1076,11 @@ def main(argv: List[str]) -> int:
10561076
glslc_flags += "-O"
10571077

10581078
shader_generator = SPVGenerator(
1059-
options.glsl_paths, env, options.glslc_path, glslc_flags
1079+
options.glsl_paths,
1080+
env,
1081+
options.glslc_path,
1082+
glslc_flags=glslc_flags,
1083+
replace_u16vecn=options.replace_u16vecn,
10601084
)
10611085
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
10621086

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// codegen-nosub
10+
911
#version 450 core
1012

1113
#define PRECISION ${PRECISION}

backends/vulkan/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False):
2727
select({
2828
"DEFAULT": "",
2929
"ovr_config//os:android": "--optimize",
30+
"ovr_config//os:linux": "--replace-u16vecn",
3031
})
3132
)
3233

devtools/inspector/_inspector_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]:
112112
ScalarType.BYTE: (torch.uint8, 1),
113113
ScalarType.CHAR: (torch.int8, 1),
114114
ScalarType.BOOL: (torch.bool, 1),
115+
ScalarType.BITS16: (torch.uint16, 2),
115116
ScalarType.SHORT: (torch.int16, 2),
116117
ScalarType.HALF: (torch.float16, 2),
117118
ScalarType.INT: (torch.int, 4),

exir/passes/executorch_prim_ops_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import operator
89
from typing import Dict, Set, Union
910

@@ -14,6 +15,8 @@
1415
from torch._ops import OpOverload
1516
from torch.library import Library
1617

18+
# pyre-unsafe
19+
1720

1821
executorch_prims_lib = Library("executorch_prim", "DEF")
1922

@@ -91,7 +94,13 @@ def neg(a: _SymScalar) -> _SymScalar:
9194
return -a # pyre-ignore
9295

9396

97+
@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
98+
def trunc(a: _SymScalar) -> _SymScalar:
99+
return math.trunc(a) # pyre-ignore
100+
101+
94102
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
103+
math.trunc: ops.backend.executorch_prim.trunc.Scalar,
95104
operator.sub: ops.backend.executorch_prim.sub.Scalar,
96105
operator.mul: ops.backend.executorch_prim.mul.Scalar,
97106
operator.add: ops.backend.executorch_prim.add.Scalar,

extension/llm/custom_ops/targets.bzl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
4+
"get_vec_preprocessor_flags",
5+
"get_vec_deps",
6+
)
27
load(
38
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
49
"get_compiler_optimization_flags",
510
)
611

7-
812
def define_common_targets():
913
"""Defines targets that should be shared between fbcode and xplat.
1014
@@ -26,6 +30,7 @@ def define_common_targets():
2630
"op_sdpa.h",
2731
"op_update_quantized_cache.h",
2832
],
33+
preprocessor_flags = get_vec_preprocessor_flags(),
2934
exported_deps = [
3035
"//executorch/runtime/kernel:kernel_includes",
3136
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -38,7 +43,7 @@ def define_common_targets():
3843
deps = [
3944
"//executorch/kernels/portable/cpu/util:reduce_util",
4045
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
41-
],
46+
] + get_vec_deps(),
4247
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
4348
visibility = [
4449
"//executorch/...",

kernels/optimized/cpu/op_add.cpp

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1010
#include <executorch/kernels/optimized/vec/functional.h>
1111
#include <executorch/kernels/optimized/vec/vec.h>
12+
#include <executorch/kernels/portable/cpu/op_add_impl.h>
1213
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1314
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1415
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -176,35 +177,7 @@ Tensor& opt_add_out(
176177
lhs->sizes()[lhs->dim() - 1]);
177178
});
178179
} else {
179-
ScalarType common_type =
180-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
181-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
182-
183-
ET_KERNEL_CHECK(
184-
ctx,
185-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
186-
InvalidArgument,
187-
out);
188-
189-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
190-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
191-
using CTYPE_IN = typename torch::executor::
192-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
193-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
194-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
195-
CTYPE_IN alpha_val;
196-
ET_KERNEL_CHECK(
197-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
198-
199-
AddInner<
200-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
201-
CTYPE_A,
202-
CTYPE_B,
203-
CTYPE_IN,
204-
CTYPE_OUT>::run(a, b, alpha_val, out);
205-
});
206-
});
207-
});
180+
add_out_impl(ctx, a, b, alpha, out);
208181
}
209182

210183
return out;
@@ -255,30 +228,7 @@ Tensor& opt_add_scalar_out(
255228
});
256229
});
257230
} else {
258-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
259-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
260-
ET_SWITCH_REALB_TYPES(
261-
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
262-
ET_SWITCH_REALHBBF16_TYPES(
263-
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
264-
CTYPE_B b_val;
265-
ET_EXTRACT_SCALAR(b, b_val);
266-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
267-
CTYPE_IN alpha_val;
268-
ET_EXTRACT_SCALAR(alpha, alpha_val);
269-
270-
const size_t n = a.numel();
271-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
272-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
273-
for (auto i = 0; i < n; ++i) {
274-
out_data[i] = static_cast<CTYPE_OUT>(
275-
static_cast<CTYPE_IN>(a_data[i]) +
276-
alpha_val * b_casted);
277-
}
278-
});
279-
});
280-
});
281-
});
231+
add_scalar_out_impl(ctx, a, b, alpha, out);
282232
}
283233

284234
return out;

0 commit comments

Comments
 (0)