55#include < ATen/miopen/miopen-wrapper.h>
66#include < ATen/core/Tensor.h>
77#include < ATen/TensorUtils.h>
8+ #include < c10/macros/Export.h>
89
910namespace at { namespace native {
1011
@@ -37,9 +38,9 @@ struct DescriptorDeleter {
3738// initialized the first time you call set() or any other initializing
3839// function.
3940template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
40- class Descriptor
41- {
42- public:
41+ // NOLINTNEXTLINE(bugprone-exception-escape)
42+ class TORCH_CUDA_CPP_API Descriptor {
43+ public:
4344 // Use desc() to access the underlying descriptor pointer in
4445 // a read-only fashion. Most client code should use this.
4546 // If the descriptor was never initialized, this will return
@@ -55,7 +56,7 @@ class Descriptor
5556protected:
5657 void init () {
5758 if (desc_ == nullptr ) {
58- T* raw_desc;
59+ T* raw_desc = nullptr ;
5960 MIOPEN_CHECK (ctor (&raw_desc));
6061 desc_.reset (raw_desc);
6162 }
@@ -64,13 +65,12 @@ class Descriptor
6465 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
6566};
6667
67- class TensorDescriptor
68- : public Descriptor<miopenTensorDescriptor,
69- &miopenCreateTensorDescriptor,
70- &miopenDestroyTensorDescriptor>
71- {
72- public:
73- TensorDescriptor () {}
68+ class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
69+ miopenTensorDescriptor,
70+ &miopenCreateTensorDescriptor,
71+ &miopenDestroyTensorDescriptor> {
72+ public:
73+ TensorDescriptor () = default ;
7474 explicit TensorDescriptor (const at::Tensor &t, size_t pad = 0 ) {
7575 set (t, pad);
7676 }
@@ -88,11 +88,10 @@ class TensorDescriptor
8888
8989std::ostream& operator <<(std::ostream & out, const TensorDescriptor& d);
9090
91- class FilterDescriptor
92- : public Descriptor<miopenTensorDescriptor,
93- &miopenCreateTensorDescriptor,
94- &miopenDestroyTensorDescriptor>
95- {
91+ class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
92+ miopenTensorDescriptor,
93+ &miopenCreateTensorDescriptor,
94+ &miopenDestroyTensorDescriptor> {
9695 public:
9796 void set (const at::Tensor &t, int64_t pad = 0 ) {
9897 set (t, at::MemoryFormat::Contiguous, pad);
@@ -106,11 +105,11 @@ class FilterDescriptor
106105 }
107106};
108107
109- struct ConvolutionDescriptor
110- : public Descriptor<miopenConvolutionDescriptor,
111- &miopenCreateConvolutionDescriptor ,
112- &miopenDestroyConvolutionDescriptor>
113- {
108+ struct TORCH_CUDA_CPP_API ConvolutionDescriptor
109+ : public Descriptor<
110+ miopenConvolutionDescriptor ,
111+ &miopenCreateConvolutionDescriptor,
112+ &miopenDestroyConvolutionDescriptor> {
114113 void set (miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int * pad, int * stride, int * upscale /* aka dilation */ , int groups, bool benchmark, bool deterministic) {
115114 MIOPEN_CHECK (miopenInitConvolutionNdDescriptor (mut_desc (), dim, pad, stride, upscale, c_mode));
116115 MIOPEN_CHECK (miopenSetConvolutionGroupCount (mut_desc (), groups));
@@ -121,8 +120,24 @@ struct ConvolutionDescriptor
121120 }
122121};
123122
123+ // NOLINTNEXTLINE(bugprone-exception-escape)
124+ struct TORCH_CUDA_CPP_API DropoutDescriptor
125+ : public Descriptor<
126+ miopenDropoutDescriptor,
127+ &miopenCreateDropoutDescriptor,
128+ &miopenDestroyDropoutDescriptor> {
129+ void set (miopenHandle_t handle, float dropout, void * states, size_t stateSizeInBytes,
130+ unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
131+ MIOPEN_CHECK (miopenSetDropoutDescriptor (mut_desc (), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
132+ }
133+
134+ void restore (miopenHandle_t handle, float dropout, void * states, size_t stateSizeInBytes,
135+ unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
136+ MIOPEN_CHECK (miopenRestoreDropoutDescriptor (mut_desc (), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
137+ }
138+ };
124139
125- struct RNNDescriptor
140+ struct TORCH_CUDA_CPP_API RNNDescriptor
126141 : public Descriptor<miopenRNNDescriptor,
127142 &miopenCreateRNNDescriptor,
128143 &miopenDestroyRNNDescriptor>
0 commit comments