Skip to content

Commit fb03b6f

Browse files
Add safety check to generated kernels
Differential Revision: D79123304 Pull Request resolved: pytorch#12945
1 parent 08c7636 commit fb03b6f

File tree

5 files changed

+16
-2
lines changed

5 files changed

+16
-2
lines changed

codegen/gen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ def __call__(
243243
argument_type_gen=argument_type_gen
244244
).convert_arguments(arguments)
245245

246+
# +1 for the return value
247+
num_boxed_args = len(binding_list) + 1
248+
# This safety check does not account for optional args with default values. ET itself doesnt support default args, but when supported is added this check can be relaxed to >= # of non default arg.
249+
safety_check = f"""ET_KERNEL_CHECK_MSG(context, stack.size() == {num_boxed_args}, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t){num_boxed_args}, stack.size());"""
246250
# for each C++ argument, generate the conversion code
247251
code_connector = "\n\t"
248252
arg_connector = ", "
@@ -292,12 +296,13 @@ def __call__(
292296
{indent} context.fail(torch::executor::Error::Internal);
293297
{indent}}}"""
294298
newline = "\n "
295-
return "\n".join(
299+
temp = "\n".join(
296300
[
297301
f"""
298302
Kernel(
299303
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""}
300304
[]({contextArg.defn()}, Span<EValue*> stack) {{
305+
{safety_check}
301306
{code_connector.join(code_list)}
302307
303308
{exception_boundary_begin}
@@ -313,6 +318,7 @@ def __call__(
313318
for k in used_kernel_keys
314319
]
315320
)
321+
return temp
316322

317323

318324
def gen_unboxing(
@@ -534,6 +540,7 @@ def gen_headers(
534540
"headers": [
535541
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
536542
"#include <executorch/runtime/kernel/kernel_runtime_context.h>",
543+
"#include <executorch/runtime/core/error.h>",
537544
],
538545
}
539546
if use_aten_lib:

codegen/templates/RegisterCodegenUnboxedKernels.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/runtime/core/evalue.h>
1010
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1112
#include <executorch/runtime/core/span.h>
1213
#include <executorch/runtime/kernel/operator_registry.h>
1314
#include <executorch/runtime/platform/profiler.h>

codegen/templates/RegisterKernels.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// This implements register_all_kernels() API that is declared in
1111
// RegisterKernels.h
1212
#include "RegisterKernels.h"
13+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1314
#include "${fn_header}" // Generated Function import headers
1415

1516
namespace torch {

codegen/test/test_executorch_gen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def test_codegen_unboxed_specialized(self) -> None:
508508
"custom_1::op_1",
509509
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
510510
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
511+
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t)1, stack.size());
511512
"""
512513
+ """
513514
@@ -606,6 +607,7 @@ def test_codegen_unboxed_default(self) -> None:
606607
Kernel(
607608
"custom_1::op_1",
608609
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
610+
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t)1, stack.size());
609611
"""
610612
+ """
611613
@@ -621,7 +623,6 @@ def test_codegen_unboxed_default(self) -> None:
621623
),
622624
"""
623625
)
624-
625626
self.assertEqual(expected_str, result)
626627

627628
result = ComputeCodegenUnboxedKernels(
@@ -633,6 +634,7 @@ def test_codegen_unboxed_default(self) -> None:
633634
Kernel(
634635
"custom_1::op_1",
635636
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
637+
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, "Expected %" ET_PRIsize_t "args received %" ET_PRIsize_t, (size_t)1, stack.size());
636638
"""
637639
+ """
638640
@@ -676,6 +678,7 @@ def test_codegen_unboxed_default_kernel_key_selected(self) -> None:
676678
Kernel(
677679
"custom_1::op_1",
678680
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
681+
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, "Expected %" ET_PRIsize_t "args received %" ET_PRIsize_t, (size_t)1, stack.size());
679682
"""
680683
+ """
681684

shim_et/xplat/executorch/codegen/codegen.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def executorch_generated_lib(
896896
exported_deps = [
897897
"//executorch/codegen:macros",
898898
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
899+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
899900
],
900901
feature = feature,
901902
)
@@ -933,6 +934,7 @@ def executorch_generated_lib(
933934
exported_deps = [
934935
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
935936
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
937+
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
936938
],
937939
xplat_deps = xplat_deps,
938940
fbcode_deps = fbcode_deps,

0 commit comments

Comments
 (0)