Skip to content

Commit d87ac4d

Browse files
author
wangyang59
committed
GPU of bilinear_interp_op done
1 parent ad3b3d9 commit d87ac4d

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

paddle/fluid/operators/bilinear_interp_op.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include "paddle/fluid/operators/bilinear_interp_op.cu.h"
1313
#include "paddle/fluid/operators/bilinear_interp_op.h"
14+
#include "paddle/fluid/operators/math/math_function.h"
15+
#include "paddle/fluid/platform/cuda_helper.h"
1416

1517
namespace paddle {
1618
namespace operators {
@@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
6466
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
6567
auto* d_output = d_output_t->data<T>();
6668

69+
auto& device_ctx =
70+
ctx.template device_context<platform::CUDADeviceContext>();
71+
math::SetConstant<platform::CUDADeviceContext, T> zero;
72+
zero(device_ctx, d_input_t, static_cast<T>(0.0));
73+
6774
int out_h = ctx.Attr<int>("out_h");
6875
int out_w = ctx.Attr<int>("out_w");
6976
int batch_size = d_input_t->dims()[0];

paddle/fluid/operators/bilinear_interp_op.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,13 @@
1010
limitations under the License. */
1111

1212
#pragma once
13-
#include "paddle/fluid/framework/eigen.h"
1413
#include "paddle/fluid/framework/op_registry.h"
14+
#include "paddle/fluid/operators/math/math_function.h"
1515

1616
namespace paddle {
1717
namespace operators {
1818

1919
using Tensor = framework::Tensor;
20-
template <typename T, int MajorType = Eigen::RowMajor,
21-
typename IndexType = Eigen::DenseIndex>
22-
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2320

2421
template <typename T>
2522
class BilinearInterpKernel : public framework::OpKernel<T> {
@@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
8986
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
9087
auto* d_output = d_output_t->data<T>();
9188

89+
auto& device_ctx =
90+
ctx.template device_context<platform::CPUDeviceContext>();
91+
math::SetConstant<platform::CPUDeviceContext, T> zero;
92+
zero(device_ctx, d_input_t, static_cast<T>(0.0));
93+
9294
int out_h = ctx.Attr<int>("out_h");
9395
int out_w = ctx.Attr<int>("out_w");
9496
int batch_size = d_input_t->dims()[0];

0 commit comments

Comments
 (0)