Skip to content

Commit 0f353ab

Browse files
authored
cpu gpu transform function (#7191)
* add rename guard * add device_data_transform * add device_data_transform_test * modify GetExpectedKernelType * update operator.run * support test test_label_semantic_roles * optimize code * optimize code * rename GetActualKernelType to GetExpectedKernelType * fix chunk_eval_op and device_data_transform_test * add is_same_place to place * optimize code, refine rename_guard * refine rename guard, add GetKernelTypeForVar * optimize code * add some log * rename guard * use sub scope to create var * fix compile * add IsInitialized for Tensor * add VarIsTensor * fix op_registry_test * test * tmp disable priority * restore switch_kernel.md * code clean
1 parent 8814bec commit 0f353ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+429
-200
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
3232
cc_library(scope SRCS scope.cc DEPS glog threadpool)
3333
cc_test(scope_test SRCS scope_test.cc DEPS scope)
3434

35-
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
35+
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
36+
37+
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
3638
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
3739

3840
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
@@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
7779
cc_test(init_test SRCS init_test.cc DEPS init)
7880

7981
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
82+
83+
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
84+
DEPS operator op_registry init math_function)

paddle/framework/data_transform.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ limitations under the License. */
1414
#include <functional>
1515

1616
#include "paddle/framework/data_transform.h"
17+
#include "paddle/framework/device_data_transform.h"
1718
#include "paddle/framework/lod_tensor.h"
19+
#include "paddle/framework/selected_rows.h"
1820
#include "paddle/platform/device_context.h"
1921

2022
namespace paddle {
@@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
2527
return data_transform_map;
2628
}
2729

30+
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
31+
const OpKernelType& kernel_type_for_var,
32+
const Tensor& input_tensor) {
33+
Tensor* out = nullptr;
34+
if (!platform::is_same_place(kernel_type_for_var.place_,
35+
expected_kernel_type.place_)) {
36+
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
37+
}
38+
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
39+
return out;
40+
}
41+
42+
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
43+
Variable& out_var) {
44+
if (in_var.IsType<LoDTensor>()) {
45+
auto& in_lod_tensor = in_var.Get<LoDTensor>();
46+
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
47+
tran_lod_tensor->set_lod(in_lod_tensor.lod());
48+
tran_lod_tensor->set_layout(in_lod_tensor.layout());
49+
tran_lod_tensor->ShareDataWith(tensor);
50+
} else if (in_var.IsType<SelectedRows>()) {
51+
auto& in_selected_rows = in_var.Get<SelectedRows>();
52+
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
53+
trans_selected_rows->set_height(in_selected_rows.height());
54+
trans_selected_rows->set_rows(in_selected_rows.rows());
55+
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
56+
} else {
57+
PADDLE_THROW("unknown var type");
58+
}
59+
}
60+
2861
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
2962
DataLayout::kNHWC, LibraryType::kPlain);
3063

paddle/framework/data_transform.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include <vector>
2020

2121
#include "paddle/framework/op_kernel_type.h"
22+
#include "paddle/framework/selected_rows.h"
2223
#include "paddle/framework/tensor.h"
2324
#include "paddle/framework/variable.h"
2425
#include "paddle/operators/math/math_function.h"
@@ -49,6 +50,13 @@ struct KernelTypePairHash {
4950
}
5051
};
5152

53+
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
54+
const OpKernelType& kernel_type_for_var,
55+
const Tensor& input_tensor);
56+
57+
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
58+
Variable& out_var);
59+
5260
template <typename InType, typename OutType>
5361
struct CastDataTypeFunctor {
5462
HOSTDEVICE inline OutType operator()(InType in) const {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#include "paddle/framework/device_data_transform.h"
15+
16+
namespace paddle {
17+
namespace framework {
18+
19+
static const platform::DeviceContext* GetDeviceContext(
20+
const platform::Place& src_place, const platform::Place& dst_place) {
21+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
22+
23+
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
24+
return pool.Get(src_place);
25+
} else if (platform::is_cpu_place(src_place) &&
26+
platform::is_gpu_place(dst_place)) {
27+
return pool.Get(dst_place);
28+
} else {
29+
PADDLE_THROW(
30+
"Currently, model parallelism is only supported between CPU and CUDA");
31+
}
32+
}
33+
34+
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
35+
VLOG(3) << "DeviceTransform in, src_place " << in.place()
36+
<< " dst_place: " << dst_place;
37+
Tensor* out = new Tensor();
38+
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
39+
dev_ctx->Wait();
40+
CopyFrom(in, dst_place, *dev_ctx, out);
41+
dev_ctx->Wait();
42+
return out;
43+
}
44+
45+
} // namespace framework
46+
} // namespace paddle
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License. */
13+
14+
#pragma once
15+
16+
#include "paddle/framework/lod_tensor.h"
17+
#include "paddle/framework/tensor.h"
18+
#include "paddle/framework/tensor_util.h"
19+
#include "paddle/platform/device_context.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
24+
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);
25+
26+
} // namespace framework
27+
} // namespace paddle
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "gtest/gtest.h"
16+
17+
#include "paddle/framework/init.h"
18+
#include "paddle/framework/lod_tensor.h"
19+
#include "paddle/framework/op_info.h"
20+
#include "paddle/framework/op_registry.h"
21+
#include "paddle/operators/elementwise_op_function.h"
22+
#include "paddle/operators/math/math_function.h"
23+
#include "paddle/platform/device_context.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
28+
template <typename T>
29+
struct AddFunctor {
30+
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
31+
};
32+
33+
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
34+
public:
35+
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
36+
: OpProtoAndCheckerMaker(proto, op_checker) {
37+
AddInput("input", "input1 of test op");
38+
AddOutput("output", "output of test op");
39+
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
40+
AddComment("This is test op");
41+
}
42+
};
43+
44+
class TestOpWithKernel : public OperatorWithKernel {
45+
public:
46+
using OperatorWithKernel::OperatorWithKernel;
47+
48+
protected:
49+
void InferShape(framework::InferShapeContext* ctx) const override {}
50+
OpKernelType GetExpectedKernelType(
51+
const ExecutionContext& ctx) const override {
52+
if (Attr<bool>("use_gpu")) {
53+
VLOG(3) << "force use gpu kernel";
54+
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
55+
} else {
56+
VLOG(3) << "use default kernel";
57+
return OpKernelType(proto::DataType::FP32,
58+
ctx.Input<Tensor>("input")->place());
59+
}
60+
}
61+
};
62+
63+
template <typename DeviceContext, typename T>
64+
class TestKernel : public OpKernel<float> {
65+
public:
66+
void Compute(const ExecutionContext& ctx) const {
67+
std::cout << ctx.op().DebugString() << std::endl;
68+
69+
const Tensor* input = ctx.Input<Tensor>("input");
70+
71+
std::cout << "input place:" << input->place() << std::endl;
72+
auto* output = ctx.Output<framework::LoDTensor>("output");
73+
output->Resize(input->dims());
74+
output->mutable_data<T>(ctx.GetPlace());
75+
76+
operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
77+
input, input, output, ctx.template device_context<DeviceContext>(),
78+
AddFunctor<T>());
79+
functor.Run();
80+
}
81+
};
82+
83+
} // namespace framework
84+
} // namespace paddle
85+
86+
REGISTER_OP_WITHOUT_GRADIENT(
87+
test_op, paddle::framework::TestOpWithKernel,
88+
paddle::framework::OpKernelTestProtoAndCheckerMaker);
89+
REGISTER_OP_CPU_KERNEL(
90+
test_op,
91+
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
92+
REGISTER_OP_CUDA_KERNEL(
93+
test_op,
94+
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);
95+
96+
static void BuildVar(const std::string& param_name,
97+
std::initializer_list<const char*> arguments,
98+
paddle::framework::proto::OpDesc::Var* var) {
99+
var->set_parameter(param_name);
100+
for (auto& arg_name : arguments) {
101+
*var->mutable_arguments()->Add() = arg_name;
102+
}
103+
}
104+
105+
TEST(Operator, CPUtoGPU) {
106+
using namespace paddle::framework;
107+
using namespace paddle::platform;
108+
109+
ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);
110+
111+
paddle::framework::Scope scope;
112+
paddle::platform::CPUPlace cpu_place;
113+
114+
// create an op to run on CPU
115+
paddle::framework::proto::OpDesc cpu_op_desc;
116+
cpu_op_desc.set_type("test_op");
117+
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
118+
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());
119+
120+
auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
121+
// prepare input
122+
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
123+
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
124+
for (int i = 0; i < 2 * 3; ++i) {
125+
src_ptr[i] = static_cast<float>(i);
126+
}
127+
128+
// get output
129+
auto* output = scope.Var("OUT1");
130+
cpu_op->Run(scope, cpu_place);
131+
132+
auto* output_ptr = output->Get<LoDTensor>().data<float>();
133+
for (int i = 0; i < 2 * 3; ++i) {
134+
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
135+
}
136+
137+
// create an op to run on GPU
138+
paddle::framework::proto::OpDesc gpu_op_desc;
139+
gpu_op_desc.set_type("test_op");
140+
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
141+
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());
142+
143+
auto attr = gpu_op_desc.mutable_attrs()->Add();
144+
attr->set_name("use_gpu");
145+
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
146+
attr->set_b(true);
147+
148+
auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);
149+
150+
paddle::platform::CUDAPlace cuda_place(0);
151+
// get output
152+
auto* output2 = scope.Var("OUT2");
153+
gpu_op->Run(scope, cuda_place);
154+
155+
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
156+
DeviceContextPool& pool = DeviceContextPool::Instance();
157+
auto dev_ctx = pool.Get(cuda_place);
158+
159+
paddle::framework::Tensor output_tensor;
160+
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
161+
&output_tensor);
162+
163+
dev_ctx->Wait();
164+
float* output2_ptr = output_tensor.data<float>();
165+
for (int i = 0; i < 2 * 3; ++i) {
166+
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
167+
}
168+
}

paddle/framework/op_registry_test.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
218218
protected:
219219
void InferShape(InferShapeContext* ctx) const override {}
220220

221-
framework::OpKernelType GetActualKernelType(
221+
framework::OpKernelType GetExpectedKernelType(
222222
const framework::ExecutionContext& ctx) const override {
223223
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
224224
}
@@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
282282
protected:
283283
void InferShape(InferShapeContext* ctx) const override {}
284284

285-
framework::OpKernelType GetActualKernelType(
286-
const framework::ExecutionContext& ctx) const override {
287-
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
288-
}
289-
290285
framework::OpKernelType GetExpectedKernelType(
291-
const framework::OpKernelType& kernel) const override {
292-
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
293-
kernel.data_layout_,
294-
framework::LibraryType::kCUDNN);
286+
const framework::ExecutionContext& ctx) const override {
287+
return framework::OpKernelType(
288+
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
289+
framework::LibraryType::kCUDNN);
295290
}
296291
};
297292

@@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
371366
op_desc.set_type("op_with_multi_kernel");
372367
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
373368

369+
// TODO(qiao) add priority back
374370
// use all available kernels
375371
paddle::framework::UseALL();
376372
op->Run(scope, cuda_place);
@@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
380376
paddle::framework::UseCPU();
381377
op->Run(scope, cpu_place);
382378

383-
EXPECT_EQ(op_test_value, -9);
379+
EXPECT_EQ(op_test_value, -20);
384380

385381
// add cuda kernels
386382
paddle::framework::UseCUDA();
387383
op->Run(scope, cuda_place);
388384

389-
EXPECT_EQ(op_test_value, -10);
385+
EXPECT_EQ(op_test_value, -30);
390386

391387
// use cudnn kernel
392388
paddle::framework::UseCUDNN();
393389
op->Run(scope, cuda_place);
394-
EXPECT_EQ(op_test_value, -20);
390+
EXPECT_EQ(op_test_value, -40);
395391
}

0 commit comments

Comments
 (0)