Skip to content

Commit a854563

Browse files
authored
[release/2.6] Exposing Some MIOpen Symbols pytorch#154545 (#2217)
(This is a cherry-pick of pytorch#154545) This PR exposes some MIOpen symbols, namely: 1. `miopenDataType_t getMiopenDataType(const at::Tensor& tensor)` 2. `miopenHandle_t getMiopenHandle()` 3. `class TensorDescriptor` 4. `class Descriptor` 5. `class FilterDescriptor` 6. `struct ConvolutionDescriptor` 7. `struct DropoutDescriptor` 8. `struct RNNDescriptor` to enable adding extensions that make use of them.
1 parent 0ad7380 commit a854563

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed

aten/src/ATen/miopen/Descriptors.h

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ struct DescriptorDeleter {
3838
// initialized the first time you call set() or any other initializing
3939
// function.
4040
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
41-
class Descriptor
42-
{
43-
public:
41+
// NOLINTNEXTLINE(bugprone-exception-escape)
42+
class TORCH_CUDA_CPP_API Descriptor {
43+
public:
4444
// Use desc() to access the underlying descriptor pointer in
4545
// a read-only fashion. Most client code should use this.
4646
// If the descriptor was never initialized, this will return
@@ -56,7 +56,7 @@ class Descriptor
5656
protected:
5757
void init() {
5858
if (desc_ == nullptr) {
59-
T* raw_desc;
59+
T* raw_desc = nullptr;
6060
MIOPEN_CHECK(ctor(&raw_desc));
6161
desc_.reset(raw_desc);
6262
}
@@ -65,13 +65,12 @@ class Descriptor
6565
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
6666
};
6767

68-
class TORCH_CUDA_CPP_API TensorDescriptor
69-
: public Descriptor<miopenTensorDescriptor,
70-
&miopenCreateTensorDescriptor,
71-
&miopenDestroyTensorDescriptor>
72-
{
73-
public:
74-
TensorDescriptor() {}
68+
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
69+
miopenTensorDescriptor,
70+
&miopenCreateTensorDescriptor,
71+
&miopenDestroyTensorDescriptor> {
72+
public:
73+
TensorDescriptor() = default;
7574
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
7675
set(t, pad);
7776
}
@@ -89,11 +88,10 @@ class TORCH_CUDA_CPP_API TensorDescriptor
8988

9089
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
9190

92-
class FilterDescriptor
93-
: public Descriptor<miopenTensorDescriptor,
94-
&miopenCreateTensorDescriptor,
95-
&miopenDestroyTensorDescriptor>
96-
{
91+
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
92+
miopenTensorDescriptor,
93+
&miopenCreateTensorDescriptor,
94+
&miopenDestroyTensorDescriptor> {
9795
public:
9896
void set(const at::Tensor &t, int64_t pad = 0) {
9997
set(t, at::MemoryFormat::Contiguous, pad);
@@ -107,11 +105,11 @@ class FilterDescriptor
107105
}
108106
};
109107

110-
struct ConvolutionDescriptor
111-
: public Descriptor<miopenConvolutionDescriptor,
112-
&miopenCreateConvolutionDescriptor,
113-
&miopenDestroyConvolutionDescriptor>
114-
{
108+
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
109+
: public Descriptor<
110+
miopenConvolutionDescriptor,
111+
&miopenCreateConvolutionDescriptor,
112+
&miopenDestroyConvolutionDescriptor> {
115113
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) {
116114
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
117115
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
@@ -122,8 +120,24 @@ struct ConvolutionDescriptor
122120
}
123121
};
124122

123+
// NOLINTNEXTLINE(bugprone-exception-escape)
124+
struct TORCH_CUDA_CPP_API DropoutDescriptor
125+
: public Descriptor<
126+
miopenDropoutDescriptor,
127+
&miopenCreateDropoutDescriptor,
128+
&miopenDestroyDropoutDescriptor> {
129+
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
130+
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
131+
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
132+
}
133+
134+
void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
135+
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
136+
MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
137+
}
138+
};
125139

126-
struct RNNDescriptor
140+
struct TORCH_CUDA_CPP_API RNNDescriptor
127141
: public Descriptor<miopenRNNDescriptor,
128142
&miopenCreateRNNDescriptor,
129143
&miopenDestroyRNNDescriptor>

aten/src/ATen/miopen/Handle.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include <ATen/miopen/miopen-wrapper.h>
44
#include <c10/macros/Export.h>
55

6-
namespace at { namespace native {
6+
namespace at::native {
77

88
TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle();
9-
10-
}} // namespace
9+
} // namespace at::native

aten/src/ATen/miopen/Types.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#pragma once
22

3-
#include <ATen/miopen/miopen-wrapper.h>
43
#include <ATen/Tensor.h>
4+
#include <ATen/miopen/miopen-wrapper.h>
55
#include <c10/macros/Export.h>
66

7-
namespace at { namespace native {
7+
namespace at::native {
88

99
TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
1010

1111
int64_t miopen_version();
1212

13-
}} // namespace at::miopen
13+
} // namespace at::native

0 commit comments

Comments
 (0)