Skip to content

Commit 88a9bd4

Browse files
committed
[ROCm] revamp miopen integration
Update sources under ATen/miopen and ATen/native/miopen to align with best practices. Avoid reshape_ calls inside backward operations.
1 parent 4e630f0 commit 88a9bd4

File tree

8 files changed

+1659
-936
lines changed

8 files changed

+1659
-936
lines changed

aten/src/ATen/miopen/Descriptors.cpp

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
#include <ATen/miopen/Descriptors.h>
2+
23
#include <ATen/ATen.h>
34
#include <c10/util/irange.h>
45

56
#include <iostream>
7+
#include <sstream>
68

7-
namespace at { namespace native {
9+
namespace at::native {
810

911
namespace {
1012

13+
1114
inline miopenDataType_t getDataType(const at::Tensor& t) {
1215
auto scalar_type = t.scalar_type();
1316
if (scalar_type == at::kFloat) {
@@ -16,34 +19,37 @@ inline miopenDataType_t getDataType(const at::Tensor& t) {
1619
return miopenHalf;
1720
} else if (scalar_type == at::kBFloat16) {
1821
return miopenBFloat16;
19-
} else {
20-
TORCH_CHECK(
21-
false,
22-
"TensorDescriptor only supports float, half and bfloat16 tensors");
2322
}
23+
TORCH_CHECK(false, "TensorDescriptor does not support ", scalar_type);
2424
}
2525

26+
constexpr size_t MIOPEN_DIM_MAX = 5;
27+
2628
} // anonymous namespace
2729

30+
void TensorDescriptor::set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) {
31+
set(getDataType(t), t.sizes(), t.strides(), pad,
32+
memory_format == at::MemoryFormat::ChannelsLast ||
33+
memory_format == at::MemoryFormat::ChannelsLast3d);
34+
}
2835

2936
void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
30-
set(getDataType(t), t.sizes(), t.strides(), pad);
37+
auto memory_format = t.suggest_memory_format();
38+
set(getDataType(t), t.sizes(), t.strides(), pad,
39+
memory_format == at::MemoryFormat::ChannelsLast ||
40+
memory_format == at::MemoryFormat::ChannelsLast3d);
3141
}
3242

33-
constexpr size_t MIOPEN_DIM_MAX = 5;
34-
3543
void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
44+
set(datatype, t_sizes, t_strides, pad,
45+
is_channels_last_strides_2d(t_sizes, t_strides) ||
46+
is_channels_last_strides_3d(t_sizes, t_strides));
47+
}
48+
49+
void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) {
3650
size_t dim = t_sizes.size();
3751
if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
38-
#define _STR(X) #X
39-
#define STR(X) _STR(X)
40-
TORCH_CHECK(
41-
false,
42-
"MIOpen supports only up to ",
43-
STR(MIOPEN_DIM_MAX),
44-
" dimensions");
45-
#undef _STR
46-
#undef STR
52+
TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions");
4753
int size[MIOPEN_DIM_MAX];
4854
int stride[MIOPEN_DIM_MAX];
4955
for (const auto i : c10::irange(dim)) {
@@ -54,7 +60,7 @@ void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntAr
5460
size[i] = 1;
5561
stride[i] = 1;
5662
}
57-
set(datatype, static_cast<int>(std::max(dim, pad)), size, stride);
63+
set(datatype, static_cast<int>(std::max(dim, pad)), size, stride, nhwc);
5864
}
5965

6066
std::string miopenTypeToString(miopenDataType_t dtype) {
@@ -74,10 +80,11 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
7480

7581
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
7682
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
77-
int nbDims = 4;
83+
int nbDims = 0;
7884
int dimA[MIOPEN_DIM_MAX];
7985
int strideA[MIOPEN_DIM_MAX];
8086
miopenDataType_t dtype;
87+
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
8188
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
8289
out << " type = " << miopenTypeToString(dtype) << "\n";
8390
out << " nbDims = " << nbDims << "\n";
@@ -99,19 +106,17 @@ void TensorDescriptor::print() { std::cout << *this; }
99106

100107
void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
101108
auto dim = t.ndimension();
102-
if (dim > static_cast<int64_t>(MIOPEN_DIM_MAX) || pad > static_cast<int64_t>(MIOPEN_DIM_MAX)) {
103-
#define _STR(X) #X
104-
#define STR(X) _STR(X)
105-
TORCH_CHECK(
106-
false,
107-
"MIOpen supports only up to ",
108-
STR(MIOPEN_DIM_MAX),
109-
" dimensions");
110-
#undef _STR
111-
#undef STR
112-
}
109+
if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
110+
TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions");
111+
// NB: It is possible for this test to be insufficient, because the
112+
// Tensor passed in to set the filter descriptor may not be the actual
113+
// Tensor whose data pointer is passed to cuDNN. Nevertheless,
114+
// that is the common case, so we can catch most client errors with this test.
113115
TORCH_CHECK(t.is_contiguous(memory_format),
114-
"MIOpen filters (a.k.a. weights) must be contiguous");
116+
"MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n",
117+
"Weight sizes: ", t.sizes(), "\n",
118+
"Weight strides: ", t.strides(), "\n",
119+
"cuDNN suggested memory_format: ", memory_format);
115120

116121
int size[MIOPEN_DIM_MAX];
117122
int stride[MIOPEN_DIM_MAX];
@@ -131,7 +136,33 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
131136
}
132137

133138
dim = std::max<int64_t>(dim, pad);
134-
set(getDataType(t), (int) dim, size, stride);
139+
set(getDataType(t), static_cast<int>(dim), size, stride);
135140
}
136141

137-
}}
142+
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
143+
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
144+
int nbDims = 0;
145+
int dimA[MIOPEN_DIM_MAX];
146+
int strideA[MIOPEN_DIM_MAX];
147+
miopenDataType_t dtype;
148+
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
149+
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
150+
out << " type = " << miopenTypeToString(dtype) << "\n";
151+
out << " nbDims = " << nbDims << "\n";
152+
// Read out only nbDims of the arrays!
153+
out << " dimA = ";
154+
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
155+
out << i << ", ";
156+
}
157+
out << "\n";
158+
out << " strideA = ";
159+
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
160+
out << i << ", ";
161+
}
162+
out << "\n";
163+
return out;
164+
}
165+
166+
void FilterDescriptor::print() { std::cout << *this; }
167+
168+
}

aten/src/ATen/miopen/Descriptors.h

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
#pragma once
22

3-
#include <ATen/miopen/Exceptions.h>
3+
#include <string>
44

5+
#include <ATen/miopen/Exceptions.h>
56
#include <ATen/miopen/miopen-wrapper.h>
67
#include <ATen/core/Tensor.h>
78
#include <ATen/TensorUtils.h>
89
#include <c10/macros/Export.h>
910

10-
namespace at { namespace native {
11+
namespace at::native {
12+
13+
std::string miopenTypeToString(miopenDataType_t dtype);
1114

1215
inline int dataSize(miopenDataType_t dataType)
1316
{
@@ -19,6 +22,43 @@ inline int dataSize(miopenDataType_t dataType)
1922
}
2023
}
2124

25+
// The stride for a size-1 dimensions is not uniquely determined; in
26+
// fact, it can be anything you want, because the fact that the
27+
// tensor is size 1 at this dimension means that you will never actually
28+
// try advancing your pointer by this stride.
29+
//
30+
// We duplicate the CuDNN restriction here for MIOpen.
31+
//
32+
// However, CuDNN has a much more stringent requirement on strides:
33+
// if you are passing a contiguous input, it better be the case
34+
// that the stride for dim i is the product of the sizes of dims
35+
// i+1 to the end. This stride is indeed uniquely determined. This
36+
// function modifies 'stride' in place so this invariant holds.
37+
template <typename T>
38+
static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
39+
int64_t z = 1;
40+
int index = 0;
41+
std::vector<int> permutation(dim);
42+
43+
if (nhwc) {
44+
permutation[index++] = 1;
45+
}
46+
for (int d = dim-1; d > 1; d--) {
47+
permutation[index++] = d;
48+
}
49+
if (!nhwc) {
50+
permutation[index++] = 1;
51+
}
52+
permutation[index++] = 0;
53+
for (int d : permutation) {
54+
if (size[d] == 1) {
55+
stride[d] = z;
56+
} else {
57+
z *= size[d];
58+
}
59+
}
60+
}
61+
2262
template <typename T, miopenStatus_t (*dtor)(T*)>
2363
struct DescriptorDeleter {
2464
void operator()(T* x) {
@@ -41,6 +81,8 @@ template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
4181
// NOLINTNEXTLINE(bugprone-exception-escape)
4282
class TORCH_HIP_CPP_API Descriptor {
4383
public:
84+
// TODO: Figure out why const-correctness doesn't work here
85+
4486
// Use desc() to access the underlying descriptor pointer in
4587
// a read-only fashion. Most client code should use this.
4688
// If the descriptor was never initialized, this will return
@@ -75,14 +117,32 @@ class TORCH_HIP_CPP_API TensorDescriptor : public Descriptor<
75117
set(t, pad);
76118
}
77119

120+
// Note [MIOpen broadcast padding]
121+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
122+
// pad specifies the minimum dimensionality of the tensor descriptor
123+
// we produce (it doesn't have anything to do with, e.g., convolution
124+
// padding). If 't' is lower-dimensional than 'pad', the remaining
125+
// dimensions (on the right) are padded with ones. This doesn't
126+
// affect the underlying data layout. This is particularly useful for
127+
// dealing with a peculiarity of the MIOpen API, which is that broadcasting in MIOpen is
128+
// done in two steps: first, the client code is expected to pad out
129+
// (the dimensions) input tensors to be the same dimension as the
130+
// target broadcast, and then second, MIOpen takes of actually
131+
// broadcasting size 1 dimensions.
132+
78133
void set(const at::Tensor &t, size_t pad = 0);
134+
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
79135
void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
80136

81137
void print();
82138

83139
private:
84-
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
85-
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
140+
void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
141+
142+
void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
143+
std::vector<int> strides_copy(stride, stride + dim);
144+
fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
145+
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
86146
}
87147
};
88148

@@ -99,12 +159,15 @@ class TORCH_HIP_CPP_API FilterDescriptor : public Descriptor<
99159

100160
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
101161

162+
void print();
102163
private:
103164
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
104165
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
105166
}
106167
};
107168

169+
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
170+
108171
struct TORCH_HIP_CPP_API ConvolutionDescriptor
109172
: public Descriptor<
110173
miopenConvolutionDescriptor,
@@ -166,4 +229,4 @@ union Constant
166229
}
167230
};
168231

169-
}} // namespace
232+
} // namespace

aten/src/ATen/native/ConvUtils.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,21 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
353353
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
354354

355355

356-
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
357-
356+
inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
358357
// disable NHWC for float64 input.
359358
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
360359
input.scalar_type() == at::kDouble ||
361360
weight.scalar_type() == at::kDouble) {
362-
return false;
361+
return at::MemoryFormat::Contiguous;
363362
}
364363

365364
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
366-
// See #64427
367-
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
368-
static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC;
365+
// See https://github.com/pytorch/pytorch/issues/64427.
366+
// non static variable is used to be able to change environment variable in runtime for testing
367+
// enabled by default for ROCm >= 7.0.0 with miopen 3.5
368+
int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0;
369+
bool is_miopen_3_5 = miopen_version >= 30500; // ROCm 7.0
370+
bool suggest_nhwc = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(is_miopen_3_5);
369371

370372
auto input_memory_format = input.suggest_memory_format();
371373
auto weight_memory_format = weight.suggest_memory_format();
@@ -375,13 +377,19 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
375377
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
376378
(weight_memory_format == at::MemoryFormat::ChannelsLast)
377379
);
380+
if (can_use_miopen_channels_last_2d) {
381+
return at::MemoryFormat::ChannelsLast;
382+
}
378383

379384
bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
380385
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
381386
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
382387
);
388+
if (can_use_miopen_channels_last_3d) {
389+
return at::MemoryFormat::ChannelsLast3d;
390+
}
383391

384-
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
392+
return at::MemoryFormat::Contiguous;
385393
}
386394

387395
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {

aten/src/ATen/native/Convolution.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,9 @@ struct ConvParams {
459459

460460
// Use cudnn for FP16 depthwise convolutions
461461
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
462+
if (!detail::getCUDAHooks().compiledWithCuDNN()) {
463+
return false;
464+
}
462465
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
463466
// always use cudnn_depthwise for channels_last format
464467
return true;
@@ -1419,10 +1422,8 @@ static inline at::MemoryFormat determine_backend_memory_format(
14191422
case ConvBackend::Miopen:
14201423
case ConvBackend::MiopenDepthwise:
14211424
case ConvBackend::MiopenTranspose:
1422-
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
1423-
TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
1424-
"Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
1425-
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1425+
if (detail::getCUDAHooks().compiledWithMIOpen()) {
1426+
backend_memory_format = miopen_conv_suggest_memory_format(input, weight);
14261427
}
14271428
break;
14281429
case ConvBackend::Mkldnn:

0 commit comments

Comments
 (0)