Skip to content

Commit f533532

Browse files
authored
Merge branch 'release/0.7' into cherry-pick-12452-by-pytorch_bot_bot_
2 parents bd2e706 + 2d79be5 commit f533532

File tree

4 files changed

+98
-11
lines changed

4 files changed

+98
-11
lines changed
Submodule XNNPACK updated 7178 files

backends/xnnpack/third-party/xnnpack.buck.bzl

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,38 @@ def define_xnnpack():
274274
],
275275
)
276276

277+
SSE2_FMA_COMPILER_FLAGS = [
278+
"-msse2",
279+
"-mno-sse3",
280+
]
281+
282+
native.cxx_library(
283+
name = "ukernels_sse2fma",
284+
srcs = select({
285+
"DEFAULT": prod_srcs_for_arch_wrapper("sse2fma"),
286+
"ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC,
287+
"ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC,
288+
}),
289+
headers = get_xnnpack_headers(),
290+
header_namespace = "",
291+
compiler_flags = [
292+
"-O2",
293+
"-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default
294+
] + select({
295+
"DEFAULT": SSE2_FMA_COMPILER_FLAGS,
296+
"ovr_config//cpu:arm32": [],
297+
"ovr_config//cpu:arm64": [],
298+
}),
299+
preferred_linkage = "static",
300+
preprocessor_flags = [
301+
"-DXNN_LOG_LEVEL=0",
302+
],
303+
exported_deps = [
304+
":FP16",
305+
":interface",
306+
],
307+
)
308+
277309
SSE3_COMPILER_FLAGS = ["-mssse3"]
278310

279311
# @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
@@ -961,6 +993,44 @@ def define_xnnpack():
961993
],
962994
)
963995

996+
AMD64_COMPILER_FLAGS = [
997+
"-mf16c",
998+
"-mfma",
999+
"-mavx512f",
1000+
"-mavx512cd",
1001+
"-mavx512bw",
1002+
"-mavx512dq",
1003+
"-mavx512vl",
1004+
"-mavx512vnni",
1005+
"-mgfni",
1006+
]
1007+
native.cxx_library(
1008+
name = "ukernels_amd64",
1009+
srcs = select({
1010+
"DEFAULT": prod_srcs_for_arch_wrapper("amd64"),
1011+
"ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC,
1012+
"ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC,
1013+
}),
1014+
headers = get_xnnpack_headers(),
1015+
header_namespace = "",
1016+
compiler_flags = [
1017+
"-O2",
1018+
"-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default
1019+
] + select({
1020+
"DEFAULT": AMD64_COMPILER_FLAGS,
1021+
"ovr_config//cpu:arm32": [],
1022+
"ovr_config//cpu:arm64": [],
1023+
}),
1024+
preferred_linkage = "static",
1025+
preprocessor_flags = [
1026+
"-DXNN_LOG_LEVEL=0",
1027+
],
1028+
exported_deps = [
1029+
":FP16",
1030+
":interface",
1031+
],
1032+
)
1033+
9641034
AVX512VNNIGFNI_COMPILER_FLAGS = AVX512VNNI_COMPILER_FLAGS + [
9651035
"-mgfni",
9661036
]
@@ -1044,12 +1114,14 @@ def define_xnnpack():
10441114
":ukernels_fma3",
10451115
":ukernels_sse",
10461116
":ukernels_sse2",
1117+
":ukernels_sse2fma",
10471118
":ukernels_sse41",
10481119
":ukernels_ssse3",
10491120
":ukernels_avx512vbmi",
10501121
":ukernels_avx512vnnigfni",
10511122
":ukernels_avx512vnni",
10521123
":ukernels_avxvnni",
1124+
":ukernels_amd64",
10531125
]
10541126

10551127
ARM_XNNPACK_DEPS = [
@@ -1097,10 +1169,22 @@ def define_xnnpack():
10971169
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION",
10981170
"-DXNN_ENABLE_ARM_DOTPROD",
10991171
"-DXNN_ENABLE_CPUINFO",
1100-
# "-DXNN_ENABLE_DWCONV_MULTIPLASS=1",
1172+
# "-DXNN_ENABLE_DWCONV_MULTIPLASS=0",
11011173
"-DXNN_ENABLE_ARM_I8MM=1",
11021174
"-DXNN_ENABLE_ARM_FP16_VECTOR=1",
1103-
"-DXNN_ENABLE_AVX512BF16=0"
1175+
"-DXNN_ENABLE_AVX512F=1",
1176+
"-DXNN_ENABLE_AVX512SKX=1",
1177+
"-DXNN_ENABLE_AVX512VNNI=1",
1178+
"-DXNN_ENABLE_AVX512VBMI=1",
1179+
"-DXNN_ENABLE_AVXVNNI=0",
1180+
"-DXNN_ENABLE_AVXVNNIINT8=0",
1181+
"-DXNN_ENABLE_AVX512FP16=0",
1182+
"-DXNN_ENABLE_AVX512VNNIGFNI=0",
1183+
"-DXNN_ENABLE_AVX512BF16=0",
1184+
"-DXNN_ENABLE_AVX256VNNIGFNI=0",
1185+
"-DXNN_ENABLE_AVX512AMX=0",
1186+
"-DXNN_ENABLE_AVX256SKX=0",
1187+
"-DXNN_ENABLE_AVX256VNNI=0",
11041188
],
11051189
visibility = ["PUBLIC"],
11061190
exported_deps = COMMON_XNNPACK_DEPS + [

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,25 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
5252
auto numel = sizes[0];
5353
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};
5454

55-
// Assuming the last dimension is the one with the variable token length,
56-
// for example [1, S] or [1, 1, S]
57-
sizes_vec[sizes_vec.size() - 1] = numel;
5855
TensorPtr start_pos_tensor;
5956
if (numel > 1) {
60-
// Assuming model is exported with cache_positions, create a tensor with
61-
// the same size as cache_positions
57+
// If we are here, model is exported with cache_positions, create a tensor
58+
// with the same length as input_ids. Assuming the last dimension is the
59+
// one with the variable token length, for example [1, S] or [1, 1, S]
60+
sizes_vec[sizes_vec.size() - 1] = tokens->numel();
6261
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
6362
torch::executor::native::arange_out_impl(
64-
start_pos, start_pos + numel, 1.0, *start_pos_tensor);
63+
start_pos, start_pos + tokens->numel(), 1.0, *start_pos_tensor);
6564
} else {
6665
// Assuming model is exported with input_pos, create a tensor with size 1
6766
start_pos_tensor = from_blob(
6867
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
6968
}
70-
ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel());
69+
ET_LOG(
70+
Info,
71+
"Start pos tensor numel: %zu, tokens numel: %zu",
72+
start_pos_tensor->numel(),
73+
tokens->numel());
7174
auto outputs_res = module_->forward({tokens, start_pos_tensor});
7275
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
7376
ET_CHECK_MSG(

extension/llm/runner/text_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ET_EXPERIMENTAL TextPrefiller {
2121
public:
2222
TextPrefiller(
2323
TextDecoderRunner* text_decoder_runner,
24-
bool use_kv_cache_,
24+
bool use_kv_cache,
2525
bool enable_parallel_prefill,
2626
int64_t max_seq_len = 128);
2727

0 commit comments

Comments
 (0)