Skip to content

Commit 5fe3df2

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Migrate existing check functions to return bool pattern
Summary: Migrate existing `check_` functions to the new pattern of using `ET_LOG_AND_RETURN_IF_FALSE` and returning a boolean instead of `ET_CHECK`. Within kernel code, replace `ET_CHECK` with `ET_KERNEL_CHECK`. Reviewed By: manuelcandales Differential Revision: D48242981 fbshipit-source-id: 04a664dbbce79a26584ffb454ec8de21f8af07c7
1 parent 9cb9ba0 commit 5fe3df2

File tree

8 files changed

+86
-51
lines changed

8 files changed

+86
-51
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Tensor& cat_out(
2626
dim += out.dim();
2727
}
2828

29-
check_cat_args(tensors, dim, out);
29+
ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), InvalidArgument, out);
3030

3131
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
3232
size_t expected_out_dim = 0;

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,26 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
3333
Tensor& var_out) {
3434
(void)ctx;
3535

36-
ET_CHECK(resize_tensor(out, in.sizes()) == Error::Ok);
36+
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(out, mean_out, var_out);
3737

38-
check_batch_norm_args(
39-
in, weight, bias, running_mean, running_var, momentum, eps, out);
38+
ET_KERNEL_CHECK(
39+
ctx,
40+
resize_tensor(out, in.sizes()) == Error::Ok,
41+
InvalidArgument,
42+
ret_val);
43+
44+
ET_KERNEL_CHECK(
45+
ctx,
46+
check_batch_norm_args(
47+
in, weight, bias, running_mean, running_var, momentum, eps, out),
48+
InvalidArgument,
49+
ret_val);
4050
// For now, only support the default dim order
41-
ET_CHECK(is_default_dim_order(in.dim_order().data(), in.dim_order().size()));
51+
ET_KERNEL_CHECK(
52+
ctx,
53+
is_default_dim_order(in.dim_order().data(), in.dim_order().size()),
54+
InvalidArgument,
55+
ret_val);
4256

4357
size_t C_dim = in.dim() >= 1 ? 1 : 0;
4458
size_t C = in.size(C_dim);
@@ -75,7 +89,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
7589
}
7690
});
7791

78-
return {out, mean_out, var_out};
92+
return ret_val;
7993
}
8094

8195
} // namespace native

kernels/portable/cpu/op_permute_copy.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ Tensor& permute_copy_out(
4242
IntArrayRef dims,
4343
Tensor& out) {
4444
(void)ctx;
45-
check_permute_copy_args(in, dims, out);
45+
46+
ET_KERNEL_CHECK(
47+
ctx, check_permute_copy_args(in, dims, out), InvalidArgument, out);
4648

4749
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
4850
size_t expected_out_dim = 0;
4951
get_permute_copy_out_target_size(
5052
in, dims, expected_out_size, &expected_out_dim);
51-
ET_CHECK(
52-
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok);
53+
ET_KERNEL_CHECK(
54+
ctx,
55+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
56+
InvalidArgument,
57+
out);
5358

5459
const auto in_type = out.scalar_type();
5560
// in and out must be the same dtype

kernels/portable/cpu/op_stack.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ Tensor& stack_out(
2828
dim += out.dim();
2929
}
3030

31-
check_stack_args(tensors, dim, out);
31+
ET_KERNEL_CHECK(
32+
ctx, check_stack_args(tensors, dim, out), InvalidArgument, out);
3233

3334
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
3435
size_t expected_out_dim = 0;
3536
get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
36-
ET_CHECK(
37-
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok);
37+
ET_KERNEL_CHECK(
38+
ctx,
39+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
40+
InvalidArgument,
41+
out);
3842

3943
const size_t outer = getLeadingDims(out, dim);
4044
const size_t inner = getTrailingDims(out, dim);

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,19 @@
88

99
#include <cstring>
1010

11-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12-
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13-
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14-
#include <executorch/runtime/platform/assert.h>
11+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1512

1613
namespace torch {
1714
namespace executor {
1815

1916
using Tensor = exec_aten::Tensor;
2017

21-
void check_cat_args(
18+
bool check_cat_args(
2219
exec_aten::ArrayRef<Tensor> tensors,
2320
int64_t dim,
2421
Tensor& out) {
2522
// Ensure the input tensors list is non-empty
26-
ET_CHECK(tensors.size() > 0);
23+
ET_LOG_AND_RETURN_IF_FALSE(tensors.size() > 0);
2724

2825
// Find the first non-empty tensor in the list to use as a reference
2926
size_t ref_i = 0;
@@ -39,25 +36,30 @@ void check_cat_args(
3936
// https://pytorch.org/docs/stable/generated/torch.cat.html
4037
for (size_t i = 0; i < tensors.size(); ++i) {
4138
// All input dtypes must be castable to the output dtype.
42-
ET_CHECK(canCast(tensors[i].scalar_type(), out.scalar_type()));
39+
ET_LOG_AND_RETURN_IF_FALSE(
40+
canCast(tensors[i].scalar_type(), out.scalar_type()));
4341

4442
// Empty tensors have no shape constraints.
4543
if (tensors[i].numel() == 0) {
4644
continue;
4745
}
4846

4947
// All input tensors must have the same number of dimensions.
50-
ET_CHECK(tensors[i].dim() == tensors[ref_i].dim());
48+
ET_LOG_AND_RETURN_IF_FALSE(
49+
tensor_is_rank(tensors[ref_i], tensors[i].dim()));
5150

5251
for (size_t d = 0; d < tensors[i].dim(); ++d) {
5352
if (d != dim) {
54-
ET_CHECK(tensors[i].size(d) == tensors[ref_i].size(d));
53+
ET_LOG_AND_RETURN_IF_FALSE(
54+
tensors_have_same_size_at_dims(tensors[i], d, tensors[ref_i], d));
5555
}
5656
}
5757
}
5858

5959
// Ensure dim is in range.
60-
ET_CHECK(dim >= 0 && dim < tensors[ref_i].dim());
60+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(tensors[ref_i], dim));
61+
62+
return true;
6163
}
6264

6365
void get_cat_out_target_size(
@@ -86,9 +88,9 @@ void get_cat_out_target_size(
8688
}
8789
}
8890

89-
void check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
90-
ET_CHECK(in.dim() == dims.size());
91-
ET_CHECK_SAME_DTYPE2(in, out);
91+
bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
92+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, dims.size()));
93+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
9294

9395
// Make sure no dimensions are duplicated and all in the range [-in.dim(),
9496
// in.dim() - 1]. Use gaussian sum to check this.
@@ -98,13 +100,15 @@ void check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
98100
// Convert dimension to a non-negative number. dim_base is in the range
99101
// [0 .. in.dim() - 1].
100102
size_t dim = dims[i] > -1 ? dims[i] : in.dim() + dims[i];
101-
ET_CHECK(dim >= 0 && dim < in.dim());
103+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
102104
gauss_sum += dim + 1;
103105
}
104106

105-
ET_CHECK_MSG(
107+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
106108
gauss_sum == expected_sum,
107109
"The dims passed to permute_copy must contain one of each dim!");
110+
111+
return true;
108112
}
109113

110114
void get_permute_copy_out_target_size(
@@ -119,28 +123,32 @@ void get_permute_copy_out_target_size(
119123
}
120124
}
121125

122-
void check_stack_args(
126+
bool check_stack_args(
123127
exec_aten::ArrayRef<Tensor> tensors,
124128
int64_t dim,
125129
Tensor& out) {
126130
// Ensure the input tensors list is non-empty
127-
ET_CHECK(tensors.size() > 0);
131+
ET_LOG_AND_RETURN_IF_FALSE(tensors.size() > 0);
128132

129133
// All input tensors need to be of the same size
130134
// https://pytorch.org/docs/stable/generated/torch.stack.html
131135
for (size_t i = 0; i < tensors.size(); i++) {
132136
// All input dtypes must be castable to the output dtype.
133-
ET_CHECK(canCast(tensors[i].scalar_type(), out.scalar_type()));
137+
ET_LOG_AND_RETURN_IF_FALSE(
138+
canCast(tensors[i].scalar_type(), out.scalar_type()));
134139

135-
ET_CHECK(tensors[i].dim() == tensors[0].dim());
140+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(tensors[i], tensors[0].dim()));
136141
for (size_t d = 0; d < tensors[i].dim(); d++) {
137-
ET_CHECK(tensors[i].size(d) == tensors[0].size(d));
142+
ET_LOG_AND_RETURN_IF_FALSE(
143+
tensors_have_same_size_at_dims(tensors[i], d, tensors[0], d));
138144
}
139145
}
140146

141147
// The output tensor will have a dimension inserted, so dim should be between
142148
// 0 and ndim_of_inputs + 1
143-
ET_CHECK(dim >= 0 && dim < tensors[0].dim() + 1);
149+
ET_LOG_AND_RETURN_IF_FALSE(dim >= 0 && dim < tensors[0].dim() + 1);
150+
151+
return true;
144152
}
145153

146154
void get_stack_out_target_size(

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace torch {
1414
namespace executor {
1515

16-
void check_cat_args(
16+
bool check_cat_args(
1717
exec_aten::ArrayRef<Tensor> tensors,
1818
int64_t dim,
1919
Tensor& out);
@@ -24,15 +24,15 @@ void get_cat_out_target_size(
2424
Tensor::SizesType* out_sizes,
2525
size_t* out_ndim);
2626

27-
void check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out);
27+
bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out);
2828

2929
void get_permute_copy_out_target_size(
3030
const Tensor& in,
3131
IntArrayRef dims,
3232
Tensor::SizesType* out_sizes,
3333
size_t* out_ndim);
3434

35-
void check_stack_args(
35+
bool check_stack_args(
3636
exec_aten::ArrayRef<Tensor> tensors,
3737
int64_t dim,
3838
Tensor& out);

kernels/portable/cpu/util/normalization_ops_util.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,14 @@
88

99
#include <cstring>
1010

11-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12-
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13-
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14-
#include <executorch/runtime/platform/assert.h>
11+
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
1512

1613
namespace torch {
1714
namespace executor {
1815

1916
using Tensor = exec_aten::Tensor;
2017

21-
void check_batch_norm_args(
18+
bool check_batch_norm_args(
2219
const Tensor& in,
2320
const exec_aten::optional<Tensor>& weight,
2421
const exec_aten::optional<Tensor>& bias,
@@ -28,27 +25,34 @@ void check_batch_norm_args(
2825
double eps,
2926
Tensor& out) {
3027
// All tensors must be the same dtype
31-
ET_CHECK_SAME_DTYPE3(in, running_mean, running_var);
32-
ET_CHECK_SAME_DTYPE2(in, out);
28+
ET_LOG_AND_RETURN_IF_FALSE(
29+
tensors_have_same_dtype(in, running_mean, running_var));
30+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
3331
if (weight.has_value()) {
34-
ET_CHECK_SAME_DTYPE2(in, weight.value());
32+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight.value()));
3533
}
3634
if (bias.has_value()) {
37-
ET_CHECK_SAME_DTYPE2(in, bias.value());
35+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, bias.value()));
3836
}
3937

4038
size_t C_dim = in.dim() >= 1 ? 1 : 0;
4139
// All parameter tensors must be of dim 1 and have length equal to the
4240
// channels dim of in
43-
ET_CHECK(running_mean.dim() == 1 && running_mean.size(0) == in.size(C_dim));
44-
ET_CHECK(running_var.dim() == 1 && running_var.size(0) == in.size(C_dim));
41+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(running_mean, 1));
42+
ET_LOG_AND_RETURN_IF_FALSE(
43+
tensors_have_same_size_at_dims(running_mean, 0, in, C_dim));
4544
if (weight.has_value()) {
46-
ET_CHECK(
47-
weight.value().dim() == 1 && weight.value().size(0) == in.size(C_dim));
45+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight.value(), 1));
46+
ET_LOG_AND_RETURN_IF_FALSE(
47+
tensors_have_same_size_at_dims(weight.value(), 0, in, C_dim));
4848
}
4949
if (bias.has_value()) {
50-
ET_CHECK(bias.value().dim() == 1 && bias.value().size(0) == in.size(C_dim));
50+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(bias.value(), 1));
51+
ET_LOG_AND_RETURN_IF_FALSE(
52+
tensors_have_same_size_at_dims(bias.value(), 0, in, C_dim));
5153
}
54+
55+
return true;
5256
}
5357

5458
} // namespace executor

kernels/portable/cpu/util/normalization_ops_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace torch {
1414
namespace executor {
1515

16-
void check_batch_norm_args(
16+
bool check_batch_norm_args(
1717
const Tensor& in,
1818
const exec_aten::optional<Tensor>& weight,
1919
const exec_aten::optional<Tensor>& bias,

0 commit comments

Comments
 (0)