From 60bb4ef10df46e207fc57d1b6f3368b0bd86c0e0 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Sat, 26 Apr 2025 04:06:00 -0700 Subject: [PATCH] Add direct copy fast path for portable copy op (#10487) Summary: The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed. This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper. I benchmarked a simple model on S23 and Pixel 5: ``` class TestModel(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024))) def forward(self, x): self.buffer.add_(x) return self.buffer model = TestModel() inputs = (torch.randn(2, 10, 1024, 1024),) lowered = to_edge_transform_and_lower( torch.export.export(model, inputs), partitioner=[XnnpackPartitioner()], ).to_executorch() ``` S23, average of 50 runs, time in copy_: 4.1ms vs 22.3ms Pixel 5, average of 50 runs, time in copy_: 12.1ms vs 66.6ms This is approximately a ~5.5x speedup of the copy operator. Reviewed By: swolchok Differential Revision: D73656456 Pulled By: GregoryComer --- kernels/portable/cpu/op_copy.cpp | 58 ++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 19b0c3a2f6a..069800cc399 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -46,17 +46,24 @@ Tensor& copy_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "copy.out"; - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, - ctx, - in, - utils::SupportedTensorDtypes::REALHBBF16, - src, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + // Use direct copy fast path if broadcast is not needed and tensors are + // non-empty + if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) && + src.numel() > 0) { + std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); + } else { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + } return out; } @@ -79,17 +86,24 @@ Tensor& copy_( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "copy_"; - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, - ctx, - in, - utils::SupportedTensorDtypes::REALHBBF16, - src, - utils::SupportedTensorDtypes::REALHBBF16, - in, - utils::SupportedTensorDtypes::REALHBBF16); - }); + // Use direct copy fast path if broadcast is not needed and tensors are + // non-empty + if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) && + src.numel() > 0) { + std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes()); + } else { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + in, + utils::SupportedTensorDtypes::REALHBBF16); + }); + } return in; }