Skip to content

Commit bc8da63

Browse files
janeyx99pytorchmergebot
authored andcommitted
Move MemoryFormat/Layout to headeronly (pytorch#168034)
~This PR does change the semantics of the >> operator by using STD_TORCH_CHECK to throw the error instead of TORCH_CHECK. Jane (who is writing this message) thinks it is okay because it is the error case when an invalid MemoryFormat or Layout is getting passed into >>, so the UX benefits of TORCH_CHECK over STD_TORCH_CHECK there are not significant enough to warrant making a new copy of Layout and MemoryFormat's >> APIs.~ Never mind! We shouldn't change TORCH_CHECK to STD_TORCH_CHECK for core usage ever, cuz the traceback info and c10::Error is very much desired!! So the solution is to not migrate the >>s. I pushed new commits to the stack to remove the >> code, but for reference, pytorch@8a30179 has all the code that I ended up deleting. Pull Request resolved: pytorch#168034 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#168025, pytorch#167802, pytorch#167803, pytorch#167804, pytorch#167962 Co-authored-by: Jane Xu <[email protected]>
1 parent f890837 commit bc8da63

File tree

13 files changed

+366
-81
lines changed

13 files changed

+366
-81
lines changed

c10/core/Layout.h

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,9 @@
33
#include <c10/core/Backend.h>
44
#include <c10/util/Exception.h>
55

6-
#include <cstdint>
7-
#include <ostream>
6+
#include <torch/headeronly/core/Layout.h>
87

98
namespace c10 {
10-
enum class Layout : int8_t {
11-
Strided,
12-
Sparse,
13-
SparseCsr,
14-
Mkldnn,
15-
SparseCsc,
16-
SparseBsr,
17-
SparseBsc,
18-
Jagged,
19-
NumOptions
20-
};
21-
22-
constexpr auto kStrided = Layout::Strided;
23-
constexpr auto kSparse = Layout::Sparse;
24-
constexpr auto kSparseCsr = Layout::SparseCsr;
25-
constexpr auto kMkldnn = Layout::Mkldnn;
26-
constexpr auto kSparseCsc = Layout::SparseCsc;
27-
constexpr auto kSparseBsr = Layout::SparseBsr;
28-
constexpr auto kSparseBsc = Layout::SparseBsc;
29-
constexpr auto kJagged = Layout::Jagged;
309

3110
inline Layout layout_from_backend(Backend backend) {
3211
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")

c10/core/MemoryFormat.h

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,18 @@
33
#include <c10/util/ArrayRef.h>
44
#include <c10/util/Exception.h>
55

6+
#include <torch/headeronly/core/MemoryFormat.h>
7+
68
#include <cstdint>
7-
#include <ostream>
89
#include <vector>
910

10-
// Memory format is not the property of a Tensor. It is the way to tell an
11-
// operator how the result should be organized in memory and nothing more. That
12-
// means memory format should never be used as return value for any tensor state
13-
// interrogation functions (internally and externally).
14-
//
15-
// Possible options are:
16-
// Preserve:
17-
// If any of the input tensors is in channels_last format, operator output
18-
// should be in channels_last format
19-
//
20-
// Contiguous:
21-
// Regardless of input tensors format, the output should be contiguous
22-
// Tensor.
23-
//
24-
// ChannelsLast:
25-
// Regardless of input tensors format, the output should be in channels_last
26-
// format.
27-
2811
namespace c10 {
29-
enum class MemoryFormat : int8_t {
30-
Contiguous,
31-
Preserve,
32-
ChannelsLast,
33-
ChannelsLast3d,
34-
NumOptions
35-
};
3612

3713
// If you are seeing this, it means that this call site was not checked if
3814
// the memory format could be preserved, and it was switched to old default
3915
// behaviour of contiguous
4016
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
4117

42-
inline MemoryFormat get_contiguous_memory_format() {
43-
return MemoryFormat::Contiguous;
44-
}
45-
4618
inline std::ostream& operator<<(
4719
std::ostream& stream,
4820
at::MemoryFormat memory_format) {

test/cpp/aoti_abi_check/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ set(AOTI_ABI_CHECK_TEST_SRCS
1616
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
1717
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
1818
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
19+
${AOTI_ABI_CHECK_TEST_ROOT}/test_layout.cpp
1920
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
2021
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
22+
${AOTI_ABI_CHECK_TEST_ROOT}/test_memoryformat.cpp
2123
${AOTI_ABI_CHECK_TEST_ROOT}/test_metaprogramming.cpp
2224
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
2325
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <torch/headeronly/core/Layout.h>
4+
5+
TEST(TestLayout, TestLayout) {
6+
using torch::headeronly::Layout;
7+
constexpr Layout expected_layouts[] = {
8+
torch::headeronly::kStrided,
9+
torch::headeronly::kSparse,
10+
torch::headeronly::kSparseCsr,
11+
torch::headeronly::kMkldnn,
12+
torch::headeronly::kSparseCsc,
13+
torch::headeronly::kSparseBsr,
14+
torch::headeronly::kSparseBsc,
15+
torch::headeronly::kJagged,
16+
};
17+
for (int8_t i = 0; i < static_cast<int8_t>(Layout::NumOptions); i++) {
18+
EXPECT_EQ(static_cast<Layout>(i), expected_layouts[i]);
19+
}
20+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <torch/headeronly/core/MemoryFormat.h>
4+
5+
TEST(TestMemoryFormat, TestMemoryFormat) {
6+
using torch::headeronly::MemoryFormat;
7+
constexpr MemoryFormat expected_memory_formats[] = {
8+
MemoryFormat::Contiguous,
9+
MemoryFormat::Preserve,
10+
MemoryFormat::ChannelsLast,
11+
MemoryFormat::ChannelsLast3d,
12+
};
13+
for (int8_t i = 0; i < static_cast<int8_t>(MemoryFormat::NumOptions); i++) {
14+
EXPECT_EQ(static_cast<MemoryFormat>(i), expected_memory_formats[i]);
15+
}
16+
}
17+
18+
TEST(TestMemoryFormat, get_contiguous_memory_format) {
19+
using torch::headeronly::get_contiguous_memory_format;
20+
using torch::headeronly::MemoryFormat;
21+
22+
EXPECT_EQ(get_contiguous_memory_format(), MemoryFormat::Contiguous);
23+
}

test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ using torch::stable::Tensor;
1010
Tensor my_empty(
1111
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
1212
std::optional<torch::headeronly::ScalarType> dtype,
13+
std::optional<torch::headeronly::Layout> layout,
1314
std::optional<torch::stable::Device> device,
14-
std::optional<bool> pin_memory) {
15-
return empty(size, dtype, device, pin_memory);
15+
std::optional<bool> pin_memory,
16+
std::optional<torch::headeronly::MemoryFormat> memory_format) {
17+
return empty(size, dtype, layout, device, pin_memory, memory_format);
1618
}
1719

1820
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
1921
m.def(
20-
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
22+
"my_empty(int[] size, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
2123
}
2224

2325
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {

test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,24 @@ def test_get_num_threads() -> int:
156156
return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default()
157157

158158

159-
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
159+
def my_empty(
160+
size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
161+
) -> Tensor:
160162
"""
161-
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
163+
Creates an empty tensor with the specified size, dtype, layout, device, pin_memory, and memory_format.
162164
163165
Args:
164166
size: list[int] - size of the tensor to create
165167
dtype: ScalarType or None - data type of the tensor
168+
layout: Layout or None - layout of the tensor
166169
device: Device or None - device on which to create the tensor
167170
pin_memory: bool or None - whether to use pinned memory
171+
memory_format: MemoryFormat or None - memory format of the tensor
168172
169173
Returns: Tensor - an uninitialized tensor with the specified properties
170174
"""
171175
return torch.ops.libtorch_agnostic_2_10.my_empty.default(
172-
size, dtype, device, pin_memory
176+
size, dtype, layout, device, pin_memory, memory_format
173177
)
174178

175179

test/cpp_extensions/test_libtorch_agnostic.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.testing._internal.common_utils import (
1515
install_cpp_extension,
1616
IS_WINDOWS,
17+
parametrize,
1718
run_tests,
1819
skipIfTorchDynamo,
1920
TestCase,
@@ -618,43 +619,92 @@ def test_get_num_threads(self, device):
618619
self.assertEqual(num_threads, expected_num_threads)
619620

620621
@skipIfTorchVersionLessThan(2, 10)
621-
def test_my_empty(self, device):
622+
@parametrize("layout", [None, torch.strided, torch.sparse_coo])
623+
@parametrize(
624+
"memory_format", [None, torch.channels_last, torch.contiguous_format]
625+
)
626+
def test_my_empty(self, device, layout, memory_format):
622627
import libtorch_agnostic_2_10 as libtorch_agnostic
623628

624629
deterministic = torch.are_deterministic_algorithms_enabled()
625630
try:
626631
# set use_deterministic_algorithms to fill uninitialized memory
627632
torch.use_deterministic_algorithms(True)
628633

629-
size = [2, 3]
630-
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
631-
expected = torch.empty(size)
632-
self.assertEqual(result, expected, exact_device=True)
634+
# Use 4D size for channels_last, 2D otherwise
635+
size = [2, 3, 4, 5] if memory_format == torch.channels_last else [2, 3]
636+
637+
# sparse_coo layout doesn't support memory_format parameter
638+
if layout == torch.sparse_coo and memory_format is not None:
639+
return
640+
641+
# Test default parameters
642+
result = libtorch_agnostic.ops.my_empty(
643+
size, None, layout, None, None, memory_format
644+
)
645+
expected = torch.empty(size, layout=layout, memory_format=memory_format)
646+
self.assertEqual(result, expected, exact_device=True, exact_layout=True)
633647

648+
# Test with dtype
634649
result_float = libtorch_agnostic.ops.my_empty(
635-
size, torch.float32, None, None
650+
size, torch.float32, layout, None, None, memory_format
651+
)
652+
expected_float = torch.empty(
653+
size,
654+
dtype=torch.float32,
655+
layout=layout,
656+
memory_format=memory_format,
657+
)
658+
self.assertEqual(
659+
result_float, expected_float, exact_device=True, exact_layout=True
636660
)
637-
expected_float = torch.empty(size, dtype=torch.float32)
638-
self.assertEqual(result_float, expected_float, exact_device=True)
639661

662+
# Test with dtype and device
640663
result_with_device = libtorch_agnostic.ops.my_empty(
641-
size, torch.float64, device, None
664+
size, torch.float64, layout, device, None, memory_format
642665
)
643666
expected_with_device = torch.empty(
644-
size, dtype=torch.float64, device=device
667+
size,
668+
dtype=torch.float64,
669+
layout=layout,
670+
device=device,
671+
memory_format=memory_format,
645672
)
646673
self.assertEqual(
647-
result_with_device, expected_with_device, exact_device=True
674+
result_with_device,
675+
expected_with_device,
676+
exact_device=True,
677+
exact_layout=True,
648678
)
649679

650-
if device == "cuda":
680+
# Verify layout if specified
681+
if layout is not None:
682+
self.assertEqual(result_with_device.layout, layout)
683+
684+
# Verify memory format if specified
685+
if memory_format == torch.channels_last:
686+
self.assertTrue(
687+
result_with_device.is_contiguous(
688+
memory_format=torch.channels_last
689+
)
690+
)
691+
elif memory_format == torch.contiguous_format:
692+
self.assertTrue(result_with_device.is_contiguous())
693+
694+
# Test pin_memory on CUDA (only once, not for every parameter combination)
695+
if device == "cuda" and layout is None and memory_format is None:
651696
result_pinned = libtorch_agnostic.ops.my_empty(
652-
size, torch.float32, "cpu", True
697+
[2, 3], torch.float32, None, "cpu", True, None
653698
)
654699
expected_pinned = torch.empty(
655-
size, dtype=torch.float32, device="cpu", pin_memory=True
700+
[2, 3], dtype=torch.float32, device="cpu", pin_memory=True
701+
)
702+
self.assertEqual(
703+
result_pinned,
704+
expected_pinned,
705+
exact_device=True,
706+
exact_layout=True,
656707
)
657-
self.assertEqual(result_pinned, expected_pinned)
658708
self.assertTrue(result_pinned.is_pinned())
659709
finally:
660710
torch.use_deterministic_algorithms(deterministic)

torch/csrc/stable/ops.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,24 +326,26 @@ inline uint32_t get_num_threads() {
326326
return num_threads;
327327
}
328328

329-
// We expect this to be the stable version of the empty op that takes in
330-
// device and dtype parameters. The empty op creates a tensor with uninitialized
331-
// values of the specified size, dtype, and device.
332-
// This function is only available in 2.10 because it uses the stableivalue
333-
// conversion for HeaderOnlyArrayRef<T>, which is only available in 2.10.
329+
// We expect this to be the stable version of the empty.memory_format op that
330+
// takes in device and dtype parameters. This function is only available in 2.10
331+
// because it uses the stableivalue conversion for HeaderOnlyArrayRef<T>, which
332+
// is only available in 2.10.
334333
inline torch::stable::Tensor empty(
335334
torch::headeronly::IntHeaderOnlyArrayRef size,
336335
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt,
336+
std::optional<torch::headeronly::Layout> layout = std::nullopt,
337337
std::optional<torch::stable::Device> device = std::nullopt,
338-
std::optional<bool> pin_memory = std::nullopt) {
338+
std::optional<bool> pin_memory = std::nullopt,
339+
std::optional<torch::headeronly::MemoryFormat> memory_format =
340+
std::nullopt) {
339341
const auto num_args = 6;
340342
std::array<StableIValue, num_args> stack{
341343
torch::stable::detail::from(size),
342344
torch::stable::detail::from(dtype),
343-
torch::stable::detail::from(std::nullopt),
345+
torch::stable::detail::from(layout),
344346
torch::stable::detail::from(device),
345347
torch::stable::detail::from(pin_memory),
346-
torch::stable::detail::from(std::nullopt)};
348+
torch::stable::detail::from(memory_format)};
347349
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
348350
"aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION));
349351
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);

0 commit comments

Comments
 (0)