Skip to content

Commit 85671b8

Browse files
jacquesqiaodzhwinter
authored andcommitted
Data type transform (#7653)
* init complete data layout transform * can compile * test passed * optimize code * fix while_grad_op first step loss lod problem * optimize in out ptr for transform * add check * update copyright * clean code * add NeedTransformLayout * add comment * change the interface of data_type_transform * init data_type_transform_test * complete data_type_transform_test * add TransDataType to data_transform
1 parent 02add30 commit 85671b8

File tree

7 files changed

+85
-33
lines changed

7 files changed

+85
-33
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu
3737
DEPS operator op_registry init math_function)
3838

3939
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
40+
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
4041

4142
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
4243
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)

paddle/framework/data_device_transform.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ static const platform::DeviceContext* GetDeviceContext(
3131
}
3232
}
3333

34-
void DeviceTransform(const Tensor& in, const platform::Place& dst_place,
34+
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
3535
Tensor* out) {
3636
VLOG(3) << "DeviceTransform in, src_place " << in.place()
3737
<< " dst_place: " << dst_place;

paddle/framework/data_device_transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License. */
2121
namespace paddle {
2222
namespace framework {
2323

24-
void DeviceTransform(const Tensor& in, const platform::Place& dst_place,
24+
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
2525
Tensor* out);
2626

2727
} // namespace framework

paddle/framework/data_transform.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/framework/data_device_transform.h"
1818
#include "paddle/framework/data_layout_transform.h"
19+
#include "paddle/framework/data_type_transform.h"
1920

2021
namespace paddle {
2122
namespace framework {
@@ -41,15 +42,21 @@ void DataTransform(const OpKernelType& expected_kernel_type,
4142
PassTensorData(&out, &in);
4243
}
4344

45+
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
46+
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
47+
transformed = true;
48+
PassTensorData(&out, &in);
49+
}
50+
4451
// do device transform
4552
if (!platform::is_same_place(kernel_type_for_var.place_,
4653
expected_kernel_type.place_)) {
47-
DeviceTransform(in, expected_kernel_type.place_, &out);
54+
TransDataDevice(in, expected_kernel_type.place_, &out);
4855
transformed = true;
4956
PassTensorData(&out, &in);
5057
}
5158

52-
PADDLE_ENFORCE(transformed, "no transform is done, please check!");
59+
PADDLE_ENFORCE(transformed, "No transform is applied, please check!");
5360
// get output data
5461
output_tensor->ShareDataWith(in);
5562
}

paddle/framework/data_type_transform.cc

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,11 @@ struct CastDataType {
3838

3939
template <typename OutType>
4040
void operator()() {
41-
auto place = ctx_->GetPlace();
42-
4341
auto* in_begin = in_.data<InType>();
44-
auto numel = in_.numel();
45-
auto* in_end = in_begin + numel;
46-
auto* out_begin = out_->mutable_data<OutType>(place);
42+
auto* in_end = in_begin + in_.numel();
43+
auto* out_begin = out_->mutable_data<OutType>(in_.place());
4744

48-
if (platform::is_cpu_place(place)) {
45+
if (platform::is_cpu_place(in_.place())) {
4946
platform::Transform<platform::CPUDeviceContext> trans;
5047
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
5148
trans(*context, in_begin, in_end, out_begin,
@@ -57,38 +54,31 @@ struct CastDataType {
5754
}
5855
};
5956

60-
void TransDataType(const platform::DeviceContext* ctx,
61-
const KernelTypePair& kernel_pair, const Variable& in,
62-
Variable* out) {
63-
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
64-
PADDLE_ENFORCE(
65-
platform::places_are_same_class(kernel_pair.first.place_,
66-
kernel_pair.second.place_),
67-
"TransDataType Only Support DataType transform on same place!");
68-
69-
auto src = in.Get<Tensor>();
70-
auto* dst = out->GetMutable<Tensor>();
57+
void TransDataType(const OpKernelType& kernel_type_for_var,
58+
const OpKernelType& expected_kernel_type, const Tensor& in,
59+
Tensor* out) {
60+
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
7161

72-
auto dims = src.dims();
73-
dst->Resize(dims);
74-
auto dst_type = kernel_pair.second.data_type_;
75-
auto src_type = kernel_pair.first.data_type_;
62+
out->Resize(in.dims());
63+
auto src_type = kernel_type_for_var.data_type_;
64+
auto dst_type = expected_kernel_type.data_type_;
65+
auto ctx = pool.Get(in.place());
7666

7767
switch (src_type) {
7868
case proto::DataType::FP32:
79-
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx));
69+
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
8070
break;
8171
case proto::DataType::FP64:
82-
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx));
72+
framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
8373
break;
8474
case proto::DataType::INT32:
85-
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
75+
framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
8676
break;
8777
case proto::DataType::INT64:
88-
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
78+
framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
8979
break;
9080
case proto::DataType::BOOL:
91-
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx));
81+
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
9282
break;
9383
default:
9484
PADDLE_THROW("Not support type %d", src_type);

paddle/framework/data_type_transform.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include "paddle/framework/op_kernel_type.h"
18+
#include "paddle/framework/tensor.h"
1819
#include "paddle/framework/variable.h"
1920
#include "paddle/platform/device_context.h"
2021

@@ -23,9 +24,9 @@ namespace framework {
2324

2425
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
2526

26-
void TransDataType(const platform::DeviceContext* ctx,
27-
const KernelTypePair& kernel_pair, const Variable& in,
28-
Variable* out);
27+
void TransDataType(const OpKernelType& kernel_type_for_var,
28+
const OpKernelType& expected_kernel_type, const Tensor& in,
29+
Tensor* out);
2930

3031
} // namespace framework
3132
} // namespace paddle
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/data_type_transform.h"
16+
17+
#include "gtest/gtest.h"
18+
19+
TEST(DataTypeTransform, CPUTransform) {
20+
using namespace paddle::framework;
21+
using namespace paddle::platform;
22+
23+
auto place = CPUPlace();
24+
25+
Tensor in;
26+
Tensor out;
27+
28+
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
29+
int data_number = 2 * 3;
30+
31+
for (int i = 0; i < data_number; ++i) {
32+
ptr[i] = i / 3;
33+
}
34+
35+
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
36+
DataLayout::kAnyLayout, LibraryType::kPlain);
37+
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
38+
DataLayout::kAnyLayout, LibraryType::kPlain);
39+
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
40+
DataLayout::kAnyLayout, LibraryType::kPlain);
41+
42+
TransDataType(kernel_fp32, kernel_fp64, in, &out);
43+
double* out_data_double = out.data<double>();
44+
for (int i = 0; i < data_number; ++i) {
45+
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
46+
}
47+
48+
TransDataType(kernel_fp32, kernel_int32, in, &out);
49+
int* out_data_int = out.data<int>();
50+
for (int i = 0; i < data_number; ++i) {
51+
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
52+
}
53+
}

0 commit comments

Comments
 (0)