Skip to content

Commit 0c8fde7

Browse files
authored
"cherry picked cpp tests" (#12182)
* "cherry picked cpp tests" * "cherry picked" * "cherry picked tests" * "merge develop branch"
1 parent 595a2c8 commit 0c8fde7

File tree

9 files changed

+163
-21
lines changed

9 files changed

+163
-21
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
77
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
88
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
99
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
10+
cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor)
1011
if(WITH_GPU)
1112
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context)
1213
else()

paddle/fluid/framework/data_type.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <string>
1818
#include <unordered_map>
1919

20+
using float16 = paddle::platform::float16;
21+
2022
namespace paddle {
2123
namespace framework {
2224

@@ -53,7 +55,7 @@ static DataTypeMap* InitDataTypeMap() {
5355
RegisterType<cc_type>(retv, proto_type, #cc_type)
5456

5557
// NOTE: Add your customize type here.
56-
RegType(platform::float16, proto::VarType::FP16);
58+
RegType(float16, proto::VarType::FP16);
5759
RegType(float, proto::VarType::FP32);
5860
RegType(double, proto::VarType::FP64);
5961
RegType(int, proto::VarType::INT32);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/framework/data_type.h"
15+
16+
#include <string>
17+
#include "gtest/gtest.h"
18+
#include "paddle/fluid/framework/tensor.h"
19+
20+
TEST(DataType, float16) {
21+
using paddle::framework::Tensor;
22+
using paddle::platform::CPUPlace;
23+
using paddle::platform::float16;
24+
namespace f = paddle::framework;
25+
f::proto::VarType::Type dtype = f::proto::VarType::FP16;
26+
27+
Tensor tensor;
28+
CPUPlace cpu;
29+
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));
30+
31+
// test fp16 tensor
32+
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));
33+
34+
// test fp16 size
35+
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);
36+
37+
// test debug info
38+
std::string type = "float16";
39+
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
40+
}

paddle/fluid/framework/op_kernel_type_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ TEST(OpKernelType, ToString) {
2929
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
3030
"data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
3131
"CUDNN]");
32+
33+
using CUDAPlace = paddle::platform::CUDAPlace;
34+
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
35+
LibraryType::kCUDNN);
36+
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
37+
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_"
38+
"type[CUDNN]");
3239
}
3340

3441
TEST(OpKernelType, Hash) {

paddle/fluid/framework/operator.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ static DDim GetDims(const Scope& scope, const std::string& name,
6969
}
7070
}
7171

72+
static std::string GetDtype(const Scope& scope, const std::string& name) {
73+
Variable* var = scope.FindVar(name);
74+
if (var == nullptr) {
75+
return "";
76+
}
77+
if (var->IsType<LoDTensor>()) {
78+
return DataTypeToString(ToDataType(var->Get<LoDTensor>().type()));
79+
} else if (var->IsType<SelectedRows>()) {
80+
return DataTypeToString(
81+
ToDataType(var->Get<SelectedRows>().value().type()));
82+
} else {
83+
return "";
84+
}
85+
}
86+
7287
static int GetRowSize(const Scope& scope, const std::string& name) {
7388
Variable* var = scope.FindVar(name);
7489
if (var == nullptr) {
@@ -172,6 +187,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
172187
if (row_size >= 0) {
173188
ss << "[row_size=" << row_size << "]";
174189
}
190+
std::string dtype = GetDtype(*scope, input.second[i]);
191+
ss << ":" << dtype;
175192
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
176193
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
177194
}

paddle/fluid/framework/tensor_test.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/tensor.h"
1616
#include <gtest/gtest.h>
1717
#include <string>
18+
#include "paddle/fluid/platform/float16.h"
1819

1920
namespace framework = paddle::framework;
2021
namespace platform = paddle::platform;
@@ -213,3 +214,17 @@ TEST(Tensor, Layout) {
213214
src.set_layout(framework::DataLayout::kAnyLayout);
214215
ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
215216
}
217+
218+
TEST(Tensor, FP16) {
219+
using platform::float16;
220+
framework::Tensor src;
221+
float16* src_ptr = src.mutable_data<float16>({2, 3}, platform::CPUPlace());
222+
for (int i = 0; i < 2 * 3; ++i) {
223+
src_ptr[i] = static_cast<float16>(i);
224+
}
225+
EXPECT_EQ(src.memory_size(), 2 * 3 * sizeof(float16));
226+
// EXPECT a human readable error message
227+
// src.data<uint8_t>();
228+
// Tensor holds the wrong type, it holds N6paddle8platform7float16E at
229+
// [/paddle/Paddle/paddle/fluid/framework/tensor_impl.h:43]
230+
}

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,35 @@ def get_output():
6666
tensor_to_check_dtype = np.float32
6767
elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
6868
tensor_to_check_dtype = np.float64
69+
elif tensor_to_check_dtype == core.VarDesc.VarType.FP16:
70+
tensor_to_check_dtype = np.float16
71+
# set delta as np.float16, will automatic convert to float32, float64
72+
delta = np.array(delta).astype(np.float16)
6973
else:
7074
raise ValueError("Not supported data type " + str(
7175
tensor_to_check_dtype))
7276

7377
gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)
7478

7579
def __get_elem__(tensor, i):
76-
if tensor_to_check_dtype == np.float32:
80+
if tensor_to_check_dtype == np.float16:
81+
numpy_tensor = np.array(tensor).astype(np.float16)
82+
numpy_tensor = numpy_tensor.flatten()
83+
return numpy_tensor[i]
84+
elif tensor_to_check_dtype == np.float32:
7785
return tensor._get_float_element(i)
7886
else:
7987
return tensor._get_double_element(i)
8088

8189
def __set_elem__(tensor, i, e):
82-
if tensor_to_check_dtype == np.float32:
90+
if tensor_to_check_dtype == np.float16:
91+
numpy_tensor = np.array(tensor).astype(np.float16)
92+
shape = numpy_tensor.shape
93+
numpy_tensor = numpy_tensor.flatten()
94+
numpy_tensor[i] = e
95+
numpy_tensor = numpy_tensor.reshape(shape).view(np.uint16)
96+
tensor.set(numpy_tensor, place)
97+
elif tensor_to_check_dtype == np.float32:
8398
tensor._set_float_element(i, e)
8499
else:
85100
tensor._set_double_element(i, e)
@@ -133,6 +148,11 @@ def try_call_once(self, data_type):
133148
if not self.call_once:
134149
self.call_once = True
135150
self.dtype = data_type
151+
# See the comment of np_dtype_to_fluid_dtype
152+
# If the input type is uint16, we assume use float16
153+
# for lodtensor dtype.
154+
if self.dtype == np.uint16:
155+
self.dtype == np.float16
136156

137157
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
138158
def infer_dtype(numpy_dict):
@@ -161,19 +181,25 @@ def feed_var(self, input_vars, place):
161181
for name, np_value in self.inputs[var_name]:
162182
tensor = core.LoDTensor()
163183
if isinstance(np_value, tuple):
164-
tensor.set(np_value[0], place)
184+
tensor.set(
185+
OpTest.np_value_to_fluid_value(np_value[0]), place)
165186
tensor.set_recursive_sequence_lengths(np_value[1])
166187
else:
167-
tensor.set(np_value, place)
188+
tensor.set(
189+
OpTest.np_value_to_fluid_value(np_value), place)
168190
feed_map[name] = tensor
169191
else:
170192
tensor = core.LoDTensor()
171193
if isinstance(self.inputs[var_name], tuple):
172-
tensor.set(self.inputs[var_name][0], place)
194+
tensor.set(
195+
OpTest.np_value_to_fluid_value(self.inputs[var_name][
196+
0]), place)
173197
tensor.set_recursive_sequence_lengths(self.inputs[var_name][
174198
1])
175199
else:
176-
tensor.set(self.inputs[var_name], place)
200+
tensor.set(
201+
OpTest.np_value_to_fluid_value(self.inputs[var_name]),
202+
place)
177203
feed_map[var_name] = tensor
178204

179205
return feed_map
@@ -307,13 +333,22 @@ def find_actual(target_name, fetch_list):
307333
np.allclose(
308334
actual_t, expect_t, atol=atol),
309335
"Output (" + out_name + ") has diff at " + str(place) +
310-
str(actual_t) + "\n" + str(expect_t))
336+
"\nExpect " + str(expect_t) + "\n" + "But Got" +
337+
str(actual_t))
311338
if isinstance(expect, tuple):
312339
self.assertListEqual(actual.recursive_sequence_lengths(),
313340
expect[1], "Output (" + out_name +
314341
") has different lod at " + str(place))
315342

316343
def _get_places(self):
344+
if self.dtype == np.float16:
345+
if core.is_compiled_with_cuda() and core.op_support_gpu(
346+
self.op_type):
347+
place = core.CUDAPlace(0)
348+
if core.is_float16_supported(place):
349+
return [place]
350+
else:
351+
return []
317352
places = [fluid.CPUPlace()]
318353
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
319354
places.append(core.CUDAPlace(0))
@@ -344,9 +379,9 @@ def __assert_is_close(self, numeric_grads, analytic_grads, names,
344379
def err_msg():
345380
offset = np.argmax(diff_mat > max_relative_error)
346381
return ("%s Variable %s max gradient diff %f over limit %f, "
347-
"the first error element is %d, %f, %f") % (
348-
msg_prefix, name, max_diff, max_relative_error,
349-
offset, a.flatten()[offset], b.flatten()[offset])
382+
"the first error element is %d, expected %f, but got %f"
383+
) % (msg_prefix, name, max_diff, max_relative_error,
384+
offset, a.flatten()[offset], b.flatten()[offset])
350385

351386
self.assertLessEqual(max_diff, max_relative_error, err_msg())
352387

@@ -435,6 +470,21 @@ def np_dtype_to_fluid_dtype(input):
435470
input.dtype = np.uint16
436471
return input
437472

473+
@staticmethod
474+
def fluid_dtype_to_np_dtype(self, dtype):
475+
"""
476+
See above, convert the dtype to normal type.
477+
"""
478+
if dtype == np.uint16:
479+
dtype = np.float16
480+
return dtype
481+
482+
@staticmethod
483+
def np_value_to_fluid_value(input):
484+
if input.dtype == np.float16:
485+
input = input.view(np.uint16)
486+
return input
487+
438488
def _get_gradient(self,
439489
input_to_check,
440490
place,
@@ -457,7 +507,7 @@ def _get_gradient(self,
457507
if isinstance(place, fluid.CUDAPlace(0)):
458508
use_cuda = True
459509
executor = fluid.ParallelExecutor(
460-
use_cuda=use_cuda, loss_name=loss.name, main_program=program)
510+
use_cuda=use_cuda, loss_name=loss.name, main_program=prog)
461511
else:
462512
executor = Executor(place)
463513
return map(np.array,

python/paddle/fluid/tests/unittests/test_hsigmoid_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import math
1818
from op_test import OpTest
1919

20+
np.random.seed(100)
21+
2022

2123
def find_latest_set(num):
2224
return 1 + int(math.floor(math.log(num, 2)))

python/paddle/fluid/tests/unittests/testsuite.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@
1818
from paddle.fluid.op import Operator
1919

2020

21-
def as_lodtensor(np_array, lod, place):
22-
tensor = core.LoDTensor()
23-
tensor.set(np_value, place)
24-
if lod is not None:
25-
tensor.set_recursive_sequence_lengths(lod)
26-
return tensor
27-
28-
2921
def create_op(scope, op_type, inputs, outputs, attrs):
3022
kwargs = dict()
3123

@@ -69,14 +61,19 @@ def __create_var__(name, var_name):
6961

7062

7163
def set_input(scope, op, inputs, place):
64+
def np_value_to_fluid_value(input):
65+
if input.dtype == np.float16:
66+
input = input.view(np.uint16)
67+
return input
68+
7269
def __set_input__(var_name, var):
7370
if isinstance(var, tuple) or isinstance(var, np.ndarray):
7471
tensor = scope.find_var(var_name).get_tensor()
7572
if isinstance(var, tuple):
7673
tensor.set_recursive_sequence_lengths(var[1])
7774
var = var[0]
7875
tensor._set_dims(var.shape)
79-
tensor.set(var, place)
76+
tensor.set(np_value_to_fluid_value(var), place)
8077
elif isinstance(var, float):
8178
scope.find_var(var_name).set_float(var)
8279
elif isinstance(var, int):
@@ -104,6 +101,7 @@ def create_var(block, name, np_list, var_proto):
104101
if name not in np_list:
105102
assert var_proto.intermediate, "{} not found".format(name)
106103
else:
104+
# inferece the dtype from numpy value.
107105
np_value = np_list[name]
108106
if isinstance(np_value, tuple):
109107
dtype = np_value[0].dtype
@@ -116,6 +114,16 @@ def create_var(block, name, np_list, var_proto):
116114
if is_input:
117115
shape = list(np_value.shape)
118116
lod_level = 0
117+
# NOTE(dzhwinter): type hacking
118+
# numpy float16 is binded to paddle::platform::float16
119+
# in tensor_py.h via the help of uint16 datatype. Because
120+
# the internal memory representation of float16 is
121+
# actually uint16_t in paddle. So we use np.uint16 in numpy for
122+
# raw memory, it can pass through the pybind. So in the testcase,
123+
# we feed data use data.view(uint16), but the dtype is float16 in fact.
124+
# The data.view(uint16) means do not cast the data type, but process data as the uint16
125+
if dtype == np.uint16:
126+
dtype = np.float16
119127
return block.create_var(
120128
dtype=dtype, shape=shape, lod_level=lod_level, name=name)
121129

0 commit comments

Comments
 (0)