Skip to content

Commit d3367e6

Browse files
authored
Introduce preprocess custom ops
Differential Revision: D60491675 Pull Request resolved: #4548
1 parent b671e24 commit d3367e6

File tree

6 files changed

+269
-0
lines changed

6 files changed

+269
-0
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,16 @@ runtime.python_test(
2121
"//caffe2:torch",
2222
],
2323
)
24+
25+
runtime.python_test(
26+
name = "test_preprocess_custom_ops",
27+
srcs = [
28+
"test_preprocess_custom_ops.py",
29+
],
30+
preload_deps = [
31+
":preprocess_custom_ops_py",
32+
],
33+
deps = [
34+
"//caffe2:torch",
35+
],
36+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10+
#include <executorch/extension/llm/custom_ops/op_tile_crop.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
Tensor& tile_crop_out_impl(
18+
RuntimeContext& ctx,
19+
const Tensor& input, // NOLINT
20+
const int64_t tile_size, // NOLINT
21+
Tensor& out) {
22+
(void)ctx;
23+
return out;
24+
}
25+
26+
} // namespace native
27+
} // namespace executor
28+
} // namespace torch
29+
30+
EXECUTORCH_LIBRARY(
31+
preprocess,
32+
"tile_crop.out",
33+
torch::executor::native::tile_crop_out_impl);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& tile_crop_out_impl(
19+
RuntimeContext& ctx,
20+
const Tensor& input,
21+
const int64_t tile_size,
22+
Tensor& out);
23+
24+
} // namespace native
25+
} // namespace executor
26+
} // namespace torch
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
from typing import List
11+
12+
import torch
13+
14+
from torch.library import impl, Library
15+
16+
preprocess_op_lib = Library("preprocess", "DEF")
17+
18+
# Register and define pad and out variant.
19+
# Note: pad doesn't require an explicit meta kernel because
20+
# CompositeExplicitAutograd automatically registers the implementation to meta,
21+
# and meta kernels do not go through functionalization. The implementation
22+
# does not export due to issues during functionalization.
23+
# See: https://github.com/pytorch/pytorch/issues/120288
24+
preprocess_op_lib.define("pad(Tensor image, SymInt[] padding) -> Tensor")
25+
26+
27+
@impl(preprocess_op_lib, "pad", dispatch_key="CompositeExplicitAutograd")
28+
def pad_impl(
29+
image: torch.Tensor,
30+
padding: List[int],
31+
) -> torch.Tensor:
32+
output = torch.empty(
33+
[image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]],
34+
dtype=image.dtype,
35+
device=image.device,
36+
requires_grad=False,
37+
)
38+
output = torch.fill(output, 0)
39+
output.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image)
40+
return output
41+
42+
43+
preprocess_op_lib.define(
44+
"pad.out(Tensor image, SymInt[] padding, *, Tensor(a!) out) -> Tensor(a!)"
45+
)
46+
47+
48+
@impl(preprocess_op_lib, "pad.out", dispatch_key="CompositeExplicitAutograd")
49+
def pad_out_impl(
50+
image: torch.Tensor,
51+
padding: List[int],
52+
out: torch.Tensor,
53+
) -> torch.Tensor:
54+
out = torch.empty(
55+
[image.shape[0], image.shape[1] + padding[3], image.shape[2] + padding[1]],
56+
dtype=image.dtype,
57+
device=image.device,
58+
requires_grad=False,
59+
)
60+
out = torch.fill(out, 0)
61+
out.narrow(1, 0, image.shape[1]).narrow(2, 0, image.shape[2]).copy_(image)
62+
return out
63+
64+
65+
# Register and define tile_crop and out variant.
66+
preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")
67+
68+
69+
@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd")
70+
def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor:
71+
c = input.shape[0]
72+
h = input.shape[1]
73+
w = input.shape[2]
74+
tiles_height = h // tile_size
75+
tiles_width = w // tile_size
76+
tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size)
77+
transposed = tile_cropped.permute(1, 3, 0, 2, 4)
78+
tiles = transposed.contiguous().view(
79+
tiles_height * tiles_width, c, tile_size, tile_size
80+
)
81+
return tiles
82+
83+
84+
preprocess_op_lib.define(
85+
"tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)"
86+
)
87+
88+
89+
@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd")
90+
def tile_crop_out_impl(
91+
input: torch.Tensor, tile_size: int, out: torch.Tensor
92+
) -> torch.Tensor:
93+
out = input.clone()
94+
c = out.shape[0]
95+
h = out.shape[1]
96+
w = out.shape[2]
97+
tiles_height = h // tile_size
98+
tiles_width = w // tile_size
99+
out = out.view(c, tiles_height, tile_size, tiles_width, tile_size)
100+
out = out.permute(1, 3, 0, 2, 4)
101+
out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size)
102+
return out
103+
104+
105+
# Register meta kernel to prevent export tracing into the tile_crop impl.
106+
@torch.library.register_fake("preprocess::tile_crop")
107+
def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
108+
# Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles.
109+
# We should export with n = max_num_tiles. Set 50 for now.
110+
return torch.empty([50, output.size(0), 224, 224])

extension/llm/custom_ops/targets.bzl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,36 @@ def define_common_targets():
8686
":custom_ops",
8787
],
8888
)
89+
90+
## For preprocess
91+
runtime.python_library(
92+
name = "preprocess_custom_ops_py",
93+
srcs = [
94+
"preprocess_custom_ops.py",
95+
],
96+
visibility = [
97+
"//executorch/...",
98+
"@EXECUTORCH_CLIENTS",
99+
],
100+
deps = [
101+
"//caffe2:torch",
102+
],
103+
)
104+
105+
runtime.cxx_library(
106+
name = "op_tile_crop",
107+
srcs = ["op_tile_crop.cpp"],
108+
exported_headers = ["op_tile_crop.h"],
109+
exported_deps = [
110+
"//executorch/runtime/kernel:kernel_includes",
111+
"//executorch/extension/kernel_util:kernel_util",
112+
],
113+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
114+
visibility = [
115+
"//executorch/...",
116+
"@EXECUTORCH_CLIENTS",
117+
],
118+
# @lint-ignore BUCKLINT link_whole
119+
link_whole = True,
120+
force_static = True,
121+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
from typing import List, Tuple
11+
12+
import torch
13+
14+
from .preprocess_custom_ops import preprocess_op_lib # noqa
15+
16+
17+
class PreprocessTest(unittest.TestCase):
18+
19+
def setUp(self):
20+
# pad
21+
self.pad_input = torch.ones(3, 200, 300)
22+
23+
# tile_crop
24+
self.tile_size = 224
25+
26+
def _compare_pad(self, image: torch.Tensor, padding: List[int]) -> None:
27+
output = torch.ops.preprocess.pad.default(image, padding)
28+
output_ref = torch.nn.functional.pad(image, padding)
29+
self.assertTrue(torch.allclose(output_ref, output, 1e-6))
30+
31+
def _test_tile_crop(self, image: torch.Tensor, expected_shape: Tuple[int]) -> None:
32+
output = torch.ops.preprocess.tile_crop.default(image, self.tile_size)
33+
self.assertTrue(output.shape == expected_shape)
34+
35+
def test_op_pad_without_padding(self):
36+
self._compare_pad(self.pad_input, [0, 0, 0, 0])
37+
38+
def test_op_pad_with_right_bottom_padding(self):
39+
self._compare_pad(self.pad_input, [0, 124, 0, 148])
40+
41+
def test_op_pad_with_right_padding(self):
42+
self._compare_pad(self.pad_input, [0, 124, 0, 0])
43+
44+
def test_op_pad_with_bottom_padding(self):
45+
self._compare_pad(self.pad_input, [0, 0, 0, 148])
46+
47+
def test_op_tile_crop_2x2(self):
48+
self._test_tile_crop(torch.ones(3, 448, 448), (4, 3, 224, 224))
49+
50+
def test_op_tile_crop_1x3(self):
51+
self._test_tile_crop(torch.ones(3, 224, 672), (3, 3, 224, 224))
52+
53+
def test_op_tile_crop_4x2(self):
54+
self._test_tile_crop(torch.ones(3, 896, 448), (8, 3, 224, 224))

0 commit comments

Comments
 (0)