Skip to content

Commit ea2181d

Browse files
tarun292facebook-github-bot
authored andcommitted
Make make_tensor in broadcast utilities public and rename free_broadcast_tensor (#2785)
Summary: This diff does a couple of things: - Makes `make_tensor` a public function so that we can create temporary intermediate tensors in operators that need to do so. (Such as NMS that is implemented above in this stack) - Renames `free_broadcast_tensor` to a more generic name `free_tensor` Differential Revision: D55577026
1 parent 3f2e769 commit ea2181d

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

kernels/portable/cpu/util/broadcast_util.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@ namespace executor {
1919
using Tensor = exec_aten::Tensor;
2020
using ScalarType = exec_aten::ScalarType;
2121

22-
void free_broadcast_tensor(const Tensor& broadcast_tensor) {
23-
free((void*)broadcast_tensor.const_data_ptr());
24-
free((void*)broadcast_tensor.sizes().data());
25-
free((void*)broadcast_tensor.dim_order().data());
26-
free((void*)broadcast_tensor.strides().data());
27-
free(broadcast_tensor.unsafeGetTensorImpl());
22+
void free_tensor(const Tensor& tensor) {
23+
free((void*)tensor.const_data_ptr());
24+
free((void*)tensor.sizes().data());
25+
free((void*)tensor.dim_order().data());
26+
free((void*)tensor.strides().data());
27+
free(tensor.unsafeGetTensorImpl());
2828
}
2929

30-
namespace {
31-
3230
Tensor make_tensor(
3331
const ArrayRef<Tensor::SizesType>& sizes,
3432
const ArrayRef<Tensor::DimOrderType>& dim_order,
@@ -74,8 +72,6 @@ Tensor make_tensor(
7472
return Tensor{tensor_impl};
7573
}
7674

77-
} // namespace
78-
7975
bool tensor_is_broadcastable_to(
8076
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
8177
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape) {

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ bool tensors_are_broadcastable_between(
6262
*/
6363
bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
6464

65+
/**
66+
* Create a new tensor with the given sizes, dim_order, and strides.
67+
*
68+
* @param[in] sizes The sizes of the tensor.
69+
* @param[in] dim_order The dim order of the tensor.
70+
* @param[in] strides The strides of the tensor.
71+
* @param[in] dtype The data type of the tensor.
72+
* @returns A new tensor with the given sizes, dim_order, and strides.
73+
*/
74+
Tensor make_tensor(
75+
const ArrayRef<Tensor::SizesType>& sizes,
76+
const ArrayRef<Tensor::DimOrderType>& dim_order,
77+
const ArrayRef<Tensor::StridesType>& strides,
78+
const ScalarType& dtype);
79+
6580
/**
6681
* DEPRECATED: Use `delinearize_index()` and `linearize_access_indexes()` for
6782
* index remapping to avoid memory allocation.
@@ -75,7 +90,7 @@ bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
7590
* @param[in] broadcast_to The tensor to which we want to broadcast to.
7691
* @returns A new tensor with the same shape as broadcast_to and the data
7792
* repeated as appropriate. This tensor contains dynamically allocated memory
78-
* and must be freed using free_broadcast_tensor.
93+
* and must be freed using free_tensor.
7994
*/
8095
__ET_DEPRECATED exec_aten::Tensor broadcast_tensor(
8196
const exec_aten::Tensor& broadcast_from,
@@ -202,8 +217,7 @@ __ET_NODISCARD inline Error resize_to_broadcast_target_size(
202217
* broadcast_tensor.
203218
* @returns void
204219
*/
205-
__ET_DEPRECATED void free_broadcast_tensor(
206-
const exec_aten::Tensor& broadcast_tensor);
220+
__ET_DEPRECATED void free_tensor(const exec_aten::Tensor& broadcast_tensor);
207221

208222
/**
209223
* Delinearize a flattened index to per-dimension indexes.

kernels/portable/cpu/util/test/broadcast_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ TEST(BroadcastUtilTest, BroadcastTensor) {
3232

3333
Tensor d = torch::executor::broadcast_tensor(a, c);
3434
EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
35-
torch::executor::free_broadcast_tensor(d);
35+
torch::executor::free_tensor(d);
3636

3737
d = torch::executor::broadcast_tensor(b, c);
3838
EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
39-
torch::executor::free_broadcast_tensor(d);
39+
torch::executor::free_tensor(d);
4040
}
4141

4242
TEST(BroadcastUtilTest, BroadcastableBetween) {
@@ -63,12 +63,12 @@ TEST(BroadcastUtilTest, BroadcastableToFrom) {
6363
ASSERT_TRUE(tensor_is_broadcastable_to(a, c));
6464
Tensor d = torch::executor::broadcast_tensor(a, c);
6565
EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
66-
torch::executor::free_broadcast_tensor(d);
66+
torch::executor::free_tensor(d);
6767

6868
ASSERT_TRUE(tensor_is_broadcastable_to(b, c));
6969
d = torch::executor::broadcast_tensor(b, c);
7070
EXPECT_TENSOR_DATA_EQ(d, tf.make({2, 2}, {2, 2, 2, 2}));
71-
torch::executor::free_broadcast_tensor(d);
71+
torch::executor::free_tensor(d);
7272
}
7373

7474
TEST(BroadcastUtilTest, NotBroadcastableTo) {

0 commit comments

Comments
 (0)