11#include < ATen/miopen/Descriptors.h>
2+
23#include < ATen/ATen.h>
34#include < c10/util/irange.h>
45
56#include < iostream>
7+ #include < sstream>
68
7- namespace at { namespace native {
9+ namespace at :: native {
810
911namespace {
1012
13+
1114inline miopenDataType_t getDataType (const at::Tensor& t) {
1215 auto scalar_type = t.scalar_type ();
1316 if (scalar_type == at::kFloat ) {
@@ -16,34 +19,37 @@ inline miopenDataType_t getDataType(const at::Tensor& t) {
1619 return miopenHalf;
1720 } else if (scalar_type == at::kBFloat16 ) {
1821 return miopenBFloat16;
19- } else {
20- TORCH_CHECK (
21- false ,
22- " TensorDescriptor only supports float, half and bfloat16 tensors" );
2322 }
23+ TORCH_CHECK (false , " TensorDescriptor does not support " , scalar_type);
2424}
2525
26+ constexpr size_t MIOPEN_DIM_MAX = 5 ;
27+
2628} // anonymous namespace
2729
30+ void TensorDescriptor::set (const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) {
31+ set (getDataType (t), t.sizes (), t.strides (), pad,
32+ memory_format == at::MemoryFormat::ChannelsLast ||
33+ memory_format == at::MemoryFormat::ChannelsLast3d);
34+ }
2835
2936void TensorDescriptor::set (const at::Tensor &t, size_t pad) {
30- set (getDataType (t), t.sizes (), t.strides (), pad);
37+ auto memory_format = t.suggest_memory_format ();
38+ set (getDataType (t), t.sizes (), t.strides (), pad,
39+ memory_format == at::MemoryFormat::ChannelsLast ||
40+ memory_format == at::MemoryFormat::ChannelsLast3d);
3141}
3242
33- constexpr size_t MIOPEN_DIM_MAX = 5 ;
34-
3543void TensorDescriptor::set (miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
44+ set (datatype, t_sizes, t_strides, pad,
45+ is_channels_last_strides_2d (t_sizes, t_strides) ||
46+ is_channels_last_strides_3d (t_sizes, t_strides));
47+ }
48+
49+ void TensorDescriptor::set (miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) {
3650 size_t dim = t_sizes.size ();
3751 if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
38- #define _STR (X ) #X
39- #define STR (X ) _STR(X)
40- TORCH_CHECK (
41- false ,
42- " MIOpen supports only up to " ,
43- STR (MIOPEN_DIM_MAX),
44- " dimensions" );
45- #undef _STR
46- #undef STR
52+ TORCH_CHECK (false , " MIOpen supports only up to " , MIOPEN_DIM_MAX, " dimensions" );
4753 int size[MIOPEN_DIM_MAX];
4854 int stride[MIOPEN_DIM_MAX];
4955 for (const auto i : c10::irange (dim)) {
@@ -54,7 +60,7 @@ void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntAr
5460 size[i] = 1 ;
5561 stride[i] = 1 ;
5662 }
57- set (datatype, static_cast <int >(std::max (dim, pad)), size, stride);
63+ set (datatype, static_cast <int >(std::max (dim, pad)), size, stride, nhwc );
5864}
5965
6066std::string miopenTypeToString (miopenDataType_t dtype) {
@@ -74,10 +80,11 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
7480
7581std::ostream& operator <<(std::ostream & out, const TensorDescriptor& d) {
7682 out << " TensorDescriptor " << static_cast <void *>(d.desc ()) << " \n " ;
77- int nbDims = 4 ;
83+ int nbDims = 0 ;
7884 int dimA[MIOPEN_DIM_MAX];
7985 int strideA[MIOPEN_DIM_MAX];
8086 miopenDataType_t dtype;
87+ miopenGetTensorDescriptorSize (d.desc (), &nbDims);
8188 miopenGetTensorDescriptor (d.desc (), &dtype, dimA, strideA);
8289 out << " type = " << miopenTypeToString (dtype) << " \n " ;
8390 out << " nbDims = " << nbDims << " \n " ;
@@ -99,19 +106,17 @@ void TensorDescriptor::print() { std::cout << *this; }
99106
100107void FilterDescriptor::set (const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
101108 auto dim = t.ndimension ();
102- if (dim > static_cast <int64_t >(MIOPEN_DIM_MAX) || pad > static_cast <int64_t >(MIOPEN_DIM_MAX)) {
103- #define _STR (X ) #X
104- #define STR (X ) _STR(X)
105- TORCH_CHECK (
106- false ,
107- " MIOpen supports only up to " ,
108- STR (MIOPEN_DIM_MAX),
109- " dimensions" );
110- #undef _STR
111- #undef STR
112- }
109+ if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
110+ TORCH_CHECK (false , " MIOpen supports only up to " , MIOPEN_DIM_MAX, " dimensions" );
111+ // NB: It is possible for this test to be insufficient, because the
112+ // Tensor passed in to set the filter descriptor may not be the actual
113+ // Tensor whose data pointer is passed to cuDNN. Nevertheless,
114+ // that is the common case, so we can catch most client errors with this test.
113115 TORCH_CHECK (t.is_contiguous (memory_format),
114- " MIOpen filters (a.k.a. weights) must be contiguous" );
116+ " MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n " ,
117+ " Weight sizes: " , t.sizes (), " \n " ,
118+ " Weight strides: " , t.strides (), " \n " ,
119+ " cuDNN suggested memory_format: " , memory_format);
115120
116121 int size[MIOPEN_DIM_MAX];
117122 int stride[MIOPEN_DIM_MAX];
@@ -131,7 +136,33 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
131136 }
132137
133138 dim = std::max<int64_t >(dim, pad);
134- set (getDataType (t), ( int ) dim, size, stride);
139+ set (getDataType (t), static_cast < int >( dim) , size, stride);
135140}
136141
137- }}
142+ std::ostream& operator <<(std::ostream & out, const FilterDescriptor& d) {
143+ out << " FilterDescriptor " << static_cast <void *>(d.desc ()) << " \n " ;
144+ int nbDims = 0 ;
145+ int dimA[MIOPEN_DIM_MAX];
146+ int strideA[MIOPEN_DIM_MAX];
147+ miopenDataType_t dtype;
148+ miopenGetTensorDescriptorSize (d.desc (), &nbDims);
149+ miopenGetTensorDescriptor (d.desc (), &dtype, dimA, strideA);
150+ out << " type = " << miopenTypeToString (dtype) << " \n " ;
151+ out << " nbDims = " << nbDims << " \n " ;
152+ // Read out only nbDims of the arrays!
153+ out << " dimA = " ;
154+ for (auto i : ArrayRef<int >{dimA, static_cast <size_t >(nbDims)}) {
155+ out << i << " , " ;
156+ }
157+ out << " \n " ;
158+ out << " strideA = " ;
159+ for (auto i : ArrayRef<int >{strideA, static_cast <size_t >(nbDims)}) {
160+ out << i << " , " ;
161+ }
162+ out << " \n " ;
163+ return out;
164+ }
165+
166+ void FilterDescriptor::print () { std::cout << *this ; }
167+
168+ }
0 commit comments