Skip to content

Commit 5605954

Browse files
dbortfacebook-github-bot
authored andcommitted
Fix unqualified uses of executorch functions (pytorch#5709)
Summary: Pull Request resolved: pytorch#5709 I'm not sure how this worked before, but these sites called functions under torch::executor without actually qualifying them. Qualify them explicitly, because the "can call without qualification" magic stops working when we move the etensor types in D63294217. In a few places I used `namespace etrt = executorch::runtime;` instead of a using statement for a particular function, like `etrt::isIntegralType`. If I just say `using executorch::runtime::isIntegralType`, those files fail in aten mode because the unqualified call to `isIntegralType()` is deemed ambiguous in the presence of `c10::isIntegralType()` -- but afaict that `c10` version isn't `using`'d into the global namespace, so I don't know why it conflicts. It'd be good to figure that out at some point, but this works for now. I also updated custom_kernel_example to stop using the `torch::` namespace. Reviewed By: swolchok Differential Revision: D63476419 fbshipit-source-id: 2300fc1fa5d5af6f6f747d3a0c31db5099389dcc
1 parent 51e79a0 commit 5605954

File tree

10 files changed

+46
-28
lines changed

10 files changed

+46
-28
lines changed

backends/cadence/reference/operators/quantized_layer_norm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
#include <cmath>
1313

14-
using Tensor = exec_aten::Tensor;
14+
using executorch::aten::Tensor;
15+
using executorch::runtime::getLeadingDims;
1516
using executorch::runtime::KernelRuntimeContext;
1617

1718
namespace impl {

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace impl {
1313
namespace reference {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
16+
using executorch::aten::Tensor;
17+
using executorch::runtime::getLeadingDims;
1718
using executorch::runtime::KernelRuntimeContext;
1819

1920
void quantized_linear_out(

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace impl {
1313
namespace reference {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
16+
using executorch::aten::Tensor;
17+
using executorch::runtime::getLeadingDims;
1718
using executorch::runtime::KernelRuntimeContext;
1819

1920
// The quantized matmul. The quantized matmul accumulates in a wider register,

kernels/portable/cpu/util/test/broadcast_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ using exec_aten::ScalarType;
2222
using exec_aten::Tensor;
2323
using executorch::runtime::ArrayRef;
2424
using executorch::runtime::testing::TensorFactory;
25+
using torch::executor::broadcast_tensor;
26+
using torch::executor::delinearize_index;
27+
using torch::executor::get_broadcast_target_size;
28+
using torch::executor::linearize_access_indexes;
29+
using torch::executor::tensor_is_broadcastable_to;
30+
using torch::executor::tensors_are_broadcastable_between;
2531

2632
TEST(BroadcastUtilTest, BroadcastTensor) {
2733
TensorFactory<ScalarType::Int> tf;

kernels/portable/cpu/util/test/reduce_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ using exec_aten::ArrayRef;
1919
using exec_aten::optional;
2020
using exec_aten::ScalarType;
2121
using exec_aten::Tensor;
22-
using torch::executor::testing::TensorFactory;
22+
using executorch::runtime::testing::TensorFactory;
23+
using torch::executor::apply_over_dim;
24+
using torch::executor::apply_over_dim_list;
25+
using torch::executor::get_out_numel;
2326

2427
void _apply_over_dim(const Tensor& in, const optional<int64_t>& dim) {
2528
int64_t* in_data = in.mutable_data_ptr<int64_t>();

kernels/test/custom_kernel_example/my_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
- op: relu.out
66
kernels:
77
- arg_meta: null
8-
kernel_name: torch::my_custom_kernel::my_relu_out
8+
kernel_name: my_custom_kernels::my_relu_out

kernels/test/custom_kernel_example/op_relu.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/platform/assert.h>
1414

15-
namespace torch {
16-
namespace my_custom_kernel {
15+
namespace my_custom_kernels {
1716
namespace native {
1817

19-
using Tensor = exec_aten::Tensor;
20-
using ScalarType = exec_aten::ScalarType;
21-
using executor::Error;
18+
using exec_aten::ScalarType;
19+
using exec_aten::Tensor;
20+
using executorch::runtime::Error;
2221
using executorch::runtime::KernelRuntimeContext;
22+
using executorch::runtime::resize_tensor;
23+
using executorch::runtime::tensors_have_same_shape_and_dtype;
2324

2425
namespace {
2526

@@ -67,7 +68,7 @@ my_relu_out(KernelRuntimeContext& context, const Tensor& input, Tensor& out) {
6768
resize(out, input.sizes());
6869
ET_KERNEL_CHECK(
6970
context,
70-
executor::tensors_have_same_shape_and_dtype(input, out),
71+
tensors_have_same_shape_and_dtype(input, out),
7172
InvalidArgument,
7273
out);
7374

@@ -94,5 +95,4 @@ my_relu_out(KernelRuntimeContext& context, const Tensor& input, Tensor& out) {
9495
}
9596

9697
} // namespace native
97-
} // namespace my_custom_kernel
98-
} // namespace torch
98+
} // namespace my_custom_kernels

kernels/test/op_add_test.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
#include <iostream>
1919

2020
using namespace ::testing;
21-
using exec_aten::Scalar;
22-
using exec_aten::ScalarType;
23-
using exec_aten::Tensor;
21+
using executorch::aten::Scalar;
22+
using executorch::aten::ScalarType;
23+
using executorch::aten::Tensor;
24+
using executorch::runtime::testing::TensorFactory;
2425
using torch::executor::testing::SupportedFeatures;
25-
using torch::executor::testing::TensorFactory;
26+
namespace etrt = executorch::runtime;
2627

2728
class OpAddOutKernelTest : public OperatorTest {
2829
protected:
@@ -63,7 +64,8 @@ class OpAddOutKernelTest : public OperatorTest {
6364
test_add<DTYPE_A, DTYPE_B, ScalarType::Float>();
6465
test_add<DTYPE_A, DTYPE_B, ScalarType::Double>();
6566
// Integral out type is only allowed if both inputs are integral types
66-
if (isIntegralType(DTYPE_A, false) && isIntegralType(DTYPE_B, false)) {
67+
if (etrt::isIntegralType(DTYPE_A, false) &&
68+
etrt::isIntegralType(DTYPE_B, false)) {
6769
test_add<DTYPE_A, DTYPE_B, ScalarType::Int>();
6870
test_add<DTYPE_A, DTYPE_B, ScalarType::Long>();
6971
}

kernels/test/op_mul_test.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
#include <gtest/gtest.h>
1818

1919
using namespace ::testing;
20-
using exec_aten::Scalar;
21-
using exec_aten::ScalarType;
22-
using exec_aten::Tensor;
20+
using executorch::aten::Scalar;
21+
using executorch::aten::ScalarType;
22+
using executorch::aten::Tensor;
23+
using executorch::runtime::testing::TensorFactory;
2324
using torch::executor::testing::SupportedFeatures;
24-
using torch::executor::testing::TensorFactory;
25+
namespace etrt = executorch::runtime;
2526

2627
class OpMulOutTest : public OperatorTest {
2728
protected:
@@ -61,7 +62,8 @@ class OpMulOutTest : public OperatorTest {
6162
test_mul<DTYPE_A, DTYPE_B, ScalarType::Float>();
6263
test_mul<DTYPE_A, DTYPE_B, ScalarType::Double>();
6364
// Integral out type is only allowed if both inputs are integral types
64-
if (isIntegralType(DTYPE_A, false) && isIntegralType(DTYPE_B, false)) {
65+
if (etrt::isIntegralType(DTYPE_A, false) &&
66+
etrt::isIntegralType(DTYPE_B, false)) {
6567
test_mul<DTYPE_A, DTYPE_B, ScalarType::Int>();
6668
test_mul<DTYPE_A, DTYPE_B, ScalarType::Long>();
6769
}

kernels/test/op_sub_test.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
#include <gtest/gtest.h>
1717

1818
using namespace ::testing;
19-
using exec_aten::Scalar;
20-
using exec_aten::ScalarType;
21-
using exec_aten::Tensor;
19+
using executorch::aten::Scalar;
20+
using executorch::aten::ScalarType;
21+
using executorch::aten::Tensor;
22+
using executorch::runtime::testing::TensorFactory;
2223
using torch::executor::testing::SupportedFeatures;
23-
using torch::executor::testing::TensorFactory;
24+
namespace etrt = executorch::runtime;
2425

2526
class OpSubOutTest : public OperatorTest {
2627
protected:
@@ -60,7 +61,8 @@ class OpSubOutTest : public OperatorTest {
6061
test_sub<DTYPE_A, DTYPE_B, ScalarType::Float>();
6162
test_sub<DTYPE_A, DTYPE_B, ScalarType::Double>();
6263
// Integral out type is only allowed if both inputs are integral types
63-
if (isIntegralType(DTYPE_A, false) && isIntegralType(DTYPE_B, false)) {
64+
if (etrt::isIntegralType(DTYPE_A, false) &&
65+
etrt::isIntegralType(DTYPE_B, false)) {
6466
test_sub<DTYPE_A, DTYPE_B, ScalarType::Int>();
6567
test_sub<DTYPE_A, DTYPE_B, ScalarType::Long>();
6668
}

0 commit comments

Comments
 (0)