@@ -38,9 +38,9 @@ struct DescriptorDeleter {
3838// initialized the first time you call set() or any other initializing
3939// function.
4040template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
41- class Descriptor
42- {
43- public:
41+ // NOLINTNEXTLINE(bugprone-exception-escape)
42+ class TORCH_CUDA_CPP_API Descriptor {
43+ public:
4444 // Use desc() to access the underlying descriptor pointer in
4545 // a read-only fashion. Most client code should use this.
4646 // If the descriptor was never initialized, this will return
@@ -56,7 +56,7 @@ class Descriptor
5656protected:
5757 void init () {
5858 if (desc_ == nullptr ) {
59- T* raw_desc;
59+ T* raw_desc = nullptr ;
6060 MIOPEN_CHECK (ctor (&raw_desc));
6161 desc_.reset (raw_desc);
6262 }
@@ -65,13 +65,12 @@ class Descriptor
6565 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
6666};
6767
68- class TORCH_CUDA_CPP_API TensorDescriptor
69- : public Descriptor<miopenTensorDescriptor,
70- &miopenCreateTensorDescriptor,
71- &miopenDestroyTensorDescriptor>
72- {
73- public:
74- TensorDescriptor () {}
68+ class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
69+ miopenTensorDescriptor,
70+ &miopenCreateTensorDescriptor,
71+ &miopenDestroyTensorDescriptor> {
72+ public:
73+ TensorDescriptor () = default ;
7574 explicit TensorDescriptor (const at::Tensor &t, size_t pad = 0 ) {
7675 set (t, pad);
7776 }
@@ -89,11 +88,10 @@ class TORCH_CUDA_CPP_API TensorDescriptor
8988
9089std::ostream& operator <<(std::ostream & out, const TensorDescriptor& d);
9190
92- class FilterDescriptor
93- : public Descriptor<miopenTensorDescriptor,
94- &miopenCreateTensorDescriptor,
95- &miopenDestroyTensorDescriptor>
96- {
91+ class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
92+ miopenTensorDescriptor,
93+ &miopenCreateTensorDescriptor,
94+ &miopenDestroyTensorDescriptor> {
9795 public:
9896 void set (const at::Tensor &t, int64_t pad = 0 ) {
9997 set (t, at::MemoryFormat::Contiguous, pad);
@@ -107,11 +105,11 @@ class FilterDescriptor
107105 }
108106};
109107
110- struct ConvolutionDescriptor
111- : public Descriptor<miopenConvolutionDescriptor,
112- &miopenCreateConvolutionDescriptor ,
113- &miopenDestroyConvolutionDescriptor>
114- {
108+ struct TORCH_CUDA_CPP_API ConvolutionDescriptor
109+ : public Descriptor<
110+ miopenConvolutionDescriptor ,
111+ &miopenCreateConvolutionDescriptor,
112+ &miopenDestroyConvolutionDescriptor> {
115113 void set (miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int * pad, int * stride, int * upscale /* aka dilation */ , int groups, bool benchmark, bool deterministic) {
116114 MIOPEN_CHECK (miopenInitConvolutionNdDescriptor (mut_desc (), dim, pad, stride, upscale, c_mode));
117115 MIOPEN_CHECK (miopenSetConvolutionGroupCount (mut_desc (), groups));
@@ -122,8 +120,24 @@ struct ConvolutionDescriptor
122120 }
123121};
124122
123+ // NOLINTNEXTLINE(bugprone-exception-escape)
124+ struct TORCH_CUDA_CPP_API DropoutDescriptor
125+ : public Descriptor<
126+ miopenDropoutDescriptor,
127+ &miopenCreateDropoutDescriptor,
128+ &miopenDestroyDropoutDescriptor> {
129+ void set (miopenHandle_t handle, float dropout, void * states, size_t stateSizeInBytes,
130+ unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
131+ MIOPEN_CHECK (miopenSetDropoutDescriptor (mut_desc (), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
132+ }
133+
134+ void restore (miopenHandle_t handle, float dropout, void * states, size_t stateSizeInBytes,
135+ unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
136+ MIOPEN_CHECK (miopenRestoreDropoutDescriptor (mut_desc (), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
137+ }
138+ };
125139
126- struct RNNDescriptor
140+ struct TORCH_CUDA_CPP_API RNNDescriptor
127141 : public Descriptor<miopenRNNDescriptor,
128142 &miopenCreateRNNDescriptor,
129143 &miopenDestroyRNNDescriptor>
0 commit comments