Skip to content

Commit 0071b5f

Browse files
authored
complete data layout transform (#7440)
* add data layout transform and optimize the implementation of data_transform
1 parent 9e17c46 commit 0071b5f

File tree

7 files changed

+120
-33
lines changed

7 files changed

+120
-33
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ cc_library(scope SRCS scope.cc DEPS glog threadpool)
3333
cc_test(scope_test SRCS scope_test.cc DEPS scope)
3434

3535
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
36+
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
37+
DEPS operator op_registry init math_function)
38+
3639
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
40+
3741
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
42+
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
3843

3944
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
4045
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
@@ -82,5 +87,3 @@ cc_test(init_test SRCS init_test.cc DEPS init)
8287

8388
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
8489
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
85-
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
86-
DEPS operator op_registry init math_function)

paddle/framework/data_device_transform_test.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ TEST(Operator, CPUtoGPU) {
150150
// get output
151151
auto* output2 = scope.Var("OUT2");
152152
gpu_op->Run(scope, cuda_place);
153+
VLOG(3) << "after gpu_op run";
153154

154155
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
155156
DeviceContextPool& pool = DeviceContextPool::Instance();

paddle/framework/data_layout_transform.cc

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -14,12 +14,23 @@ limitations under the License. */
1414

1515
#include "paddle/framework/data_layout_transform.h"
1616

17-
#include "paddle/framework/tensor.h"
1817
#include "paddle/operators/math/math_function.h"
1918

2019
namespace paddle {
2120
namespace framework {
2221

22+
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
23+
PADDLE_ENFORCE_NE(from, to,
24+
"layout transform should transform different layout");
25+
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
26+
return {0, 2, 3, 1};
27+
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
28+
return {0, 3, 1, 2};
29+
} else {
30+
PADDLE_THROW("unsupported transform");
31+
}
32+
}
33+
2334
struct CastDataLayout {
2435
CastDataLayout(const platform::DeviceContext* ctx,
2536
const std::vector<int>& axis, const framework::Tensor& in,
@@ -44,38 +55,36 @@ struct CastDataLayout {
4455
}
4556
};
4657

47-
void TransDataLayout(const std::vector<int>& axis,
48-
const platform::DeviceContext* ctx,
49-
const KernelTypePair& kernel_pair, const Variable& in,
50-
Variable* out) {
51-
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
58+
void TransDataLayout(const OpKernelType& kernel_type_for_var,
59+
const OpKernelType& expected_kernel_type, const Tensor& in,
60+
Tensor* out) {
5261
PADDLE_ENFORCE(
53-
platform::places_are_same_class(kernel_pair.first.place_,
54-
kernel_pair.second.place_),
62+
platform::places_are_same_class(kernel_type_for_var.place_,
63+
expected_kernel_type.place_),
5564
"TransDataLayout only support DataLayout transform on same place!");
56-
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
57-
"TransDataLayout only support Datatype are same!");
5865

59-
auto src = in.Get<Tensor>();
60-
auto* dst = out->GetMutable<Tensor>();
61-
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
66+
PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!");
67+
68+
auto& pool = platform::DeviceContextPool::Instance();
6269

63-
auto src_dim = src.dims();
70+
auto src_dim = in.dims();
6471
std::vector<int64_t> dst_dim;
6572

73+
auto axis = GetAxis(kernel_type_for_var.data_layout_,
74+
expected_kernel_type.data_layout_);
6675
dst_dim.resize(axis.size());
6776
for (size_t i = 0; i < axis.size(); i++) {
6877
dst_dim[i] = src_dim[axis[i]];
6978
}
7079

71-
dst->Resize(make_ddim(dst_dim));
72-
auto place = kernel_pair.second.place_;
73-
dst->mutable_data(place, src.type());
80+
out->Resize(make_ddim(dst_dim));
81+
out->mutable_data(expected_kernel_type.place_, in.type());
7482

75-
auto src_type = kernel_pair.first.data_type_;
76-
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
83+
framework::VisitDataType(
84+
framework::ToDataType(in.type()),
85+
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
7786

78-
dst->set_layout(kernel_pair.second.data_layout_);
87+
out->set_layout(expected_kernel_type.data_layout_);
7988
}
8089

8190
} // namespace framework

paddle/framework/data_layout_transform.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -15,17 +15,17 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include "paddle/framework/op_kernel_type.h"
18+
#include "paddle/framework/tensor.h"
1819
#include "paddle/framework/variable.h"
1920

2021
namespace paddle {
2122
namespace framework {
2223

23-
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
24+
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
2425

25-
void TransDataLayout(const std::vector<int>& axis,
26-
const platform::DeviceContext* ctx,
27-
const KernelTypePair& kernel_pair, const Variable& in,
28-
Variable* out);
26+
void TransDataLayout(const OpKernelType& kernel_type_for_var,
27+
const OpKernelType& expected_kernel_type, const Tensor& in,
28+
Tensor* out);
2929

3030
} // namespace framework
3131
} // namespace paddle
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/data_layout_transform.h"
16+
17+
#include "gtest/gtest.h"
18+
#include "paddle/platform/device_context.h"
19+
20+
TEST(DataTransform, DataLayoutFunction) {
21+
using namespace paddle::framework;
22+
using namespace paddle::platform;
23+
24+
auto place = CPUPlace();
25+
Tensor in = Tensor();
26+
Tensor out = Tensor();
27+
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
28+
in.set_layout(DataLayout::kNHWC);
29+
30+
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
31+
DataLayout::kNHWC, LibraryType::kPlain);
32+
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
33+
DataLayout::kNCHW, LibraryType::kPlain);
34+
35+
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
36+
37+
EXPECT_TRUE(out.layout() == DataLayout::kNCHW);
38+
EXPECT_TRUE(out.dims() == make_ddim({2, 2, 3, 1}));
39+
40+
TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out);
41+
42+
EXPECT_TRUE(in.layout() == DataLayout::kNHWC);
43+
EXPECT_TRUE(in.dims() == make_ddim({2, 3, 1, 2}));
44+
}

paddle/framework/data_transform.cc

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,43 @@ limitations under the License. */
1515
#include "paddle/framework/data_transform.h"
1616

1717
#include "paddle/framework/data_device_transform.h"
18+
#include "paddle/framework/data_layout_transform.h"
1819

1920
namespace paddle {
2021
namespace framework {
2122

23+
static void PassTensorData(Tensor* from, Tensor* to) {
24+
to->ShareDataWith(*from);
25+
*from = Tensor();
26+
}
27+
2228
void DataTransform(const OpKernelType& expected_kernel_type,
2329
const OpKernelType& kernel_type_for_var,
24-
const Tensor& input_tensor, Tensor* out) {
30+
const Tensor& input_tensor, Tensor* output_tensor) {
31+
bool transformed = false;
32+
Tensor in;
33+
in.ShareDataWith(input_tensor);
34+
Tensor out;
35+
36+
// do layout transform
37+
if (NeedTransformLayout(expected_kernel_type.data_layout_,
38+
kernel_type_for_var.data_layout_)) {
39+
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
40+
transformed = true;
41+
PassTensorData(&out, &in);
42+
}
43+
44+
// do device transform
2545
if (!platform::is_same_place(kernel_type_for_var.place_,
2646
expected_kernel_type.place_)) {
27-
DeviceTransform(input_tensor, expected_kernel_type.place_, out);
47+
DeviceTransform(in, expected_kernel_type.place_, &out);
48+
transformed = true;
49+
PassTensorData(&out, &in);
2850
}
29-
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
51+
52+
PADDLE_ENFORCE(transformed, "no transform is done, please check!");
53+
// get output data
54+
output_tensor->ShareDataWith(in);
3055
}
3156

3257
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,

paddle/framework/op_kernel_type.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
8585
return stream.str();
8686
}
8787

88+
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
89+
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
90+
}
91+
8892
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) {
8993
return (!platform::places_are_same_class(l.place_, r.place_)) ||
90-
(l.data_type_ != r.data_type_) || (l.data_layout_ != r.data_layout_);
94+
(l.data_type_ != r.data_type_) ||
95+
NeedTransformLayout(l.data_layout_, r.data_layout_);
9196
}
9297

9398
} // namespace framework

0 commit comments

Comments
 (0)