Skip to content

Commit aeb6421

Browse files
authored
[release/2.8] revamp miopen integration (#2601)
Update sources under ATen/miopen and ATen/native/miopen to align with best practices. Avoid reshape_ calls inside backward operations. Pull Request resolved: pytorch#161687 Approved by: https://github.com/jeffdaily
1 parent db3ba66 commit aeb6421

File tree

8 files changed

+1162
-772
lines changed

8 files changed

+1162
-772
lines changed

aten/src/ATen/cudnn/Descriptors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ inline int dataSize(cudnnDataType_t dataType)
3838
}
3939
}
4040

41+
// NOTE [ cudnn fixSizeOneDimStride ]
4142
// The stride for a size-1 dimensions is not uniquely determined; in
4243
// fact, it can be anything you want, because the fact that the
4344
// tensor is size 1 at this dimension means that you will never actually

aten/src/ATen/miopen/Descriptors.cpp

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,37 @@ inline miopenDataType_t getDataType(const at::Tensor& t) {
1919
} else {
2020
TORCH_CHECK(
2121
false,
22-
"TensorDescriptor only supports float, half and bfloat16 tensors");
22+
"TensorDescriptor does not support ", scalar_type);
2323
}
2424
}
2525

2626
} // anonymous namespace
2727

28+
constexpr size_t MIOPEN_DIM_MAX = 5;
2829

29-
void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
30-
set(getDataType(t), t.sizes(), t.strides(), pad);
30+
void TensorDescriptor::set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) {
31+
set(getDataType(t), t.sizes(), t.strides(), pad,
32+
memory_format == at::MemoryFormat::ChannelsLast ||
33+
memory_format == at::MemoryFormat::ChannelsLast3d);
3134
}
3235

33-
constexpr size_t MIOPEN_DIM_MAX = 5;
36+
void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
37+
auto memory_format = t.suggest_memory_format();
38+
set(getDataType(t), t.sizes(), t.strides(), pad,
39+
memory_format == at::MemoryFormat::ChannelsLast ||
40+
memory_format == at::MemoryFormat::ChannelsLast3d);
41+
}
3442

3543
void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
44+
set(datatype, t_sizes, t_strides, pad,
45+
is_channels_last_strides_2d(t_sizes, t_strides) ||
46+
is_channels_last_strides_3d(t_sizes, t_strides));
47+
}
48+
49+
void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) {
3650
size_t dim = t_sizes.size();
3751
if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
38-
#define _STR(X) #X
39-
#define STR(X) _STR(X)
40-
TORCH_CHECK(
41-
false,
42-
"MIOpen supports only up to ",
43-
STR(MIOPEN_DIM_MAX),
44-
" dimensions");
45-
#undef _STR
46-
#undef STR
52+
TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions");
4753
int size[MIOPEN_DIM_MAX];
4854
int stride[MIOPEN_DIM_MAX];
4955
for (const auto i : c10::irange(dim)) {
@@ -54,7 +60,7 @@ void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntAr
5460
size[i] = 1;
5561
stride[i] = 1;
5662
}
57-
set(datatype, static_cast<int>(std::max(dim, pad)), size, stride);
63+
set(datatype, static_cast<int>(std::max(dim, pad)), size, stride, nhwc);
5864
}
5965

6066
std::string miopenTypeToString(miopenDataType_t dtype) {
@@ -74,10 +80,11 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
7480

7581
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
7682
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
77-
int nbDims = 4;
83+
int nbDims = 0;
7884
int dimA[MIOPEN_DIM_MAX];
7985
int strideA[MIOPEN_DIM_MAX];
8086
miopenDataType_t dtype;
87+
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
8188
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
8289
out << " type = " << miopenTypeToString(dtype) << "\n";
8390
out << " nbDims = " << nbDims << "\n";
@@ -99,19 +106,17 @@ void TensorDescriptor::print() { std::cout << *this; }
99106

100107
void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
101108
auto dim = t.ndimension();
102-
if (dim > static_cast<int64_t>(MIOPEN_DIM_MAX) || pad > static_cast<int64_t>(MIOPEN_DIM_MAX)) {
103-
#define _STR(X) #X
104-
#define STR(X) _STR(X)
105-
TORCH_CHECK(
106-
false,
107-
"MIOpen supports only up to ",
108-
STR(MIOPEN_DIM_MAX),
109-
" dimensions");
110-
#undef _STR
111-
#undef STR
112-
}
109+
if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
110+
TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions");
111+
// NB: It is possible for this test to be insufficient, because the
112+
// Tensor passed in to set the filter descriptor may not be the actual
113+
// Tensor whose data pointer is passed to cuDNN. Nevertheless,
114+
// that is the common case, so we can catch most client errors with this test.
113115
TORCH_CHECK(t.is_contiguous(memory_format),
114-
"MIOpen filters (a.k.a. weights) must be contiguous");
116+
"MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n",
117+
"Weight sizes: ", t.sizes(), "\n",
118+
"Weight strides: ", t.strides(), "\n",
119+
"cuDNN suggested memory_format: ", memory_format);
115120

116121
int size[MIOPEN_DIM_MAX];
117122
int stride[MIOPEN_DIM_MAX];
@@ -131,7 +136,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
131136
}
132137

133138
dim = std::max<int64_t>(dim, pad);
134-
set(getDataType(t), (int) dim, size, stride);
139+
set(getDataType(t), static_cast<int>(dim), size, stride,
140+
memory_format == at::MemoryFormat::ChannelsLast ||
141+
memory_format == at::MemoryFormat::ChannelsLast3d);
135142
}
136143

137144
}}

aten/src/ATen/miopen/Descriptors.h

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
namespace at { namespace native {
1111

12+
std::string miopenTypeToString(miopenDataType_t dtype);
13+
1214
inline int dataSize(miopenDataType_t dataType)
1315
{
1416
switch (dataType) {
@@ -19,6 +21,32 @@ inline int dataSize(miopenDataType_t dataType)
1921
}
2022
}
2123

24+
// See NOTE [ cudnn fixSizeOneDimStride ] in aten/src/ATen/cudnn/Descriptors.h
25+
template <typename T>
26+
static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
27+
int64_t z = 1;
28+
int index = 0;
29+
std::vector<int> permutation(dim);
30+
31+
if (nhwc) {
32+
permutation[index++] = 1;
33+
}
34+
for (int d = dim-1; d > 1; d--) {
35+
permutation[index++] = d;
36+
}
37+
if (!nhwc) {
38+
permutation[index++] = 1;
39+
}
40+
permutation[index++] = 0;
41+
for (int d : permutation) {
42+
if (size[d] == 1) {
43+
stride[d] = z;
44+
} else {
45+
z *= size[d];
46+
}
47+
}
48+
}
49+
2250
template <typename T, miopenStatus_t (*dtor)(T*)>
2351
struct DescriptorDeleter {
2452
void operator()(T* x) {
@@ -75,14 +103,20 @@ class TORCH_HIP_CPP_API TensorDescriptor : public Descriptor<
75103
set(t, pad);
76104
}
77105

106+
// See Note [CuDNN broadcast padding]
78107
void set(const at::Tensor &t, size_t pad = 0);
108+
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
79109
void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
80110

81111
void print();
82112

83113
private:
84-
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
85-
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
114+
void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
115+
116+
void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
117+
std::vector<int> strides_copy(stride, stride + dim);
118+
fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
119+
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
86120
}
87121
};
88122

@@ -100,8 +134,10 @@ class TORCH_HIP_CPP_API FilterDescriptor : public Descriptor<
100134
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
101135

102136
private:
103-
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
104-
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
137+
void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
138+
std::vector<int> strides_copy(stride, stride + dim);
139+
fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
140+
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
105141
}
106142
};
107143

@@ -166,4 +202,4 @@ union Constant
166202
}
167203
};
168204

169-
}} // namespace
205+
}} // namespace

aten/src/ATen/native/ConvUtils.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,21 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
353353
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
354354

355355

356-
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
357-
356+
inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
358357
// disable NHWC for float64 input.
359358
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
360359
input.scalar_type() == at::kDouble ||
361360
weight.scalar_type() == at::kDouble) {
362-
return false;
361+
return at::MemoryFormat::Contiguous;
363362
}
364363

365364
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
366-
// See #64427
367-
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
368-
static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC;
365+
// See https://github.com/pytorch/pytorch/issues/64427.
366+
// non static variable is used to be able to change environment variable in runtime for testing
367+
// enabled by default for ROCm >= 7.0.0 with miopen 3.5
368+
int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0;
369+
bool is_miopen_3_5 = miopen_version >= 30500; // ROCm 7.0
370+
bool suggest_nhwc = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(is_miopen_3_5);
369371

370372
auto input_memory_format = input.suggest_memory_format();
371373
auto weight_memory_format = weight.suggest_memory_format();
@@ -375,13 +377,24 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
375377
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
376378
(weight_memory_format == at::MemoryFormat::ChannelsLast)
377379
);
380+
if (can_use_miopen_channels_last_2d) {
381+
return at::MemoryFormat::ChannelsLast;
382+
}
378383

379384
bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
380385
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
381386
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
382387
);
388+
if (can_use_miopen_channels_last_3d) {
389+
return at::MemoryFormat::ChannelsLast3d;
390+
}
391+
392+
return at::MemoryFormat::Contiguous;
393+
}
383394

384-
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
395+
// deprecated, but to remove would be BC-breaking
396+
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
397+
return miopen_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous;
385398
}
386399

387400
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {

aten/src/ATen/native/Convolution.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,9 @@ struct ConvParams {
458458

459459
// Use cudnn for FP16 depthwise convolutions
460460
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
461+
if (!detail::getCUDAHooks().compiledWithCuDNN()) {
462+
return false;
463+
}
461464
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
462465
// always use cudnn_depthwise for channels_last format
463466
return true;
@@ -1418,10 +1421,8 @@ static inline at::MemoryFormat determine_backend_memory_format(
14181421
case ConvBackend::Miopen:
14191422
case ConvBackend::MiopenDepthwise:
14201423
case ConvBackend::MiopenTranspose:
1421-
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
1422-
TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
1423-
"Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
1424-
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1424+
if (detail::getCUDAHooks().compiledWithMIOpen()) {
1425+
backend_memory_format = miopen_conv_suggest_memory_format(input, weight);
14251426
}
14261427
break;
14271428
case ConvBackend::Mkldnn:

0 commit comments

Comments
 (0)