@@ -49,7 +49,22 @@ ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu)
49
49
isSelectAlgo_ = false ;
50
50
}
51
51
52
- void ConvBaseOperator::allocConvWorkSpace (size_t maxWorkSpace) {
52
+ void ConvBaseOperator::allocConvWorkSpace () {
53
+ hl_conv_workspace (imageDesc_,
54
+ outputDesc_,
55
+ filterDesc_,
56
+ convDesc_,
57
+ &fwdAlgo_,
58
+ &fwdLimitBytes_,
59
+ &bwdDataAlgo_,
60
+ &bwdDataLimitBytes_,
61
+ &bwdFilterAlgo_,
62
+ &bwdFilterLimitBytes_);
63
+
64
+ size_t maxWorkSpace = 0 ;
65
+ maxWorkSpace = std::max (fwdLimitBytes_, bwdDataLimitBytes_);
66
+ maxWorkSpace = std::max (maxWorkSpace, bwdFilterLimitBytes_);
67
+
53
68
if (maxWorkSpace > workSpaceInBytes_) {
54
69
if (workSpaceInBytes_ != 0 ) {
55
70
hl_free_mem_device (workSpace_);
@@ -60,59 +75,6 @@ void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) {
60
75
}
61
76
}
62
77
63
- void ConvBaseOperator::reshape (int batchSize) {
64
- if (isDeconv_) {
65
- outputH_ = ins_[0 ]->getFrameHeight ();
66
- outputW_ = ins_[0 ]->getFrameWidth ();
67
- if (outputH_ == 0 ) outputH_ = outputY_;
68
- if (outputW_ == 0 ) outputW_ = outputX_;
69
- imageH_ =
70
- imageSize (outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
71
- imageW_ = imageSize (outputW_, filterSize_, padding_, stride_, caffeMode_);
72
- // / Check that the imageSizes are consistent with config
73
- CHECK_EQ (imageH_, imgSizeY_);
74
- CHECK_EQ (imageW_, imgSize_);
75
- out_->setFrameHeight (imageH_);
76
- out_->setFrameWidth (imageW_);
77
- } else {
78
- imageH_ = ins_[0 ]->getFrameHeight ();
79
- imageW_ = ins_[0 ]->getFrameWidth ();
80
- if (imageH_ == 0 ) imageH_ = imgSizeY_;
81
- if (imageW_ == 0 ) imageW_ = imgSize_;
82
- outputH_ =
83
- outputSize (imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
84
- outputW_ = outputSize (imageW_, filterSize_, padding_, stride_, caffeMode_);
85
- // / Check that the outputSizes are consistent with config
86
- CHECK_EQ (outputH_, outputY_);
87
- CHECK_EQ (outputW_, outputX_);
88
- out_->setFrameHeight (outputH_);
89
- out_->setFrameWidth (outputW_);
90
- }
91
-
92
- reshapeImageDescriptors ();
93
-
94
- if (!isSelectAlgo_) {
95
- hl_conv_workspace (imageDesc_,
96
- outputDesc_,
97
- filterDesc_,
98
- convDesc_,
99
- &fwdAlgo_,
100
- &fwdLimitBytes_,
101
- &bwdDataAlgo_,
102
- &bwdDataLimitBytes_,
103
- &bwdFilterAlgo_,
104
- &bwdFilterLimitBytes_);
105
-
106
- size_t maxWorkSpace = 0 ;
107
- maxWorkSpace = std::max (fwdLimitBytes_, bwdDataLimitBytes_);
108
- maxWorkSpace = std::max (maxWorkSpace, bwdFilterLimitBytes_);
109
-
110
- allocConvWorkSpace (maxWorkSpace);
111
- }
112
-
113
- isSelectAlgo_ = true ;
114
- }
115
-
116
78
void ConvBaseOperator::computeConvSizes () {
117
79
hl_create_filter_descriptor (
118
80
&filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
@@ -153,15 +115,6 @@ void ConvBaseOperator::reshapeImageDescriptors() {
153
115
padding_,
154
116
strideY_,
155
117
stride_);
156
-
157
- if (isDeconv_) {
158
- inputOffset_ = numFilters_ * outputH_ * outputW_;
159
- outputOffset_ = channels_ * imageH_ * imageW_;
160
- } else {
161
- inputOffset_ = channels_ * imageH_ * imageW_;
162
- outputOffset_ = numFilters_ * outputH_ * outputW_;
163
- }
164
- weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
165
118
}
166
119
167
120
void ConvBaseOperator::getConvParams () {
0 commit comments