diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 6c2492b12e6b9..85f0286542e75 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -38,6 +38,7 @@ inline int dataSize(cudnnDataType_t dataType) } } +// NOTE [ cudnn fixSizeOneDimStride ] // The stride for a size-1 dimensions is not uniquely determined; in // fact, it can be anything you want, because the fact that the // tensor is size 1 at this dimension means that you will never actually diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp index 08c09b88f99cb..86e42ee3b66dc 100644 --- a/aten/src/ATen/miopen/Descriptors.cpp +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -19,31 +19,37 @@ inline miopenDataType_t getDataType(const at::Tensor& t) { } else { TORCH_CHECK( false, - "TensorDescriptor only supports float, half and bfloat16 tensors"); + "TensorDescriptor does not support ", scalar_type); } } } // anonymous namespace +constexpr size_t MIOPEN_DIM_MAX = 5; -void TensorDescriptor::set(const at::Tensor &t, size_t pad) { - set(getDataType(t), t.sizes(), t.strides(), pad); +void TensorDescriptor::set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) { + set(getDataType(t), t.sizes(), t.strides(), pad, + memory_format == at::MemoryFormat::ChannelsLast || + memory_format == at::MemoryFormat::ChannelsLast3d); } -constexpr size_t MIOPEN_DIM_MAX = 5; +void TensorDescriptor::set(const at::Tensor &t, size_t pad) { + auto memory_format = t.suggest_memory_format(); + set(getDataType(t), t.sizes(), t.strides(), pad, + memory_format == at::MemoryFormat::ChannelsLast || + memory_format == at::MemoryFormat::ChannelsLast3d); +} void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) { + set(datatype, t_sizes, t_strides, pad, + is_channels_last_strides_2d(t_sizes, t_strides) || + is_channels_last_strides_3d(t_sizes, t_strides)); +} + +void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) { size_t dim = t_sizes.size(); if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) -#define _STR(X) #X -#define STR(X) _STR(X) - TORCH_CHECK( - false, - "MIOpen supports only up to ", - STR(MIOPEN_DIM_MAX), - " dimensions"); -#undef _STR -#undef STR + TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions"); int size[MIOPEN_DIM_MAX]; int stride[MIOPEN_DIM_MAX]; for (const auto i : c10::irange(dim)) { @@ -54,7 +60,7 @@ void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntAr size[i] = 1; stride[i] = 1; } - set(datatype, static_cast(std::max(dim, pad)), size, stride); + set(datatype, static_cast(std::max(dim, pad)), size, stride, nhwc); } std::string miopenTypeToString(miopenDataType_t dtype) { @@ -74,10 +80,11 @@ std::string miopenTypeToString(miopenDataType_t dtype) { std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) { out << "TensorDescriptor " << static_cast(d.desc()) << "\n"; - int nbDims = 4; + int nbDims = 0; int dimA[MIOPEN_DIM_MAX]; int strideA[MIOPEN_DIM_MAX]; miopenDataType_t dtype; + miopenGetTensorDescriptorSize(d.desc(), &nbDims); miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA); out << " type = " << miopenTypeToString(dtype) << "\n"; out << " nbDims = " << nbDims << "\n"; @@ -99,19 +106,17 @@ void TensorDescriptor::print() { std::cout << *this; } void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) { auto dim = t.ndimension(); - if (dim > static_cast(MIOPEN_DIM_MAX) || pad > static_cast(MIOPEN_DIM_MAX)) { -#define _STR(X) #X -#define STR(X) _STR(X) - TORCH_CHECK( - false, - "MIOpen supports only up to ", - STR(MIOPEN_DIM_MAX), - " dimensions"); -#undef _STR -#undef STR - } + if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) + TORCH_CHECK(false, "MIOpen supports only up to ", MIOPEN_DIM_MAX, " dimensions"); + // NB: It is possible for this test to be insufficient, because the + // Tensor passed in to set the filter descriptor may not be the actual + // Tensor whose data pointer is passed to cuDNN. Nevertheless, + // that is the common case, so we can catch most client errors with this test. TORCH_CHECK(t.is_contiguous(memory_format), - "MIOpen filters (a.k.a. weights) must be contiguous"); + "MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n", + "Weight sizes: ", t.sizes(), "\n", + "Weight strides: ", t.strides(), "\n", + "cuDNN suggested memory_format: ", memory_format); int size[MIOPEN_DIM_MAX]; int stride[MIOPEN_DIM_MAX]; @@ -131,7 +136,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo } dim = std::max(dim, pad); - set(getDataType(t), (int) dim, size, stride); + set(getDataType(t), static_cast(dim), size, stride, + memory_format == at::MemoryFormat::ChannelsLast || + memory_format == at::MemoryFormat::ChannelsLast3d); } }} diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h index 2eee837cd533d..8825575c9231b 100644 --- a/aten/src/ATen/miopen/Descriptors.h +++ b/aten/src/ATen/miopen/Descriptors.h @@ -9,6 +9,8 @@ namespace at { namespace native { +std::string miopenTypeToString(miopenDataType_t dtype); + inline int dataSize(miopenDataType_t dataType) { switch (dataType) { @@ -19,6 +21,32 @@ inline int dataSize(miopenDataType_t dataType) } } +// See NOTE [ cudnn fixSizeOneDimStride ] in aten/src/ATen/cudnn/Descriptors.h +template +static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) { + int64_t z = 1; + int index = 0; + std::vector permutation(dim); + + if (nhwc) { + permutation[index++] = 1; + } + for (int d = dim-1; d > 1; d--) { + permutation[index++] = d; + } + if (!nhwc) { + permutation[index++] = 1; + } + permutation[index++] = 0; + for (int d : permutation) { + if (size[d] == 1) { + stride[d] = z; + } else { + z *= size[d]; + } + } +} + template struct DescriptorDeleter { void operator()(T* x) { @@ -75,14 +103,20 @@ class TORCH_HIP_CPP_API TensorDescriptor : public Descriptor< set(t, pad); } + // See Note [CuDNN broadcast padding] void set(const at::Tensor &t, size_t pad = 0); + void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0); void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); void print(); private: - void set(miopenDataType_t dataType, int dim, int* size, int* stride) { - MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); + + void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); } }; @@ -100,8 +134,10 @@ class TORCH_HIP_CPP_API FilterDescriptor : public Descriptor< void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); private: - void set(miopenDataType_t dataType, int dim, int* size, int* stride) { - MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); } }; @@ -166,4 +202,4 @@ union Constant } }; -}} // namespace +}} // namespace diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 84381efe55b0b..e160c84ced331 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -353,19 +353,21 @@ TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable); TORCH_API bool _cudnn_get_conv_benchmark_empty_cache(); -inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { - +inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { // disable NHWC for float64 input. if (!at::detail::getCUDAHooks().compiledWithMIOpen() || input.scalar_type() == at::kDouble || weight.scalar_type() == at::kDouble) { - return false; + return at::MemoryFormat::Contiguous; } // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen - // See #64427 - static std::optional PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC"); - static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC; + // See https://github.com/pytorch/pytorch/issues/64427. + // non static variable is used to be able to change environment variable in runtime for testing + // enabled by default for ROCm >= 7.0.0 with miopen 3.5 + int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0; + bool is_miopen_3_5 = miopen_version >= 30500; // ROCm 7.0 + bool suggest_nhwc = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(is_miopen_3_5); auto input_memory_format = input.suggest_memory_format(); auto weight_memory_format = weight.suggest_memory_format(); @@ -375,13 +377,24 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten (input_memory_format == at::MemoryFormat::ChannelsLast) || (weight_memory_format == at::MemoryFormat::ChannelsLast) ); + if (can_use_miopen_channels_last_2d) { + return at::MemoryFormat::ChannelsLast; + } bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && ( (input_memory_format == at::MemoryFormat::ChannelsLast3d) || (weight_memory_format == at::MemoryFormat::ChannelsLast3d) ); + if (can_use_miopen_channels_last_3d) { + return at::MemoryFormat::ChannelsLast3d; + } + + return at::MemoryFormat::Contiguous; +} - return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d; +// deprecated, but to remove would be BC-breaking +inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { + return miopen_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous; } inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 1122d9c8d38af..634b71b9e3eb2 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -458,6 +458,9 @@ struct ConvParams { // Use cudnn for FP16 depthwise convolutions bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const { + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { // always use cudnn_depthwise for channels_last format return true; @@ -1418,10 +1421,8 @@ static inline at::MemoryFormat determine_backend_memory_format( case ConvBackend::Miopen: case ConvBackend::MiopenDepthwise: case ConvBackend::MiopenTranspose: - if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) { - TORCH_INTERNAL_ASSERT((k == 4 || k == 5), - "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()"); - backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; + if (detail::getCUDAHooks().compiledWithMIOpen()) { + backend_memory_format = miopen_conv_suggest_memory_format(input, weight); } break; case ConvBackend::Mkldnn: diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index f9ac375c205ec..9dcfe783d5340 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #endif // TODO: Remove the condition on AT_ROCM_ENABLED entirely, @@ -145,13 +146,13 @@ at::Tensor miopen_convolution_relu( #include #include +#include #include #include #include #include -#include #include #include #include @@ -162,10 +163,13 @@ at::Tensor miopen_convolution_relu( namespace at { namespace native { -Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) { - auto group_size = t.size(dim) / groups; - return t.narrow(dim, group_idx * group_size, group_size); -} +// See NOTE [ Convolution design ] in aten/src/ATen/native/cudnn/ConvShared.cpp + +// --------------------------------------------------------------------- +// +// Helper classes +// +// --------------------------------------------------------------------- // This POD struct is used to let us easily compute hashes of the // parameters @@ -174,6 +178,8 @@ struct ConvolutionParams miopenHandle_t handle; miopenDataType_t dataType; int input_size[2 + max_dim]; + uint8_t input_dim; + at::MemoryFormat memory_format; int input_stride[2 + max_dim]; int weight_size[2 + max_dim]; int padding[max_dim]; @@ -181,25 +187,29 @@ struct ConvolutionParams int dilation[max_dim]; int64_t groups; bool deterministic; - int device_id; //This is needed to distinguish between miopen handles of multiple gpus. + c10::DeviceIndex device_id; //This is needed to distinguish between miopen handles of multiple gpus. // NB: transposed purposely omitted: transposed just swaps // forward and backward, so you can reuse the benchmark entry, }; -// ConvolutionParams must be a POD because we read out its memory -// contenst as char* when hashing -static_assert(std::is_standard_layout_v, "ConvolutionParams not POD"); void setConvolutionParams( - ConvolutionParams* params, miopenHandle_t handle, - const at::Tensor& input, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool deterministic) { - + ConvolutionParams* params, + miopenHandle_t handle, + const at::Tensor& input, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool deterministic, + at::MemoryFormat memory_format) { miopenDataType_t dataType = getMiopenDataType(input); memset(params, 0, sizeof(ConvolutionParams)); params->dataType = dataType; params->handle = handle; // ASSERT(weight.dim() == input.dim()) + params->input_dim = input.dim(); + params->memory_format = memory_format; for (int i = 0; i != input.dim(); ++i) { params->input_size[i] = (int) input.size(i); params->input_stride[i] = (int) input.stride(i); @@ -214,9 +224,7 @@ void setConvolutionParams( } params->groups = groups; params->deterministic = deterministic; - int device_id; - HIP_CHECK(hipGetDevice(&device_id)); - params->device_id = device_id; + params->device_id = at::cuda::current_device(); } // Convenience struct for passing around descriptors and data @@ -239,31 +247,10 @@ struct ConvolutionArgs { // // --------------------------------------------------------------------- -// Hashing machinery for ConvolutionParams -struct ParamsHash { - std::size_t operator()(const ConvolutionParams& params) const { - auto ptr = reinterpret_cast(¶ms); - uint32_t value = 0x811C9DC5; - for (const auto i : c10::irange((int)sizeof(ConvolutionParams))) { - value ^= ptr[i]; - value *= 0x01000193; - } - return (size_t)value; - } -}; - -struct ParamsEqual { - bool operator()(const ConvolutionParams& a, const ConvolutionParams& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(ConvolutionParams)) == 0; - } -}; - template struct BenchmarkCache { std::mutex mutex; - std::unordered_map map; + std::unordered_map, ParamsEqual> map; bool find(const ConvolutionParams& params, T* results) { std::lock_guard guard(mutex); @@ -314,39 +301,39 @@ size_t getWorkspaceSize( const ConvolutionArgs& args, const miopenConvFwdAlgorithm_t) { size_t sz = 0; - miopenConvolutionForwardGetWorkSpaceSize( + MIOPEN_CHECK(miopenConvolutionForwardGetWorkSpaceSize( args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(), args.odesc.desc(), - &sz); + &sz)); return sz; } size_t getWorkspaceSize( const ConvolutionArgs& args, const miopenConvBwdDataAlgorithm_t) { size_t sz = 0; - miopenConvolutionBackwardDataGetWorkSpaceSize( + MIOPEN_CHECK(miopenConvolutionBackwardDataGetWorkSpaceSize( args.handle, args.odesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.idesc.desc(), - &sz); + &sz)); return sz; } size_t getWorkspaceSize( const ConvolutionArgs& args, const miopenConvBwdWeightsAlgorithm_t) { size_t sz = 0; - miopenConvolutionBackwardWeightsGetWorkSpaceSize( + MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetWorkSpaceSize( args.handle, args.odesc.desc(), args.idesc.desc(), args.cdesc.desc(), args.wdesc.desc(), - &sz); + &sz)); return sz; } @@ -649,6 +636,94 @@ Workspace chooseSolution(const ConvolutionArgs& args, uint64_t* solution_id) } } +// See NOTE [ raw_cudnn_convolution_forward_out ] in aten/src/ATen/native/cudnn/Conv_v7.cpp + +// --------------------------------------------------------------------- +// +// Splitting to 32bit +// +// --------------------------------------------------------------------- + +template +static inline void split_batch_dim_to_32bit_out( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise, + int64_t max_worksize, + func_t func_32bit) { + constexpr int64_t int_max = std::numeric_limits::max(); + const int64_t ni = input.numel(); + const int64_t no = output.numel(); + // Assume the shape of the tensor is (N, C, D1, D2, ...) + // if N * C * D1 * D2 * ... <= int_max, then no need to split at all + if (ni <= int_max && no <= int_max) { + func_32bit( + output, + input, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); + return; + } + // else, if C * D1 * D2 * ... <= int_max, then we just need to split across + // the N dimension + // + // Here we use a simple heuristics to determine the size of each split + // We don't max out the 2^31 address space because this number is super + // large and very likely to get an OOM. + int64_t n = output.size(0); + int64_t max_inner_size = std::max(ni, no) / n; + int64_t split_size = std::max(max_worksize / max_inner_size, 1L); + int64_t num_splits = (n + split_size - 1) / split_size; + if (split_size * max_inner_size < int_max) { + for (const auto i : c10::irange(num_splits)) { + int64_t start = split_size * i; + int64_t split_size_ = std::min(split_size, n - start); + Tensor input_ = input.narrow(0, start, split_size_); + Tensor output_ = output.narrow(0, start, split_size_); + func_32bit( + output_, + input_, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); + } + return; + } + // If control flow reaches here, this means even splitting N is not enough, + // then things starts to become complicated: For example, for conv2d, there + // following questions needs to be considered. + // - Is the memory layout NCHW or NHWC ? + // - If the conv is NCHW -> NC'H'W', then should we + // - split only NC? + // - split only N'C'? + // - split both? + // - If the conv is NHWC, then we need to split across H, we need to be very + // careful about the boundary condition + // to make sure that the boundary is handled correctly. + // - If we decide to make these splits, is the memory contiguous? Do we need + // to copy the memory? Considering the complexity of this issue, it is better + // not to use cuDNN for this case + TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); +} + // --------------------------------------------------------------------- // // Bias addition @@ -690,8 +765,47 @@ void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const */ } -// see NOTE [ Convolution design ] in src/Aten/native/cudnn/Conv.cpp +Tensor miopen_convolution_backward_bias(const Tensor& grad_output_t) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }; + + // TODO: Workaround since MIOpen does not support NHWC bias + // See #64426 + std::vector discard_dims; + for( int i = 0; i < grad_output_t.dim(); i++ ) { + if(i != output_channels_dim ) { + discard_dims.push_back(i); + } + } + + Tensor outputBias = at::squeeze( at::sum(grad_output_t, discard_dims, true) ); + if( outputBias.dim() == 0 ) { + // always return a tensor of shape [_] + return outputBias.unsqueeze(0); + } + else { + return outputBias; + } + +/* MIOpen does not support NHWC bias. Activate once support is added. + auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options()); + + TensorArg grad_bias{ grad_bias_t, "result", 0 }; + + TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}), + static_cast(grad_output->dim())}; + TensorDescriptor odesc{*grad_output}; + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*grad_bias); + Constant one(dataType, 1); + Constant zero(dataType, 0); + MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(), + &zero, bdesc.desc(), grad_bias->data_ptr())); + return *grad_bias; +*/ +} // --------------------------------------------------------------------- // @@ -699,30 +813,47 @@ void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const // // --------------------------------------------------------------------- -// The raw API directly invokes MIOpen. -// -// There are a few reasons this should never be directly exposed -// via ATen: -// -// - It takes output as a parameter (this should be computed!) -// - It doesn't do input checking -// - It doesn't resize output (it is assumed to be correctly sized) -// -void raw_miopen_convolution_forward_out( - const Tensor& output, const Tensor& input, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - +void raw_miopen_convolution_forward_out_32bit( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { auto dataType = getMiopenDataType(input); - miopenConvolutionMode_t c_mode = miopenConvolution; + miopenConvolutionMode_t c_mode = depthwise ? miopenDepthwise : miopenConvolution; - ConvolutionArgs args{ input, output, weight }; + ConvolutionArgs args{input, output, weight}; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(input); - args.wdesc.set(weight, input.suggest_memory_format(), 0); - args.odesc.set(output); - args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); + at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight); + setConvolutionParams( + &args.params, + args.handle, + input, + weight, + padding, + stride, + dilation, + groups, + deterministic, + memory_format); + args.idesc.set(input, memory_format); + args.wdesc.set(weight, memory_format, 0); + args.odesc.set(output, memory_format); + args.cdesc.set( + dataType, + c_mode, + input.dim() - 2, + args.params.padding, + args.params.stride, + args.params.dilation, + args.params.groups, + benchmark, + deterministic); if (deterministic && !benchmark) { // immediate mode is triggered for the specific combination of benchmark=off deterministic=on @@ -731,10 +862,16 @@ void raw_miopen_convolution_forward_out( MIOPEN_CHECK(miopenConvolutionForwardImmediate( args.handle, - args.wdesc.desc(), weight.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), + args.wdesc.desc(), + weight.const_data_ptr(), + args.idesc.desc(), + input.const_data_ptr(), args.cdesc.desc(), - args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size, solution_id)); + args.odesc.desc(), + output.data_ptr(), + workspace.data, + workspace.size, + solution_id)); } else { miopenConvFwdAlgorithm_t fwdAlg; @@ -745,475 +882,216 @@ void raw_miopen_convolution_forward_out( MIOPEN_CHECK(miopenConvolutionForward( args.handle, - &one, args.idesc.desc(), input.const_data_ptr(), - args.wdesc.desc(), weight.const_data_ptr(), - args.cdesc.desc(), fwdAlg, &zero, - args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size)); + &one, + args.idesc.desc(), + input.const_data_ptr(), + args.wdesc.desc(), + weight.const_data_ptr(), + args.cdesc.desc(), + fwdAlg, + &zero, + args.odesc.desc(), + output.data_ptr(), + workspace.data, + workspace.size)); } } -Tensor miopen_convolution_forward( +void raw_miopen_convolution_forward_out( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + split_batch_dim_to_32bit_out( + output, + input, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise, + 1024 * 1024 * 256, + raw_miopen_convolution_forward_out_32bit); +} + +void miopen_convolution_forward_out( + TensorArg& output, CheckedFrom c, - const TensorArg& input, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ + const TensorArg& input, + const TensorArg& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { checkAllSameType(c, {input, weight}); checkAllSameGPU(c, {input, weight}); - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*input, *weight)) { - memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } - - Tensor output_t = at::detail::empty_cuda( - conv_output_size(input->sizes(), weight->sizes(), - padding, stride, dilation), - input->options().memory_format(memory_format)); - - if (output_t.numel() == 0) { - return output_t; - } - - // Avoid ambiguity of "output" when this is being used as backwards - TensorArg output{ output_t, "result", 0 }; - convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); + auto memory_format = output->suggest_memory_format(); + convolution_shape_check( + c, input, weight, output, padding, stride, dilation, groups); - // See #4500 Tensor weight_contig = weight->contiguous(memory_format); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), memory_format); Tensor input_contig = input->contiguous(memory_format); - input_contig.resize_(input_contig.sizes(), memory_format); - - raw_miopen_convolution_forward_out( - *output, input_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); - - return *output; + *output, + input_contig, + weight_contig, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); } Tensor miopen_convolution( - const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_t_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); const Tensor& bias_t = *bias_t_maybe_owned; - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }, - bias { bias_t, "bias", 3 }; + TensorArg input{input_t, "input", 1 }, weight{weight_t, "weight", 2}, bias{bias_t, "bias", 3}; CheckedFrom c = "miopen_convolution"; - auto output_t = miopen_convolution_forward( - c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); + auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t); + Tensor output_t = at::detail::empty_cuda( + conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation), + input->options().memory_format(memory_format)); + if (output_t.numel() == 0) { + return output_t; + } + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{output_t, "result", 0}; + miopen_convolution_forward_out( + output, + c, + input, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); if (bias->defined()) { - miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); + miopen_convolution_add_bias_(c, output, bias); } - return output_t; + return *output; } -//Depthwise Convolutions -void raw_miopen_depthwise_convolution_forward_out( - const Tensor& output, const Tensor& input, const Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { +Tensor miopen_convolution_transpose_backward_input( + const Tensor& grad_output_t, + const Tensor& weight_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, weight{weight_t, "weight", 2}; + auto memory_format = + miopen_conv_suggest_memory_format(grad_output_t, weight_t); + Tensor output_t = at::detail::empty_cuda( + conv_output_size( + grad_output_t.sizes(), weight_t.sizes(), padding, stride, dilation), + grad_output_t.options().memory_format(memory_format)); - auto dataType = getMiopenDataType(input); - miopenConvolutionMode_t c_mode = miopenDepthwise; + if (output_t.numel() == 0) { + return output_t; + } + TensorArg output{output_t, "result", 0}; + miopen_convolution_forward_out( + output, + "miopen_convolution_transpose_backward_input", + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); + return *output; +} - ConvolutionArgs args{ input, output, weight }; - args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(input); - args.wdesc.set(weight, input.suggest_memory_format(), 0); - args.odesc.set(output); - args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); +// file organization would put miopen_convolution_transpose_backward_weight here, +// but it depends on miopen_convolution_backward_weight which is defined later +Tensor miopen_convolution_transpose_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic); - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on - uint64_t solution_id; - Workspace workspace = chooseSolution(args, &solution_id); +std::tuple miopen_convolution_transpose_backward( + const at::Tensor& input, + const at::Tensor& grad_output_t, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + std::array output_mask) { + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - MIOPEN_CHECK(miopenConvolutionForwardImmediate( - args.handle, - args.wdesc.desc(), weight.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), - args.cdesc.desc(), - args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size, solution_id)); + Tensor grad_input, grad_weight, grad_bias; + if (output_mask[0]) { + grad_input = miopen_convolution_transpose_backward_input( + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); } - else { - miopenConvFwdAlgorithm_t fwdAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - - MIOPEN_CHECK(miopenConvolutionForward( - args.handle, - &one, args.idesc.desc(), input.const_data_ptr(), - args.wdesc.desc(), weight.const_data_ptr(), - args.cdesc.desc(), fwdAlg, &zero, - args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size)); + if (output_mask[1]) { + grad_weight = miopen_convolution_transpose_backward_weight( + weight.sizes(), + grad_output, + input, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); + } + if (output_mask[2]) { + grad_bias = miopen_convolution_backward_bias(grad_output); } -} - -Tensor miopen_depthwise_convolution_forward( - CheckedFrom c, - const TensorArg& input, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - checkAllSameType(c, {input, weight}); - checkAllSameGPU(c, {input, weight}); - - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*input, *weight)) { - memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } - - Tensor output_t = at::detail::empty_cuda( - conv_output_size(input->sizes(), weight->sizes(), - padding, stride, dilation), - input->options().memory_format(memory_format)); - - TensorArg output{ output_t, "result", 0 }; - convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); - - // See #4500 - Tensor weight_contig = weight->contiguous(memory_format); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), memory_format); - Tensor input_contig = input->contiguous(memory_format); - input_contig.resize_(input_contig.sizes(), memory_format); - - raw_miopen_depthwise_convolution_forward_out( - *output, input_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); - - return *output; -} - -Tensor miopen_depthwise_convolution( - const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); - const Tensor& bias_t = *bias_t_maybe_owned; - - TensorArg input { input_t, "input", 1 }, - weight { weight_t, "weight", 2 }, - bias { bias_t, "bias", 3 }; - CheckedFrom c = "miopen_depthwise_convolution"; - auto output_t = miopen_depthwise_convolution_forward( - c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); - if (bias->defined()) { - miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); - } - return output_t; -} - -// --------------------------------------------------------------------- -// -// Convolution backward (bias) -// -// --------------------------------------------------------------------- - -Tensor miopen_convolution_backward_bias( - const Tensor& grad_output_t) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }; - - // TODO: Workaround since MIOpen does not support NHWC bias - // See #64426 - std::vector discard_dims; - for( int i = 0; i < grad_output_t.dim(); i++ ) { - if(i != output_channels_dim ) { - discard_dims.push_back(i); - } - } - - Tensor outputBias = at::squeeze( at::sum(grad_output_t, discard_dims, true) ); - if( outputBias.dim() == 0 ) { - // always return a tensor of shape [_] - return outputBias.unsqueeze(0); - } - else { - return outputBias; - } - -/* MIOpen does not support NHWC bias. Activate once support is added. - auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options()); - - TensorArg grad_bias{ grad_bias_t, "result", 0 }; - - TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}), - static_cast(grad_output->dim())}; - TensorDescriptor odesc{*grad_output}; - - auto handle = getMiopenHandle(); - auto dataType = getMiopenDataType(*grad_bias); - Constant one(dataType, 1); - Constant zero(dataType, 0); - - MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(), - &zero, bdesc.desc(), grad_bias->data_ptr())); - return *grad_bias; -*/ -} - -// --------------------------------------------------------------------- -// -// Convolution backward (weight) -// -// --------------------------------------------------------------------- - -void raw_miopen_convolution_backward_weight_out( - const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - - auto dataType = getMiopenDataType(input); - miopenConvolutionMode_t c_mode = miopenConvolution; - - ConvolutionArgs args{ input, grad_output, grad_weight }; - args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(input); - args.wdesc.set(grad_weight, input.suggest_memory_format(), 0); - args.odesc.set(grad_output); - args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on - uint64_t solution_id; - Workspace workspace = chooseSolution(args, &solution_id); - - MIOPEN_CHECK(miopenConvolutionBackwardWeightsImmediate( - args.handle, - args.odesc.desc(), grad_output.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), - args.cdesc.desc(), - args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size, solution_id)); - } - else { - miopenConvBwdWeightsAlgorithm_t bwdFilterAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - - MIOPEN_CHECK(miopenConvolutionBackwardWeights( - args.handle, - &one, args.odesc.desc(), grad_output.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), - args.cdesc.desc(), bwdFilterAlg, &zero, - args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size)); - } -} - -//Depthwise backward weights. -void raw_miopen_depthwise_convolution_backward_weight_out( - const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - - auto dataType = getMiopenDataType(input); - miopenConvolutionMode_t c_mode = miopenDepthwise; - - ConvolutionArgs args{ input, grad_output, grad_weight }; - args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(input); - args.wdesc.set(grad_weight, input.suggest_memory_format(), 0); - args.odesc.set(grad_output); - args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); - - if (deterministic && !benchmark) { - // immediate mode is triggered for the specific combination of benchmark=off deterministic=on - uint64_t solution_id; - Workspace workspace = chooseSolution(args, &solution_id); - - MIOPEN_CHECK(miopenConvolutionBackwardWeightsImmediate( - args.handle, - args.odesc.desc(), grad_output.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), - args.cdesc.desc(), - args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size, solution_id)); - } - else { - miopenConvBwdWeightsAlgorithm_t bwdFilterAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - - MIOPEN_CHECK(miopenConvolutionBackwardWeights( - args.handle, - &one, args.odesc.desc(), grad_output.const_data_ptr(), - args.idesc.desc(), input.const_data_ptr(), - args.cdesc.desc(), bwdFilterAlg, &zero, - args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size)); - } -} - -Tensor miopen_depthwise_convolution_backward_weight( - CheckedFrom c, - IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - - checkAllSameType(c, {grad_output, input}); - checkAllSameGPU(c, {grad_output, input}); - - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*input, *grad_output)) { - memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } - - Tensor grad_output_contig_t = grad_output->contiguous(memory_format); - // Make sure that NC11 strides follow formula - grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format); - TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; - - Tensor input_contig_t = input->contiguous(memory_format); - input_contig_t.resize_(input_contig_t.sizes(), memory_format); - TensorArg input_contig{ input_contig_t, "input", 2}; - - auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format); - - // For uniformity with everything else, although it seems grad_weight - // would be unambiguous too. - TensorArg grad_weight{ grad_weight_t, "result", 0 }; - convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); - - raw_miopen_depthwise_convolution_backward_weight_out( - *grad_weight, *grad_output_contig, *input_contig, - padding, stride, dilation, groups, benchmark, deterministic); - - return grad_weight_t; -} - -Tensor miopen_depthwise_convolution_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - input{ input_t, "input", 2 }; - return miopen_depthwise_convolution_backward_weight( - "miopen_depthwise_convolution_backward_weight", - weight_size, grad_output, input, - padding, stride, dilation, groups, benchmark, deterministic); -} - -Tensor miopen_convolution_backward_weight( - CheckedFrom c, - IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - - checkAllSameType(c, {grad_output, input}); - checkAllSameGPU(c, {grad_output, input}); - - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*input, *grad_output)) { - memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } - - Tensor grad_output_contig_t = grad_output->contiguous(memory_format); - // Make sure that NC11 strides follow formula - grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format); - TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 }; - - Tensor input_contig_t = input->contiguous(memory_format); - input_contig_t.resize_(input_contig_t.sizes(), memory_format); - TensorArg input_contig{ input_contig_t, "input", 2}; - - auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format); - - // For uniformity with everything else, although it seems grad_weight - // would be unambiguous too. - TensorArg grad_weight{ grad_weight_t, "result", 0 }; - convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups); - - raw_miopen_convolution_backward_weight_out( - *grad_weight, *grad_output_contig, *input_contig, - padding, stride, dilation, groups, benchmark, deterministic); - - return grad_weight_t; -} - -Tensor miopen_convolution_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - input{ input_t, "input", 2 }; - return miopen_convolution_backward_weight( - "miopen_convolution_backward_weight", - weight_size, grad_output, input, - padding, stride, dilation, groups, benchmark, deterministic); -} - -Tensor miopen_convolution_transpose_backward_input( - const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) -{ - TensorArg grad_output { grad_output_t, "grad_output", 1 }, - weight { weight_t, "weight", 2 }; - return miopen_convolution_forward( - "miopen_convolution_transpose_backward_input", - grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); -} - -Tensor miopen_convolution_transpose_backward_weight( - IntArrayRef weight_size, - const Tensor& grad_output_t, - const Tensor& input_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - input{ input_t, "input", 2 }; - return miopen_convolution_backward_weight( - "miopen_convolution_backward_weight", - weight_size, input, grad_output, - padding, stride, dilation, groups, benchmark, deterministic); -} - -std::tuple miopen_convolution_transpose_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); - - Tensor grad_input, grad_weight, grad_bias; - if (output_mask[0]) { - grad_input = miopen_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[1]) { - grad_weight = miopen_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[2]) { - grad_bias = miopen_convolution_backward_bias(grad_output); - } - - return std::tuple{grad_input, grad_weight, grad_bias}; + + return std::tuple{grad_input, grad_weight, grad_bias}; } // --------------------------------------------------------------------- @@ -1222,23 +1100,50 @@ std::tuple miopen_convolution_transpose_backwa // // --------------------------------------------------------------------- -void raw_miopen_convolution_backward_input_out( +// See NOTE [ Backward vs transpose convolutions ] in aten/src/ATen/native/cudnn/ConvShared.cpp + +void raw_miopen_convolution_backward_input_out_32bit( const at::Tensor& grad_input, const at::Tensor& grad_output, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { auto dataType = getMiopenDataType(grad_output); - miopenConvolutionMode_t c_mode = miopenConvolution; + miopenConvolutionMode_t c_mode = depthwise ? miopenDepthwise : miopenConvolution; - ConvolutionArgs args{ grad_input, grad_output, weight }; + ConvolutionArgs args{grad_input, grad_output, weight}; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(grad_input); - args.wdesc.set(weight, grad_output.suggest_memory_format(), 0); - args.odesc.set(grad_output); - args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); + at::MemoryFormat memory_format = + miopen_conv_suggest_memory_format(grad_input, weight); + setConvolutionParams( + &args.params, + args.handle, + grad_input, + weight, + padding, + stride, + dilation, + groups, + deterministic, + memory_format); + args.idesc.set(grad_input, memory_format); + args.wdesc.set(weight, memory_format, 0); + args.odesc.set(grad_output, memory_format); + args.cdesc.set( + dataType, + c_mode, + grad_output.dim() - 2, + args.params.padding, + args.params.stride, + args.params.dilation, + args.params.groups, + benchmark, + deterministic); if (deterministic && !benchmark) { // immediate mode is triggered for the specific combination of benchmark=off deterministic=on @@ -1250,7 +1155,10 @@ void raw_miopen_convolution_backward_input_out( args.odesc.desc(), grad_output.const_data_ptr(), args.wdesc.desc(), weight.const_data_ptr(), args.cdesc.desc(), - args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size, solution_id)); + args.idesc.desc(), grad_input.mutable_data_ptr(), + workspace.data, + workspace.size, + solution_id)); } else { miopenConvBwdDataAlgorithm_t bwdDataAlg; @@ -1261,217 +1169,522 @@ void raw_miopen_convolution_backward_input_out( MIOPEN_CHECK(miopenConvolutionBackwardData( args.handle, - &one, args.odesc.desc(), grad_output.const_data_ptr(), + &one, + args.odesc.desc(), grad_output.const_data_ptr(), args.wdesc.desc(), weight.const_data_ptr(), - args.cdesc.desc(), bwdDataAlg, &zero, - args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size)); + args.cdesc.desc(), + bwdDataAlg, + &zero, + args.idesc.desc(), grad_input.mutable_data_ptr(), + workspace.data, + workspace.size)); } } -// see NOTE [ Backward vs transpose convolutions ] in src/Aten/native/cudnn/Conv.cpp +void raw_miopen_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + split_batch_dim_to_32bit_out( + grad_input, + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise, + 1024 * 1024 * 128, + raw_miopen_convolution_backward_input_out_32bit); +} Tensor miopen_convolution_backward_input( CheckedFrom c, - IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ + IntArrayRef input_size, + const TensorArg& grad_output, + const TensorArg& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*grad_output, *weight)) { - memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } - + auto memory_format = miopen_conv_suggest_memory_format(*grad_output, *weight); Tensor grad_input_t = at::detail::empty_cuda( input_size, grad_output->options().memory_format(memory_format)); // Avoid "grad_input" when this is being used as transposed convolution - TensorArg grad_input{ grad_input_t, "result", 0 }; - convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); + TensorArg grad_input{grad_input_t, "result", 0}; + convolution_shape_check( + c, grad_input, weight, grad_output, padding, stride, dilation, groups); - // See #4500 Tensor weight_contig = weight->contiguous(memory_format); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), memory_format); - Tensor grad_output_contig = grad_output->contiguous(memory_format); - grad_output_contig.resize_(grad_output_contig.sizes(), memory_format); raw_miopen_convolution_backward_input_out( - *grad_input, grad_output_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); + *grad_input, + grad_output_contig, + weight_contig, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); return *grad_input; } -Tensor miopen_convolution_transpose_forward( - CheckedFrom c, - const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), - padding, output_padding, stride, dilation, groups); - return miopen_convolution_backward_input(c, input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic); -} - +// overload Tensor miopen_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; + IntArrayRef input_size, + const Tensor& grad_output_t, + const Tensor& weight_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + TensorArg grad_output{grad_output_t, "grad_output", 1}, + weight{weight_t, "weight", 2}; return miopen_convolution_backward_input( "miopen_convolution_backward_input", - input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic); + input_size, + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); } -//Depthwise convolutions backward data. -void raw_miopen_depthwise_convolution_backward_input_out( - const at::Tensor& grad_input, - const at::Tensor& grad_output, - const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) { - - auto dataType = getMiopenDataType(grad_output); - miopenConvolutionMode_t c_mode = miopenDepthwise; +void raw_miopen_convolution_backward_weight_out_32bit( + const Tensor& grad_weight, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + auto dataType = getMiopenDataType(input); + miopenConvolutionMode_t c_mode = depthwise ? miopenDepthwise : miopenConvolution; - ConvolutionArgs args{ grad_input, grad_output, weight }; + ConvolutionArgs args{input, grad_output, grad_weight}; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, grad_input, weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(grad_input); - args.wdesc.set(weight, grad_output.suggest_memory_format(), 0); - args.odesc.set(grad_output); - args.cdesc.set(dataType, c_mode, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); + at::MemoryFormat memory_format = + miopen_conv_suggest_memory_format(input, grad_weight); + setConvolutionParams( + &args.params, + args.handle, + input, + grad_weight, + padding, + stride, + dilation, + groups, + deterministic, + memory_format); + args.idesc.set(input, memory_format); + args.wdesc.set(grad_weight, memory_format, 0); + args.odesc.set(grad_output, memory_format); + args.cdesc.set( + dataType, + c_mode, + input.dim() - 2, + args.params.padding, + args.params.stride, + args.params.dilation, + args.params.groups, + benchmark, + deterministic); if (deterministic && !benchmark) { // immediate mode is triggered for the specific combination of benchmark=off deterministic=on uint64_t solution_id; - Workspace workspace = chooseSolution(args, &solution_id); + Workspace workspace = chooseSolution(args, &solution_id); - MIOPEN_CHECK(miopenConvolutionBackwardDataImmediate( + MIOPEN_CHECK(miopenConvolutionBackwardWeightsImmediate( args.handle, args.odesc.desc(), grad_output.const_data_ptr(), - args.wdesc.desc(), weight.const_data_ptr(), + args.idesc.desc(), input.const_data_ptr(), args.cdesc.desc(), - args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size, solution_id)); + args.wdesc.desc(), grad_weight.data_ptr(), + workspace.data, + workspace.size, + solution_id)); + } + else { + miopenConvBwdWeightsAlgorithm_t bwdFilterAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenConvolutionBackwardWeights( + args.handle, + &one, + args.odesc.desc(), grad_output.const_data_ptr(), + args.idesc.desc(), input.const_data_ptr(), + args.cdesc.desc(), + bwdFilterAlg, + &zero, + args.wdesc.desc(), grad_weight.data_ptr(), + workspace.data, + workspace.size)); + } +} + +void raw_miopen_convolution_backward_weight_out( + const Tensor& grad_weight, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + constexpr int64_t int_max = std::numeric_limits::max(); + const int64_t ni = input.numel(); + const int64_t no = grad_output.numel(); + // Assume the shape of the tensor is (N, C, D1, D2, ...) + // if N * C * D1 * D2 * ... <= int_max, then no need to split at all + if (ni <= int_max && no <= int_max) { + raw_miopen_convolution_backward_weight_out_32bit( + grad_weight, + grad_output, + input, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); + return; } - else { - miopenConvBwdDataAlgorithm_t bwdDataAlg; - Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg); - - Constant one(dataType, 1); - Constant zero(dataType, 0); - - MIOPEN_CHECK(miopenConvolutionBackwardData( - args.handle, - &one, args.odesc.desc(), grad_output.const_data_ptr(), - args.wdesc.desc(), weight.const_data_ptr(), - args.cdesc.desc(), bwdDataAlg, &zero, - args.idesc.desc(), grad_input.mutable_data_ptr(), workspace.data, workspace.size)); + // else, if C * D1 * D2 * ... <= int_max, then we just need to split across + // the N dimension + // + // Here we use a simple heuristics to determine the size of each split + // We don't max out the 2^31 address space because this number is super + // large and very likely to get an OOM. + int64_t n = grad_output.size(0); + int64_t max_inner_size = std::max(ni, no) / n; + int64_t split_size = + std::max(1024 * 1024 * 512 / max_inner_size, 1L); + int64_t num_splits = (n + split_size - 1) / split_size; + if (split_size * max_inner_size < int_max) { + const auto kAccType = (grad_weight.scalar_type() == kHalf || + grad_weight.scalar_type() == kBFloat16) + ? kFloat + : grad_weight.scalar_type(); + Tensor grad_weight_accumulator = + at::zeros(grad_weight.sizes(), grad_weight.options().dtype(kAccType)); + for (const auto i : c10::irange(num_splits)) { + int64_t start = split_size * i; + int64_t split_size_ = std::min(split_size, n - start); + Tensor input_ = input.narrow(0, start, split_size_); + Tensor grad_output_ = grad_output.narrow(0, start, split_size_); + Tensor grad_weight_ = at::empty_like(grad_weight); + raw_miopen_convolution_backward_weight_out_32bit( + grad_weight_, + grad_output_, + input_, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); + grad_weight_accumulator.add_(grad_weight_); + } + grad_weight.copy_(grad_weight_accumulator); + return; } + // If control flow reaches here, this means even splitting N is not enough, + // then things starts to become complicated: For example, for conv2d, there + // following questions needs to be considered. + // - Is the memory layout NCHW or NHWC ? + // - If the conv is NCHW -> NC'H'W', then should we + // - split only NC? + // - split only N'C'? + // - split both? + // - If the conv is NHWC, then we need to split across H, we need to be very + // careful about the boundary condition + // to make sure that the boundary is handled correctly. + // - If we decide to make these splits, is the memory contiguous? Do we need + // to copy the memory? Considering the complexity of this issue, it is better + // not to use cuDNN for this case + TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN."); } -Tensor miopen_depthwise_convolution_backward_input( +Tensor miopen_convolution_backward_weight( CheckedFrom c, - IntArrayRef input_size, const TensorArg& grad_output, const TensorArg& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - checkAllSameType(c, {grad_output, weight}); - checkAllSameGPU(c, {grad_output, weight}); + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + auto memory_format = miopen_conv_suggest_memory_format(input_t, grad_output_t); - auto memory_format = at::MemoryFormat::Contiguous; - if (miopen_conv_use_channels_last(*grad_output, *weight)) { - memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; - } + Tensor grad_output_contig_t = grad_output_t.contiguous(memory_format); + TensorArg grad_output_contig{grad_output_contig_t, "grad_output", 1}; - Tensor grad_input_t = at::detail::empty_cuda( - input_size, grad_output->options().memory_format(memory_format)); + Tensor input_contig_t = input_t.contiguous(memory_format); + TensorArg input{input_contig_t, "input", 2}; - TensorArg grad_input{ grad_input_t, "result", 0 }; - convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); + checkAllSameType(c, {grad_output_contig, input}); + checkAllSameGPU(c, {grad_output_contig, input}); - // See #4500 - Tensor weight_contig = weight->contiguous(memory_format); - // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), memory_format); + auto grad_weight_t = + at::empty(weight_size, grad_output_contig->options(), memory_format); - Tensor grad_output_contig = grad_output->contiguous(memory_format); - grad_output_contig.resize_(grad_output_contig.sizes(), memory_format); + // For uniformity with everything else, although it seems grad_weight + // would be unambiguous too. + TensorArg grad_weight{grad_weight_t, "result", 0}; + convolution_shape_check( + c, + input, + grad_weight, + grad_output_contig, + padding, + stride, + dilation, + groups); - raw_miopen_depthwise_convolution_backward_input_out( - *grad_input, grad_output_contig, weight_contig, - padding, stride, dilation, groups, benchmark, deterministic); + raw_miopen_convolution_backward_weight_out( + *grad_weight, + *grad_output_contig, + *input, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); - return *grad_input; + return grad_weight_t; } -Tensor miopen_depthwise_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic) -{ - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; - return miopen_depthwise_convolution_backward_input( - "miopen_depthwise_convolution_backward_input", - input_size, grad_output, weight, - padding, stride, dilation, groups, benchmark, deterministic); +// overload +Tensor miopen_convolution_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + bool depthwise=false) { + return miopen_convolution_backward_weight( + "miopen_convolution_backward_weight", + weight_size, + grad_output_t, + input_t, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + depthwise); } -std::tuple miopen_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { - - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); +std::tuple miopen_convolution_backward( + const at::Tensor& input, + const at::Tensor& grad_output_t, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + std::array output_mask) { + Tensor grad_output = grad_output_t.to(input.suggest_memory_format()); Tensor grad_input, grad_weight, grad_bias; - if (output_mask[0]) { - grad_input = miopen_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[1]) { - grad_weight = miopen_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[2]) { - grad_bias = miopen_convolution_backward_bias(grad_output); + if (input.numel() == 0) { + if (output_mask[0]) { + grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (output_mask[1]) { + grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (output_mask[2]) { + grad_bias = at::zeros_like(grad_output_t, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + } else { + if (output_mask[0]) { + grad_input = miopen_convolution_backward_input( + input.sizes(), + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); + } + if (output_mask[1]) { + grad_weight = miopen_convolution_backward_weight( + weight.sizes(), + grad_output, + input, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); + } + if (output_mask[2]) { + grad_bias = miopen_convolution_backward_bias(grad_output); + } } - return std::tuple{grad_input, grad_weight, grad_bias}; + return std::tuple{grad_input, grad_weight, grad_bias}; } -std::tuple miopen_depthwise_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, - bool benchmark, bool deterministic, std::array output_mask) { +Tensor miopen_convolution_transpose_forward( + CheckedFrom c, + const TensorArg& grad_output, + const TensorArg& weight, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + auto input_size = conv_input_size( + grad_output->sizes(), + weight->sizes(), + padding, + output_padding, + stride, + dilation, + groups); + return miopen_convolution_backward_input( + c, + input_size, + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); +} - Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); +Tensor miopen_convolution_transpose_backward_weight( + IntArrayRef weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + return miopen_convolution_backward_weight( + "miopen_convolution_backward_weight", + weight_size, + input_t, + grad_output_t, + padding, + stride, + dilation, + groups, + benchmark, + deterministic); +} - Tensor grad_input, grad_weight, grad_bias; - if (output_mask[0]) { - grad_input = miopen_depthwise_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[1]) { - grad_weight = miopen_depthwise_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); - } - if (output_mask[2]) { - grad_bias = miopen_convolution_backward_bias(grad_output); - } +Tensor miopen_convolution_transpose( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_t_opt, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); + const Tensor& bias_t = *bias_t_maybe_owned; - return std::tuple{grad_input, grad_weight, grad_bias}; + TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2}, bias{bias_t, "bias", 3}; + CheckedFrom c = "miopen_convolution_transpose"; + auto output_t = miopen_convolution_transpose_forward( + c, + input, + weight, + padding, + output_padding, + stride, + dilation, + groups, + benchmark, + deterministic); + if (bias->defined()) { + miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); + } + return output_t; } -Tensor miopen_convolution_transpose( - const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups, bool benchmark, bool deterministic) +// --------------------------------------------------------------------- +// +// Convolution depthwise +// +// --------------------------------------------------------------------- + +Tensor miopen_depthwise_convolution( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_t_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); @@ -1480,16 +1693,86 @@ Tensor miopen_convolution_transpose( TensorArg input { input_t, "input", 1 }, weight { weight_t, "weight", 2 }, bias { bias_t, "bias", 3 }; - CheckedFrom c = "miopen_convolution_transpose"; - auto output_t = miopen_convolution_transpose_forward( - c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + CheckedFrom c = "miopen_depthwise_convolution"; + auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t); + Tensor output_t = at::detail::empty_cuda( + conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation), + input_t.options().memory_format(memory_format)); + if (output_t.numel() == 0) { + return output_t; + } + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{output_t, "result", 0}; + miopen_convolution_forward_out( + output, + c, + input, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + true); if (bias->defined()) { - miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); + miopen_convolution_add_bias_(c, output, bias); } - return output_t; + return *output; } -// MIOpen fused convolution bias activation forward +std::tuple miopen_depthwise_convolution_backward( + const at::Tensor& input, + const at::Tensor& grad_output_t, + const at::Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + std::array output_mask) { + Tensor grad_output = grad_output_t.to(input.suggest_memory_format()); + + Tensor grad_input, grad_weight, grad_bias; + if (output_mask[0]) { + grad_input = miopen_convolution_backward_input( + input.sizes(), + grad_output, + weight, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + true); + } + if (output_mask[1]) { + grad_weight = miopen_convolution_backward_weight( + weight.sizes(), + grad_output, + input, + padding, + stride, + dilation, + groups, + benchmark, + deterministic, + true); + } + if (output_mask[2]) { + grad_bias = miopen_convolution_backward_bias(grad_output); + } + + return std::tuple{grad_input, grad_weight, grad_bias}; +} + +// --------------------------------------------------------------------- +// fusions +// --------------------------------------------------------------------- + void raw_miopen_convolution_relu_out( const Tensor& output, const Tensor& input, @@ -1501,17 +1784,35 @@ void raw_miopen_convolution_relu_out( int64_t groups, bool benchmark, bool deterministic) { - auto dataType = getMiopenDataType(input); miopenConvolutionMode_t c_mode = miopenConvolution; - ConvolutionArgs args{ input, output, weight }; args.handle = getMiopenHandle(); - setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic); - args.idesc.set(input); - args.wdesc.set(weight, input.suggest_memory_format(), 0); - args.odesc.set(output); - args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups, benchmark, deterministic); + at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight); + setConvolutionParams( + &args.params, + args.handle, + input, + weight, + padding, + stride, + dilation, + groups, + deterministic, + memory_format); + args.idesc.set(input, memory_format); + args.wdesc.set(weight, memory_format, 0); + args.odesc.set(output, memory_format); + args.cdesc.set( + dataType, + c_mode, + input.dim() - 2, + args.params.padding, + args.params.stride, + args.params.dilation, + args.params.groups, + benchmark, + deterministic); TensorDescriptor bdesc; bdesc.set(bias.expand({1, bias.size(0)}), output.dim()); @@ -1555,8 +1856,8 @@ static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat m } Tensor miopen_convolution_add_relu( - const Tensor& input, - const Tensor& weight, + const Tensor& input_t, + const Tensor& weight_t, const Tensor& z, const std::optional& alpha, const std::optional& bias, @@ -1568,17 +1869,28 @@ Tensor miopen_convolution_add_relu( // MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function: // y = act ( alpha1 * conv(x) + alpha2 * z + bias ) - auto memory_format = input.suggest_memory_format(); + auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t); auto& ctx = at::globalContext(); bool benchmark = ctx.benchmarkCuDNN(); - TensorArg input_arg { input, "input", 1 }, - weight_arg { weight, "weight", 2 }; - auto output = miopen_convolution_forward( + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + + Tensor output_t = at::detail::empty_cuda( + conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation), + input_t.options().memory_format(memory_format)); + if (output_t.numel() == 0){ + return output_t; + } + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{output_t, "result", 0}; + miopen_convolution_forward_out( + output, "miopen_convolution_add_relu", - input_arg, - weight_arg, + input, + weight, padding, stride, dilation, @@ -1587,53 +1899,51 @@ Tensor miopen_convolution_add_relu( false // deterministic ); - auto contig_output = self_or_new_memory_format(output, memory_format); + auto contig_output_t = self_or_new_memory_format(output_t, memory_format); - if (!output.is_same(contig_output)) { - contig_output.copy_(output); + if (!output_t.is_same(contig_output_t)) { + contig_output_t.copy_(output_t); } auto _alpha = alpha.has_value() ? alpha.value().to() : 1.0; auto _bias = bias.has_value() ? bias.value() : at::zeros( - {contig_output.size(1)}, - optTypeMetaToScalarType(contig_output.options().dtype_opt()), - contig_output.options().layout_opt(), - contig_output.options().device_opt(), - contig_output.options().pinned_memory_opt()); + {contig_output_t.size(1)}, + optTypeMetaToScalarType(contig_output_t.options().dtype_opt()), + contig_output_t.options().layout_opt(), + contig_output_t.options().device_opt(), + contig_output_t.options().pinned_memory_opt()); - at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input.dim(), _bias).add(z, _alpha); - contig_output.add_(alpha_mul_z_add_bias); - contig_output.relu_(); + at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input_t.dim(), _bias).add(z, _alpha); + contig_output_t.add_(alpha_mul_z_add_bias); + contig_output_t.relu_(); - return contig_output; + return contig_output_t; } Tensor miopen_convolution_relu( - const Tensor& input, - const Tensor& weight, + const Tensor& input_t, + const Tensor& weight_t, const std::optional& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { - auto memory_format = input.suggest_memory_format(); - auto& ctx = at::globalContext(); bool benchmark = ctx.benchmarkCuDNN(); // MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d - if (input.suggest_memory_format() == at::MemoryFormat::Contiguous - && input.scalar_type() == at::kFloat - && input.ndimension() == 4) { + if (input_t.suggest_memory_format() == at::MemoryFormat::Contiguous + && input_t.scalar_type() == at::kFloat + && input_t.ndimension() == 4) { // FuseFrozenConvAddRelu performs some tensor shape checking Tensor output_t = at::detail::empty_cuda( conv_output_size( - input.sizes(), weight.sizes(), padding, stride, dilation), - input.options().memory_format(input.suggest_memory_format())); + input_t.sizes(), weight_t.sizes(), padding, stride, dilation), + input_t.options().memory_format(input_t.suggest_memory_format())); if (output_t.numel() == 0) { return output_t; } @@ -1649,8 +1959,8 @@ Tensor miopen_convolution_relu( raw_miopen_convolution_relu_out( output_t, - input, - weight, + input_t, + weight_t, _bias, stride, padding, @@ -1665,12 +1975,25 @@ Tensor miopen_convolution_relu( else { // fallback - TensorArg input_arg { input, "input", 1 }, - weight_arg { weight, "weight", 2 }; - auto output = miopen_convolution_forward( + auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t); + + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }; + + Tensor output_t = at::detail::empty_cuda( + conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation), + input->options().memory_format(memory_format)); + if (output_t.numel() == 0){ + return output_t; + } + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{output_t, "result", 0}; + miopen_convolution_forward_out( + output, "miopen_convolution_relu", - input_arg, - weight_arg, + input, + weight, padding, stride, dilation, @@ -1679,26 +2002,26 @@ Tensor miopen_convolution_relu( false // deterministic ); - auto contig_output = self_or_new_memory_format(output, memory_format); + auto contig_output_t = self_or_new_memory_format(output_t, memory_format); - if (!output.is_same(contig_output)) { - contig_output.copy_(output); + if (!output_t.is_same(contig_output_t)) { + contig_output_t.copy_(output_t); } auto _bias = bias.has_value() ? bias.value() : at::zeros( - {contig_output.size(1)}, - optTypeMetaToScalarType(contig_output.options().dtype_opt()), - contig_output.options().layout_opt(), - contig_output.options().device_opt(), - contig_output.options().pinned_memory_opt()); + {contig_output_t.size(1)}, + optTypeMetaToScalarType(contig_output_t.options().dtype_opt()), + contig_output_t.options().layout_opt(), + contig_output_t.options().device_opt(), + contig_output_t.options().pinned_memory_opt()); - at::Tensor reshaped_bias = at::native::reshape_bias(input.dim(), _bias); - contig_output.add_(reshaped_bias); - contig_output.relu_(); + at::Tensor reshaped_bias = at::native::reshape_bias(input_t.dim(), _bias); + contig_output_t.add_(reshaped_bias); + contig_output_t.relu_(); - return contig_output; + return contig_output_t; } } diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 858e72416c890..22555a9b588d6 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -1,6 +1,7 @@ # Owner(s): ["module: nn"] import itertools import math +import os import unittest import warnings from itertools import product @@ -61,6 +62,10 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() +if TEST_WITH_ROCM: + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + + if TEST_SCIPY: import scipy.ndimage import scipy.signal @@ -4042,6 +4047,7 @@ def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): self.assertEqual(grad_input.shape, input.shape) self.assertEqual(grad_weight.shape, weight.shape) + @skipCUDAIfRocm @onlyCUDA @largeTensorTest("40GB") @largeTensorTest("24GB", "cpu") diff --git a/test/test_nn.py b/test/test_nn.py index e7a898f0cf22c..f7c018a32e9fe 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -58,6 +58,9 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() +if TEST_WITH_ROCM: + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests