|
| 1 | +#include "ctranslate2/ops/conv1d.h" |
| 2 | + |
| 3 | +#include "cuda/utils.h" |
| 4 | + |
| 5 | +namespace ctranslate2 { |
| 6 | + namespace ops { |
| 7 | + |
| 8 | + template <Device D, typename T> |
| 9 | + void Conv1D::compute(const StorageView& input, |
| 10 | + const StorageView& weight, |
| 11 | + const StorageView* bias, |
| 12 | + StorageView& output, |
| 13 | + const StorageView* qscale) const { |
| 14 | + if (qscale) |
| 15 | + throw std::runtime_error("Quantization is not supported in this Conv1D implementation"); |
| 16 | + |
| 17 | + const int batch_size = input.dim(0); |
| 18 | + const int in_channels = input.dim(1); |
| 19 | + const int input_length = input.dim(2); |
| 20 | + const int output_length = output.dim(2); |
| 21 | + const int out_channels = weight.dim(0); |
| 22 | + const int in_channels_per_group = weight.dim(1); |
| 23 | + const int kernel_size = weight.dim(2); |
| 24 | + |
| 25 | + cudnnDataType_t data_type = cuda::get_cudnn_data_type(input.dtype()); |
| 26 | + |
| 27 | + cudnnTensorDescriptor_t input_desc; |
| 28 | + CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc)); |
| 29 | + CUDNN_CHECK(cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, data_type, |
| 30 | + batch_size, in_channels, 1, input_length)); |
| 31 | + |
| 32 | + cudnnTensorDescriptor_t output_desc; |
| 33 | + CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc)); |
| 34 | + CUDNN_CHECK(cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, data_type, |
| 35 | + batch_size, out_channels, 1, output_length)); |
| 36 | + |
| 37 | + cudnnFilterDescriptor_t weight_desc; |
| 38 | + CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc)); |
| 39 | + CUDNN_CHECK(cudnnSetFilter4dDescriptor(weight_desc, data_type, CUDNN_TENSOR_NCHW, |
| 40 | + out_channels, in_channels_per_group, 1, kernel_size)); |
| 41 | + |
| 42 | + cudnnConvolutionDescriptor_t conv_desc; |
| 43 | + CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc)); |
| 44 | + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc, |
| 45 | + /*pad_h=*/0, /*pad_w=*/_padding, |
| 46 | + /*stride_h=*/1, /*stride_w=*/_stride, |
| 47 | + /*dilation_h=*/1, /*dilation_w=*/_dilation, |
| 48 | + CUDNN_CROSS_CORRELATION, |
| 49 | + CUDNN_DATA_FLOAT)); |
| 50 | + |
| 51 | + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); |
| 52 | + if (_groups > 1) |
| 53 | + CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, _groups)); |
| 54 | + if (data_type == CUDNN_DATA_HALF) |
| 55 | + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); |
| 56 | + |
| 57 | + cudnnHandle_t handle = cuda::get_cudnn_handle(); |
| 58 | + |
| 59 | + cudnnConvolutionFwdAlgo_t algo = (bias |
| 60 | + ? CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM |
| 61 | + : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM); |
| 62 | + |
| 63 | + size_t workspace_size = 0; |
| 64 | + void* workspace = nullptr; |
| 65 | + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle, |
| 66 | + input_desc, |
| 67 | + weight_desc, |
| 68 | + conv_desc, |
| 69 | + output_desc, |
| 70 | + algo, |
| 71 | + &workspace_size)); |
| 72 | + |
| 73 | + if (workspace_size > 0) |
| 74 | + workspace = get_allocator<Device::CUDA>().allocate(workspace_size); |
| 75 | + |
| 76 | + float alpha = 1; |
| 77 | + float beta = 0; |
| 78 | + |
| 79 | + if (bias) { |
| 80 | + cudnnTensorDescriptor_t bias_desc; |
| 81 | + CUDNN_CHECK(cudnnCreateTensorDescriptor(&bias_desc)); |
| 82 | + CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc, CUDNN_TENSOR_NCHW, data_type, |
| 83 | + 1, out_channels, 1, 1)); |
| 84 | + |
| 85 | + cudnnActivationDescriptor_t activation_desc; |
| 86 | + CUDNN_CHECK(cudnnCreateActivationDescriptor(&activation_desc)); |
| 87 | + CUDNN_CHECK(cudnnSetActivationDescriptor(activation_desc, |
| 88 | + CUDNN_ACTIVATION_IDENTITY, |
| 89 | + CUDNN_NOT_PROPAGATE_NAN, |
| 90 | + /*coef=*/0)); |
| 91 | + |
| 92 | + CUDNN_CHECK(cudnnConvolutionBiasActivationForward(handle, |
| 93 | + &alpha, |
| 94 | + input_desc, |
| 95 | + input.buffer(), |
| 96 | + weight_desc, |
| 97 | + weight.buffer(), |
| 98 | + conv_desc, |
| 99 | + algo, |
| 100 | + workspace, |
| 101 | + workspace_size, |
| 102 | + &beta, |
| 103 | + output_desc, |
| 104 | + output.buffer(), |
| 105 | + bias_desc, |
| 106 | + bias->buffer(), |
| 107 | + activation_desc, |
| 108 | + output_desc, |
| 109 | + output.buffer())); |
| 110 | + |
| 111 | + CUDNN_CHECK(cudnnDestroyActivationDescriptor(activation_desc)); |
| 112 | + CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc)); |
| 113 | + |
| 114 | + } else { |
| 115 | + CUDNN_CHECK(cudnnConvolutionForward(handle, |
| 116 | + &alpha, |
| 117 | + input_desc, |
| 118 | + input.buffer(), |
| 119 | + weight_desc, |
| 120 | + weight.buffer(), |
| 121 | + conv_desc, |
| 122 | + algo, |
| 123 | + workspace, |
| 124 | + workspace_size, |
| 125 | + &beta, |
| 126 | + output_desc, |
| 127 | + output.buffer())); |
| 128 | + } |
| 129 | + |
| 130 | + if (workspace) |
| 131 | + get_allocator<Device::CUDA>().free(workspace); |
| 132 | + |
| 133 | + CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc)); |
| 134 | + CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc)); |
| 135 | + CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc)); |
| 136 | + CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc)); |
| 137 | + } |
| 138 | + |
| 139 | +#define DECLARE_IMPL(T) \ |
| 140 | + template void \ |
| 141 | + Conv1D::compute<Device::CUDA, T>(const StorageView& input, \ |
| 142 | + const StorageView& weight, \ |
| 143 | + const StorageView* bias, \ |
| 144 | + StorageView& output, \ |
| 145 | + const StorageView* qscale) const; |
| 146 | + |
| 147 | + DECLARE_IMPL(float) |
| 148 | + DECLARE_IMPL(float16_t) |
| 149 | + DECLARE_IMPL(bfloat16_t) |
| 150 | + |
| 151 | + } |
| 152 | +} |
0 commit comments