|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "ConvBaseOperator.h" |
| 16 | +#include "paddle/math/MathUtils.h" |
| 17 | +#include "paddle/math/Matrix.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | + |
| 21 | +/** |
| 22 | + * @brief ConvBaseOperator takes two inputs to perform the convolution. |
| 23 | + * The first input is the image, and the second input is the convolution kernel. |
| 24 | + * The height of data for two inputs are the same. Each data of the first input |
| 25 | + * is convolved with each data of the second input indepedently. |
| 26 | + * |
| 27 | + * The config file api is conv_operator. |
| 28 | + */ |
| 29 | + |
| 30 | +ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu) |
| 31 | + : Operator(config, useGpu) { |
| 32 | + CHECK(useGpu); |
| 33 | + CHECK_EQ(config_.input_indices_size(), 2L); |
| 34 | + |
| 35 | + caffeMode_ = true; |
| 36 | + getConvParams(); |
| 37 | + computeConvSizes(); |
| 38 | + |
| 39 | + // initialize all to default algorithms |
| 40 | + fwdAlgo_ = 0; |
| 41 | + bwdFilterAlgo_ = 0; |
| 42 | + bwdDataAlgo_ = 0; |
| 43 | + fwdLimitBytes_ = 0; |
| 44 | + bwdDataLimitBytes_ = 0; |
| 45 | + bwdFilterLimitBytes_ = 0; |
| 46 | + workSpaceInBytes_ = 0; |
| 47 | + workSpace_ = nullptr; |
| 48 | + |
| 49 | + isSelectAlgo_ = false; |
| 50 | +} |
| 51 | + |
| 52 | +void ConvBaseOperator::allocConvWorkSpace() { |
| 53 | + hl_conv_workspace(imageDesc_, |
| 54 | + outputDesc_, |
| 55 | + filterDesc_, |
| 56 | + convDesc_, |
| 57 | + &fwdAlgo_, |
| 58 | + &fwdLimitBytes_, |
| 59 | + &bwdDataAlgo_, |
| 60 | + &bwdDataLimitBytes_, |
| 61 | + &bwdFilterAlgo_, |
| 62 | + &bwdFilterLimitBytes_); |
| 63 | + |
| 64 | + size_t maxWorkSpace = 0; |
| 65 | + maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); |
| 66 | + maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); |
| 67 | + |
| 68 | + if (maxWorkSpace > workSpaceInBytes_) { |
| 69 | + if (workSpaceInBytes_ != 0) { |
| 70 | + hl_free_mem_device(workSpace_); |
| 71 | + } |
| 72 | + // total amount of storage needed |
| 73 | + workSpace_ = hl_malloc_device(maxWorkSpace); |
| 74 | + workSpaceInBytes_ = maxWorkSpace; |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +void ConvBaseOperator::computeConvSizes() { |
| 79 | + hl_create_filter_descriptor( |
| 80 | + &filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_); |
| 81 | + hl_create_tensor_descriptor(&imageDesc_); |
| 82 | + hl_create_tensor_descriptor(&outputDesc_); |
| 83 | + hl_create_convolution_descriptor(&convDesc_, |
| 84 | + imageDesc_, |
| 85 | + filterDesc_, |
| 86 | + paddingY_, |
| 87 | + padding_, |
| 88 | + strideY_, |
| 89 | + stride_); |
| 90 | +} |
| 91 | + |
| 92 | +void ConvBaseOperator::reshapeImageDescriptors() { |
| 93 | + hl_tensor_reshape(imageDesc_, |
| 94 | + 1, |
| 95 | + channels_, |
| 96 | + imageH_, |
| 97 | + imageW_, |
| 98 | + channels_ * imageH_ * imageW_, |
| 99 | + imageH_ * imageW_, |
| 100 | + imageW_, |
| 101 | + 1); |
| 102 | + hl_tensor_reshape(outputDesc_, |
| 103 | + 1, |
| 104 | + numFilters_, |
| 105 | + outputH_, |
| 106 | + outputW_, |
| 107 | + numFilters_ * outputH_ * outputW_, |
| 108 | + outputH_ * outputW_, |
| 109 | + outputW_, |
| 110 | + 1); |
| 111 | + hl_reset_convolution_descriptor(convDesc_, |
| 112 | + imageDesc_, |
| 113 | + filterDesc_, |
| 114 | + paddingY_, |
| 115 | + padding_, |
| 116 | + strideY_, |
| 117 | + stride_); |
| 118 | +} |
| 119 | + |
| 120 | +void ConvBaseOperator::getConvParams() { |
| 121 | + configNumFilters_ = config_.num_filters(); |
| 122 | + const ConvConfig &conf = config_.conv_conf(); |
| 123 | + padding_ = conf.padding(); |
| 124 | + stride_ = conf.stride(); |
| 125 | + filterSize_ = conf.filter_size(); |
| 126 | + paddingY_ = conf.padding_y(); |
| 127 | + strideY_ = conf.stride_y(); |
| 128 | + filterSizeY_ = conf.filter_size_y(); |
| 129 | + filterPixels_ = filterSize_ * filterSizeY_; |
| 130 | + configChannels_ = conf.channels(); |
| 131 | + imgSize_ = conf.img_size(); |
| 132 | + imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size(); |
| 133 | + imgPixels_ = imgSize_ * imgSizeY_; |
| 134 | + CHECK_EQ(conf.groups(), 1U); |
| 135 | + filterChannels_ = conf.filter_channels(); |
| 136 | + outputX_ = conf.output_x(); |
| 137 | + outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x(); |
| 138 | + outputs_ = outputX_ * outputX_; |
| 139 | + |
| 140 | + isDeconv_ = (config_.type() == "conv") ? false : true; |
| 141 | + if (isDeconv_) { |
| 142 | + channels_ = configNumFilters_; |
| 143 | + numFilters_ = configChannels_; |
| 144 | + } else { |
| 145 | + channels_ = configChannels_; |
| 146 | + numFilters_ = configNumFilters_; |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +} // namespace paddle |
0 commit comments