Skip to content

Commit bbabd28

Browse files
Implement tile_crop custom op (#4622) (#4622)
Summary: Pull Request resolved: #4622 Reviewed By: lucylq Differential Revision: D61000109
1 parent 3f9b39e commit bbabd28

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed

extension/llm/custom_ops/op_tile_crop.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,95 @@
1313
namespace torch {
1414
namespace executor {
1515
namespace native {
16+
namespace {
17+
18+
bool check_tile_crop_out_args(
19+
const Tensor& in,
20+
int64_t tile_size,
21+
Tensor& out) {
22+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
23+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
24+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 4));
25+
ET_LOG_AND_RETURN_IF_FALSE(tile_size > 0);
26+
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % tile_size == 0);
27+
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % tile_size == 0);
28+
return true;
29+
}
30+
31+
void get_tile_crop_out_target_size(
32+
const Tensor& in,
33+
int64_t tile_size,
34+
exec_aten::SizesType* out_sizes,
35+
size_t* out_ndim) {
36+
*out_ndim = in.dim() + 1;
37+
38+
out_sizes[0] = in.size(1) * in.size(2) / (tile_size * tile_size);
39+
out_sizes[1] = in.size(0);
40+
out_sizes[2] = tile_size;
41+
out_sizes[3] = tile_size;
42+
}
43+
44+
template <typename CTYPE>
45+
void tile_crop_impl(const Tensor& in, int64_t tile_size, Tensor& out) {
46+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48+
49+
const auto channels = in.size(0);
50+
const auto height = in.size(1);
51+
const auto width = in.size(2);
52+
53+
const auto HdivS = height / tile_size;
54+
const auto WdivS = width / tile_size;
55+
56+
size_t out_ix = 0;
57+
for (size_t bH = 0; bH < HdivS; bH++) {
58+
for (size_t bW = 0; bW < WdivS; bW++) {
59+
for (size_t c = 0; c < channels; c++) {
60+
for (size_t h = 0; h < tile_size; h++) {
61+
for (size_t w = 0; w < tile_size; w++) {
62+
size_t in_h = bH * tile_size + h;
63+
size_t in_w = bW * tile_size + w;
64+
size_t in_ix = c * height * width + in_h * width + in_w;
65+
66+
out_data[out_ix++] = in_data[in_ix];
67+
}
68+
}
69+
}
70+
}
71+
}
72+
}
73+
74+
} // namespace
1675

1776
Tensor& tile_crop_out_impl(
1877
RuntimeContext& ctx,
1978
const Tensor& input, // NOLINT
2079
const int64_t tile_size, // NOLINT
2180
Tensor& out) {
22-
(void)ctx;
81+
ET_KERNEL_CHECK(
82+
ctx,
83+
check_tile_crop_out_args(input, tile_size, out),
84+
InvalidArgument,
85+
out);
86+
87+
// @lint-ignore CLANGTIDY facebook-hte-CArray
88+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
89+
size_t expected_out_dim = 0;
90+
get_tile_crop_out_target_size(
91+
input, tile_size, expected_out_size, &expected_out_dim);
92+
93+
ET_KERNEL_CHECK(
94+
ctx,
95+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
96+
InvalidArgument,
97+
out);
98+
99+
constexpr auto name = "tile_crop.out";
100+
101+
ET_SWITCH_ALL_TYPES(out.scalar_type(), ctx, name, CTYPE, [&]() {
102+
tile_crop_impl<CTYPE>(input, tile_size, out);
103+
});
104+
23105
return out;
24106
}
25107

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_tile_crop.h>
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13+
#include <gtest/gtest.h>
14+
15+
using namespace ::testing;
16+
using exec_aten::ScalarType;
17+
using exec_aten::Tensor;
18+
using torch::executor::testing::TensorFactory;
19+
20+
class OpTileCropOutTest : public OperatorTest {
21+
protected:
22+
Tensor& op_tile_crop_out(const Tensor& self, int64_t tile_size, Tensor& out) {
23+
return torch::executor::native::tile_crop_out_impl(
24+
context_, self, tile_size, out);
25+
}
26+
27+
template <ScalarType DTYPE_IN>
28+
void test_tile_crop() {
29+
TensorFactory<DTYPE_IN> tf_in;
30+
31+
const std::vector<int32_t> sizes = {1, 4, 4};
32+
const std::vector<int32_t> out_sizes = {4, 1, 2, 2};
33+
34+
Tensor out = tf_in.zeros(out_sizes);
35+
36+
// clang-format off
37+
op_tile_crop_out(
38+
tf_in.make(
39+
sizes, { 0, 1, 2, 3,
40+
4, 5, 6, 7,
41+
8, 9, 10, 11,
42+
12, 13, 14, 15}),
43+
2,
44+
out);
45+
EXPECT_TENSOR_EQ(
46+
out,
47+
tf_in.make(
48+
out_sizes, {0, 1, 4, 5,
49+
2, 3, 6, 7,
50+
8, 9, 12, 13,
51+
10, 11, 14, 15}));
52+
// clang-format on
53+
}
54+
};
55+
56+
//
57+
// Correctness Tests
58+
//
59+
60+
/**
61+
* Uses the function templates above to test all input dtypes.
62+
*/
63+
TEST_F(OpTileCropOutTest, AllRealDtypesSupported){
64+
#define ENUMERATE_TEST_ENTRY(ctype, dtype) test_tile_crop<ScalarType::dtype>();
65+
ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
66+
#undef ENUMERATE_TEST_ENTRY
67+
}
68+
69+
// Mismatched shape tests.
70+
TEST_F(OpTileCropOutTest, InvalidInputShapeDies) {
71+
TensorFactory<ScalarType::Int> tf;
72+
73+
// Input tensors with invalid shapes. 7 is not divisible by tile_size
74+
Tensor in = tf.ones(/*sizes=*/{1, 7, 8});
75+
Tensor out = tf.zeros(/*sizes=*/{16, 1, 2, 2});
76+
77+
ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 2, out));
78+
}
79+
80+
TEST_F(OpTileCropOutTest, WrongInputRankDies) {
81+
TensorFactory<ScalarType::Int> tf;
82+
83+
// Tile crop requires a 3D input tensor.
84+
Tensor in = tf.ones(/*sizes=*/{1, 2});
85+
Tensor out = tf.zeros(/*sizes=*/{1, 2});
86+
87+
ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 2, out));
88+
}
89+
90+
TEST_F(OpTileCropOutTest, DifferentDtypeDies) {
91+
TensorFactory<ScalarType::Int> tf;
92+
TensorFactory<ScalarType::Float> tf_float;
93+
94+
Tensor in = tf.ones(/*sizes=*/{2, 12, 12});
95+
96+
// Tile crop requires two tensors with the same dtype.
97+
Tensor out = tf_float.zeros(/*sizes=*/{9, 2, 4, 4});
98+
99+
ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 3, out));
100+
}
101+
102+
TEST_F(OpTileCropOutTest, NegativeTileSizeDies) {
103+
TensorFactory<ScalarType::Int> tf;
104+
Tensor in = tf.ones(/*sizes=*/{2, 12, 12});
105+
Tensor out = tf.zeros(/*sizes=*/{9, 2, 4, 4});
106+
ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, -3, out));
107+
}

extension/llm/custom_ops/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,17 @@ def define_common_targets():
119119
link_whole = True,
120120
force_static = True,
121121
)
122+
123+
runtime.cxx_test(
124+
name = "op_tile_crop_test",
125+
srcs = [
126+
"op_tile_crop_test.cpp",
127+
],
128+
visibility = ["//executorch/..."],
129+
deps = [
130+
"//executorch/runtime/core/exec_aten:lib",
131+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
132+
"//executorch/kernels/test:test_util",
133+
":op_tile_crop",
134+
],
135+
)

0 commit comments

Comments
 (0)