Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok to leave here since we are gonna need it in the future, but when we talk about adding portable kernels we mainly focus on the kernels in the runtime, specificly under executorch/kernels/portable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, but I needed to register the operator here otherwise the tests I added fail since there is no Python side reference to _clone_dim_order

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, should all the tests for this PR have been on the kernel side and not in test_memory_format_ops_pass.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR should only focus on our runtime changes, so no python side will refer to our new operator.

should all the tests for this PR have been on the kernel side and not in test_memory_format_ops_pass.py

Yes absolutely correct! This PR is only for portable kernel and its tests. Sorry for any misleading!

Copy link
Contributor Author

@keyprocedure keyprocedure Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve added the runtime test and all tests passed locally. I couldn’t run the DynamicShapeUnbound test since it depends on SupportedFeatures and supported_features.h doesn’t seem to be generated in OSS builds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please disregard my previous comment about the missing SupportedFeatures dependency, the issue was with my local build setup. All tests pass now.

"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

lib.define(
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)


def _op_impl(target, *args, **kwargs):
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
Expand Down Expand Up @@ -57,6 +65,16 @@ def _empty_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)


@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
def _clone_dim_order_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)


@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
def _clone_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)


"""
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
"""
Expand Down
34 changes: 34 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,40 @@ def test_op_empty_replacement_contiguous(self) -> None:
),
)

def test_op_clone_dim_order_preserves_channels_last(self):
x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last)
y = torch.ops.dim_order_ops._clone_dim_order.default(x)

assert y.is_contiguous(
memory_format=torch.channels_last
), "_clone_dim_order output is not in channels_last memory format."
assert torch.allclose(x, y)

def test_op_clone_dim_order_to_contiguous(self):
x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last)
contiguous_dim_order = get_dim_order(torch.contiguous_format, x.dim())
y = torch.ops.dim_order_ops._clone_dim_order.default(
x, dim_order=contiguous_dim_order
)

assert (
y.is_contiguous()
), "_clone_dim_order output is not in contiguous memory format"
assert torch.allclose(x, y)

def test_op_clone_dim_order_out_to_channels_last(self):
x = torch.randn(2, 3, 4, 5).contiguous()
y = torch.empty_like(x, memory_format=torch.channels_last)
channels_last_dim_order = get_dim_order(torch.channels_last, y.dim())
torch.ops.dim_order_ops._clone_dim_order.out(
x, dim_order=channels_last_dim_order, out=y
)

assert y.is_contiguous(
memory_format=torch.channels_last
), "_clone_dim_order output is not in channels_last memory format"
assert torch.allclose(x, y)

def test_op_dim_order_update(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
Expand Down
80 changes: 80 additions & 0 deletions kernels/portable/cpu/op__clone_dim_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using Tensor = executorch::aten::Tensor;

template <typename T>
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;

/**
* _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
* dim_order=None, Tensor(a!) out) -> Tensor(a!)
*
* Clones via element-wise copy while preserving dim_order.
*/
Tensor& _clone_dim_order_out(
KernelRuntimeContext& ctx,
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
(void)ctx;

// Ensure input and output dtype match.
ET_KERNEL_CHECK(
ctx, self.scalar_type() == out.scalar_type(), InvalidArgument, out);

// Ensure output has the same layout as input or matches dim_order.
ET_KERNEL_CHECK(
ctx,
check__to_dim_order_copy_args(self, non_blocking, dim_order, out),
InvalidArgument,
out);

// Ensure input and output shapes match, resizing if necessary.
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
InvalidArgument,
out);

if (self.numel() == 0) {
return out;
}

// Select the correct input dtype and copy the tensors.
ET_SWITCH_REALHBBF16_TYPES(
self.scalar_type(),
ctx,
"dim_order_ops::_clone_dim_order.out",
CTYPE,
[&] { _to_dim_order_copy_impl<CTYPE, CTYPE>(self, out); });

return out;
}

Tensor& _clone_dim_order_out(
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
executorch::ET_RUNTIME_NAMESPACE::KernelRuntimeContext context{};
return _clone_dim_order_out(context, self, non_blocking, dim_order, out);
}

} // namespace native
} // namespace executor
} // namespace torch
23 changes: 0 additions & 23 deletions kernels/portable/cpu/op__to_dim_order_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,6 @@ using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;
template <typename T>
using Optional = std::optional<T>;

namespace {

template <typename SELF_CTYPE, typename OUT_CTYPE>
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();

// Here we make a slightly off-label use of
// BroadcastIndexesRange. It always assumes it doesn't have to care
// about different dim_order between input and output, but we can
// just force it to respect strides (and thus dim_order) for its
// inputs using support_noncontiguous_input_tensors=true, and then pretend
// the output is just another input.
for (const auto [unused_index, self_data_index, out_data_index] :
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
/*dummy output*/ self, self, out)) {
(void)unused_index;
out_data[out_data_index] =
static_cast<OUT_CTYPE>(self_data[self_data_index]);
}
}
} // namespace

// _to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]?
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
Tensor& _to_dim_order_copy_out(
Expand Down
24 changes: 24 additions & 0 deletions kernels/portable/cpu/util/copy_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once
#include <c10/util/irange.h>

#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -77,6 +78,29 @@ void as_strided_copy(
}
}

/**
* Copies and casts a tensor while preserving input dim_order.
*/
template <typename SELF_CTYPE, typename OUT_CTYPE>
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();

// Here we make a slightly off-label use of
// BroadcastIndexesRange. It always assumes it doesn't have to care
// about different dim_order between input and output, but we can
// just force it to respect strides (and thus dim_order) for its
// inputs using support_noncontiguous_input_tensors=true, and then pretend
// the output is just another input.
for (const auto [unused_index, self_data_index, out_data_index] :
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
/*dummy output*/ self, self, out)) {
(void)unused_index;
out_data[out_data_index] =
static_cast<OUT_CTYPE>(self_data[self_data_index]);
}
}

bool check_cat_args(
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Expand Down
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,8 @@
kernels:
- arg_meta: null
kernel_name: torch::executor::_to_dim_order_copy_out

- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: torch::executor::_clone_dim_order_out
Loading