Skip to content

Commit dc488c1

Browse files
committed
Merge branch 'develop' of github.com:baidu/Paddle into feature/parallel_for_unittest
2 parents 12aca86 + 87f9b58 commit dc488c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1930
-845
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
3232
cc_library(scope SRCS scope.cc DEPS glog threadpool)
3333
cc_test(scope_test SRCS scope_test.cc DEPS scope)
3434

35-
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
35+
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
36+
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
37+
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
3638

37-
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
38-
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
39+
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
40+
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
3941

4042
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
4143
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
@@ -80,5 +82,5 @@ cc_test(init_test SRCS init_test.cc DEPS init)
8082

8183
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
8284
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
83-
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
85+
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
8486
DEPS operator op_registry init math_function)

paddle/framework/device_data_transform.cc renamed to paddle/framework/data_device_transform.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
See the License for the specific language governing permissions and
1212
limitations under the License. */
1313

14-
#include "paddle/framework/device_data_transform.h"
14+
#include "paddle/framework/data_device_transform.h"
1515

1616
namespace paddle {
1717
namespace framework {

paddle/framework/data_layout.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
#include "paddle/platform/enforce.h"
1716

1817
#include <iostream>
1918
#include "paddle/platform/enforce.h"
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/data_layout_transform.h"
16+
17+
#include "paddle/framework/tensor.h"
18+
#include "paddle/operators/math/math_function.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
23+
struct CastDataLayout {
24+
CastDataLayout(const platform::DeviceContext* ctx,
25+
const std::vector<int>& axis, const framework::Tensor& in,
26+
framework::Tensor* out)
27+
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
28+
const framework::Tensor in_;
29+
framework::Tensor* out_;
30+
const platform::DeviceContext* ctx_;
31+
const std::vector<int> axis_;
32+
33+
template <typename T>
34+
void operator()() {
35+
auto place = ctx_->GetPlace();
36+
37+
if (platform::is_cpu_place(place)) {
38+
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
39+
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
40+
trans4(*context, in_, out_, axis_);
41+
} else {
42+
PADDLE_THROW("Unsupport CPU <-> GPU!");
43+
}
44+
}
45+
};
46+
47+
void TransDataLayout(const std::vector<int>& axis,
48+
const platform::DeviceContext* ctx,
49+
const KernelTypePair& kernel_pair, const Variable& in,
50+
Variable* out) {
51+
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
52+
PADDLE_ENFORCE(
53+
platform::places_are_same_class(kernel_pair.first.place_,
54+
kernel_pair.second.place_),
55+
"TransDataLayout only support DataLayout transform on same place!");
56+
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
57+
"TransDataLayout only support Datatype are same!");
58+
59+
auto src = in.Get<Tensor>();
60+
auto* dst = out->GetMutable<Tensor>();
61+
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
62+
63+
auto src_dim = src.dims();
64+
std::vector<int64_t> dst_dim;
65+
66+
dst_dim.resize(axis.size());
67+
for (size_t i = 0; i < axis.size(); i++) {
68+
dst_dim[i] = src_dim[axis[i]];
69+
}
70+
71+
dst->Resize(make_ddim(dst_dim));
72+
auto place = kernel_pair.second.place_;
73+
dst->mutable_data(place, src.type());
74+
75+
auto src_type = kernel_pair.first.data_type_;
76+
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
77+
78+
dst->set_layout(kernel_pair.second.data_layout_);
79+
}
80+
81+
} // namespace framework
82+
} // namespace paddle
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/op_kernel_type.h"
18+
#include "paddle/framework/variable.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
23+
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
24+
25+
void TransDataLayout(const std::vector<int>& axis,
26+
const platform::DeviceContext* ctx,
27+
const KernelTypePair& kernel_pair, const Variable& in,
28+
Variable* out);
29+
30+
} // namespace framework
31+
} // namespace paddle

paddle/framework/data_transform.cc

Lines changed: 2 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include <functional>
1514

1615
#include "paddle/framework/data_transform.h"
17-
#include "paddle/framework/device_data_transform.h"
18-
#include "paddle/framework/lod_tensor.h"
19-
#include "paddle/framework/selected_rows.h"
20-
#include "paddle/platform/device_context.h"
16+
17+
#include "paddle/framework/data_device_transform.h"
2118

2219
namespace paddle {
2320
namespace framework {
2421

25-
DataTransformFnMap& DataTransformFnMap::Instance() {
26-
static DataTransformFnMap data_transform_map;
27-
return data_transform_map;
28-
}
29-
3022
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
3123
const OpKernelType& kernel_type_for_var,
3224
const Tensor& input_tensor) {
@@ -58,134 +50,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
5850
}
5951
}
6052

61-
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
62-
DataLayout::kNHWC, LibraryType::kPlain);
63-
64-
auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
65-
DataLayout::kNHWC, LibraryType::kPlain);
66-
67-
auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
68-
DataLayout::kNHWC, LibraryType::kPlain);
69-
70-
auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
71-
DataLayout::kNCHW, LibraryType::kPlain);
72-
73-
// TODO(dzhwinter): Only for testing multiple op kernel.
74-
// Dummy transform function for library_type
75-
// should be removed.
76-
auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
77-
DataLayout::kAnyLayout, LibraryType::kPlain);
78-
79-
auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
80-
DataLayout::kAnyLayout, LibraryType::kCUDNN);
81-
82-
void DummyTrans(const platform::DeviceContext* ctx,
83-
const KernelTypePair& kernel_pair, const Variable& in,
84-
Variable* out) {
85-
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
86-
PADDLE_ENFORCE(
87-
platform::places_are_same_class(kernel_pair.first.place_,
88-
kernel_pair.second.place_),
89-
"TransDataType Only Support DataType transform on same place!");
90-
auto src = in.Get<Tensor>();
91-
auto* dst = out->GetMutable<Tensor>();
92-
*dst = src;
93-
}
94-
95-
void TransDataType(const platform::DeviceContext* ctx,
96-
const KernelTypePair& kernel_pair, const Variable& in,
97-
Variable* out) {
98-
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
99-
PADDLE_ENFORCE(
100-
platform::places_are_same_class(kernel_pair.first.place_,
101-
kernel_pair.second.place_),
102-
"TransDataType Only Support DataType transform on same place!");
103-
104-
auto src = in.Get<Tensor>();
105-
auto* dst = out->GetMutable<Tensor>();
106-
107-
auto dims = src.dims();
108-
dst->Resize(dims);
109-
auto dst_type = kernel_pair.second.data_type_;
110-
auto src_type = kernel_pair.first.data_type_;
111-
112-
switch (src_type) {
113-
case proto::DataType::FP32:
114-
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx));
115-
break;
116-
case proto::DataType::FP64:
117-
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx));
118-
break;
119-
case proto::DataType::INT32:
120-
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
121-
break;
122-
case proto::DataType::INT64:
123-
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
124-
break;
125-
case proto::DataType::BOOL:
126-
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx));
127-
break;
128-
default:
129-
PADDLE_THROW("Not support type %d", src_type);
130-
}
131-
}
132-
133-
void TransDataLayout(const std::vector<int>& axis,
134-
const platform::DeviceContext* ctx,
135-
const KernelTypePair& kernel_pair, const Variable& in,
136-
Variable* out) {
137-
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
138-
PADDLE_ENFORCE(
139-
platform::places_are_same_class(kernel_pair.first.place_,
140-
kernel_pair.second.place_),
141-
"TransDataLayout only support DataLayout transform on same place!");
142-
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
143-
"TransDataLayout only support Datatype are same!");
144-
145-
auto src = in.Get<Tensor>();
146-
auto* dst = out->GetMutable<Tensor>();
147-
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
148-
149-
auto src_dim = src.dims();
150-
std::vector<int64_t> dst_dim;
151-
152-
dst_dim.resize(axis.size());
153-
for (size_t i = 0; i < axis.size(); i++) {
154-
dst_dim[i] = src_dim[axis[i]];
155-
}
156-
157-
dst->Resize(make_ddim(dst_dim));
158-
auto place = kernel_pair.second.place_;
159-
dst->mutable_data(place, src.type());
160-
161-
auto src_type = kernel_pair.first.data_type_;
162-
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
163-
164-
dst->set_layout(kernel_pair.second.data_layout_);
165-
}
166-
16753
} // namespace framework
16854
} // namespace paddle
169-
170-
namespace f = paddle::framework;
171-
172-
namespace {
173-
std::vector<int> NHWC2NCHW = {0, 3, 1, 2};
174-
std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
175-
}
176-
177-
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
178-
REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans);
179-
REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans);
180-
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW,
181-
std::bind(f::TransDataLayout, NHWC2NCHW,
182-
std::placeholders::_1,
183-
std::placeholders::_2,
184-
std::placeholders::_3,
185-
std::placeholders::_4));
186-
REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC,
187-
std::bind(f::TransDataLayout, NCHW2NHWC,
188-
std::placeholders::_1,
189-
std::placeholders::_2,
190-
std::placeholders::_3,
191-
std::placeholders::_4));

0 commit comments

Comments
 (0)