File tree Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Original file line number Diff line number Diff line change 11
11
12
12
#include " paddle/fluid/operators/bilinear_interp_op.cu.h"
13
13
#include " paddle/fluid/operators/bilinear_interp_op.h"
14
+ #include " paddle/fluid/operators/math/math_function.h"
15
+ #include " paddle/fluid/platform/cuda_helper.h"
14
16
15
17
namespace paddle {
16
18
namespace operators {
@@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
64
66
auto * d_input = d_input_t ->mutable_data <T>(ctx.GetPlace ());
65
67
auto * d_output = d_output_t ->data <T>();
66
68
69
+ auto & device_ctx =
70
+ ctx.template device_context <platform::CUDADeviceContext>();
71
+ math::SetConstant<platform::CUDADeviceContext, T> zero;
72
+ zero (device_ctx, d_input_t , static_cast <T>(0.0 ));
73
+
67
74
int out_h = ctx.Attr <int >(" out_h" );
68
75
int out_w = ctx.Attr <int >(" out_w" );
69
76
int batch_size = d_input_t ->dims ()[0 ];
Original file line number Diff line number Diff line change 10
10
limitations under the License. */
11
11
12
12
#pragma once
13
- #include " paddle/fluid/framework/eigen.h"
14
13
#include " paddle/fluid/framework/op_registry.h"
14
+ #include " paddle/fluid/operators/math/math_function.h"
15
15
16
16
namespace paddle {
17
17
namespace operators {
18
18
19
19
using Tensor = framework::Tensor;
20
- template <typename T, int MajorType = Eigen::RowMajor,
21
- typename IndexType = Eigen::DenseIndex>
22
- using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
23
20
24
21
template <typename T>
25
22
class BilinearInterpKernel : public framework ::OpKernel<T> {
@@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
89
86
auto * d_input = d_input_t ->mutable_data <T>(ctx.GetPlace ());
90
87
auto * d_output = d_output_t ->data <T>();
91
88
89
+ auto & device_ctx =
90
+ ctx.template device_context <platform::CPUDeviceContext>();
91
+ math::SetConstant<platform::CPUDeviceContext, T> zero;
92
+ zero (device_ctx, d_input_t , static_cast <T>(0.0 ));
93
+
92
94
int out_h = ctx.Attr <int >(" out_h" );
93
95
int out_w = ctx.Attr <int >(" out_w" );
94
96
int batch_size = d_input_t ->dims ()[0 ];
You can’t perform that action at this time.
0 commit comments