Skip to content

Commit 8a8e876

Browse files
lucylqfacebook-github-bot
authored andcommitted
Register preprocess in pytorch (pytorch#5350)
Summary: - Following https://pytorch.org/executorch/stable/kernel-library-custom-aten-kernel.html, use WRAP_TO_ATEN to register preprocess in pytorch - Create a separate `op_tile_crop_aot.py` that registers the C++ aot library into Python. Inside export_preprocess, use `op_tile_crop_aot.py` instead of `preprocess_custom_ops.py`, which is the pure python lib. Otherwise, we end up loading the C++ library when the python one already exists. Note, include these PyTorch changes for AOTI export: pytorch/pytorch#135933 Pull Request resolved: pytorch#5350 Test Plan: ``` >>> import torch >>> from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip >>> x = torch._export.aot_load("/home/lfq/local/executorch/aoti_preprocess.so", "cpu") >>> img = torch.ones([3, 600, 800]) >>> canvas_size = torch.tensor([448, 448]) >>> target_size = torch.tensor([336, 448]) >>> res = x(img, target_size, canvas_size) >>> res[0].shape torch.Size([4, 3, 224, 224]) >>> res[1] tensor([2, 2]) >>> ``` Reviewed By: larryliu0820 Differential Revision: D62651605 Pulled By: lucylq fbshipit-source-id: bdf5b46033ebbd73d10307ab58219743a73fd6fd
1 parent f7954f6 commit 8a8e876

File tree

5 files changed

+93
-1
lines changed

5 files changed

+93
-1
lines changed

examples/models/flamingo/preprocess/export_preprocess_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
1212
from executorch.exir.program._program import ExecutorchProgramManager
1313

14-
from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa
14+
from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa
1515

1616
from torch.export import Dim, ExportedProgram
1717
from torchtune.models.clip.inference._transform import _CLIPImageTransform

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
7575
add_library(
7676
custom_ops_aot_lib SHARED
7777
${_custom_ops__srcs} ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop_aot.cpp
7879
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
7980
)
8081
target_include_directories(
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
11+
#include <executorch/extension/llm/custom_ops/op_tile_crop.h>
12+
13+
#include <torch/library.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
namespace native {
19+
20+
Tensor&
21+
tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) {
22+
exec_aten::RuntimeContext context{};
23+
return tile_crop_out_impl(context, input, tile_size, out);
24+
}
25+
26+
at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) {
27+
// max_num_tiles = 4, num_channels = 3.
28+
auto output = at::empty({4, 3, tile_size, tile_size});
29+
30+
WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)
31+
(input, tile_size, output);
32+
return output;
33+
}
34+
35+
} // namespace native
36+
} // namespace executor
37+
} // namespace torch
38+
39+
TORCH_LIBRARY(preprocess, m) {
40+
m.def("tile_crop(Tensor input, int tile_size) -> Tensor");
41+
m.def(
42+
"tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)");
43+
}
44+
45+
TORCH_LIBRARY_IMPL(preprocess, CompositeExplicitAutograd, m) {
46+
m.impl("tile_crop", torch::executor::native::tile_crop_aten);
47+
m.impl(
48+
"tile_crop.out",
49+
WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2));
50+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from pathlib import Path
9+
10+
import torch
11+
12+
try:
13+
tile_crop = torch.ops.preprocess.tile_crop.default
14+
assert tile_crop is not None
15+
except:
16+
libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
17+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
18+
logging.info(f"Loading custom ops library: {libs[0]}")
19+
torch.ops.load_library(libs[0])
20+
tile_crop = torch.ops.preprocess.tile_crop.default
21+
assert tile_crop is not None
22+
23+
preprocess_ops_lib = torch.library.Library("preprocess", "IMPL")
24+
25+
MAX_NUM_TILES = 4
26+
27+
28+
# Register meta kernel to prevent export tracing into the tile_crop impl.
29+
@torch.library.register_fake("preprocess::tile_crop")
30+
def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
31+
# Returned tensor is of size [n, 3, 224, 224], where n = number of tiles.
32+
# Use an unbacked symint to create an upper-bounded dynamic shape output.
33+
# Otherwise, output is set to a static shape, and we can only output
34+
# tensors of shape [MAX_NUM_TILES, 3, 224, 224].
35+
ctx = torch._custom_ops.get_ctx()
36+
s0 = ctx.create_unbacked_symint()
37+
torch._constrain_as_size(s0, 0, MAX_NUM_TILES)
38+
return torch.empty([s0, output.size(0), tile_size, tile_size])

extension/llm/custom_ops/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ def define_common_targets():
3535
name = "custom_ops_aot_lib" + mkl_dep,
3636
srcs = [
3737
"op_sdpa_aot.cpp",
38+
"op_tile_crop_aot.cpp",
39+
"op_tile_crop.cpp",
3840
],
41+
headers = ["op_tile_crop.h"],
3942
visibility = [
4043
"//executorch/...",
4144
"@EXECUTORCH_CLIENTS",

0 commit comments

Comments
 (0)