Skip to content

Commit 6250be4

Browse files
committed
Merge branch 'windows/build' into windows/online
test=develop
2 parents e0d47cc + 30849d1 commit 6250be4

19 files changed

+773
-65
lines changed

cmake/operators.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ function(op_library TARGET)
109109

110110
# Define operators that don't need pybind here.
111111
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
112-
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
112+
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
113+
"fusion_transpose_flatten_concat_op")
113114
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
114115
set(pybind_flag 1)
115116
endif()

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
116116
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
117117
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
118118

119-
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto)
119+
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
120120
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
121121
shape_inference data_transform lod_tensor profiler transfer_scope_cache)
122122

paddle/fluid/framework/transfer_scope_cache.cc

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,28 @@
1717
namespace paddle {
1818
namespace framework {
1919

20+
// Holds all the transfer scope across the process.
2021
std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
21-
thread_local auto* x = new std::unordered_map<size_t, Scope*>;
22+
typedef std::unordered_map<size_t, Scope*> map_t;
23+
thread_local std::unique_ptr<map_t> x(new map_t);
2224
return *x;
2325
}
2426

27+
// Holds all the transfer scope for this thread.
2528
std::unordered_set<Scope*>& global_transfer_scope_cache() {
26-
thread_local auto* x = new std::unordered_set<Scope*>;
29+
typedef std::unordered_set<Scope*> set_t;
30+
thread_local std::unique_ptr<set_t> x(new set_t);
2731
return *x;
2832
}
2933

34+
// Try to create a transfer scope. If one cached scope has match the
35+
// requirement, just return that one.
36+
// Inputs:
37+
// @type0: the source kernel type.
38+
// @type1: the target kernel type.
39+
// @scope: the execution scope of this op.
40+
// Returns: A scope used to hold the transfer data across the different kernel
41+
// type.
3042
Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
3143
const Scope* scope) {
3244
Scope* new_scope{nullptr};
@@ -46,27 +58,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
4658
return new_scope;
4759
}
4860

49-
void RemoveKidsFromTransferScopeCache(Scope* scope) {
50-
auto it = global_transfer_scope_cache().find(scope);
51-
if (it != global_transfer_scope_cache().end()) {
52-
global_transfer_scope_cache().erase(it);
53-
}
54-
for (auto* s : scope->kids()) {
55-
auto it = global_transfer_scope_cache().find(s);
56-
if (it != global_transfer_scope_cache().end()) {
57-
global_transfer_scope_cache().erase(it);
58-
}
59-
}
60-
61-
// remove global transfer data cache
62-
auto& cache = global_transfer_data_cache();
63-
for (auto it = cache.begin(); it != cache.end();) {
64-
if (it->second == scope)
65-
it = cache.erase(it);
66-
else
67-
it++;
68-
}
69-
}
70-
7161
} // namespace framework
7262
} // namespace paddle

paddle/fluid/memory/allocation/retry_allocator_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST(RetryAllocator, RetryAllocator) {
4141

4242
size_t thread_num = 32;
4343
size_t sleep_time = 40;
44-
size_t extra_time = 2;
44+
size_t extra_time = 10;
4545

4646
// Reserve to perform more tests in the future
4747
std::vector<std::shared_ptr<Allocator>> allocators;
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
include(operators)
2-
register_operators()
2+
register_operators(EXCLUDES fusion_transpose_flatten_concat_op)
3+
if (WITH_GPU)
4+
op_library(fusion_transpose_flatten_concat_op)
5+
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
6+
endif()
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h"
16+
#include <string>
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using framework::Tensor;
24+
25+
class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel {
26+
public:
27+
using framework::OperatorWithKernel::OperatorWithKernel;
28+
29+
void InferShape(framework::InferShapeContext *ctx) const override {
30+
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
31+
"Inputs(X) of ConcatOp should be empty.");
32+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
33+
"Output(Out) of ConcatOp should not be null.");
34+
35+
auto ins = ctx->GetInputsDim("X");
36+
const size_t n = ins.size();
37+
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
38+
39+
std::vector<int> trans_axis =
40+
ctx->Attrs().Get<std::vector<int>>("trans_axis");
41+
int flatten_axis = ctx->Attrs().Get<int>("flatten_axis");
42+
int concat_axis = ctx->Attrs().Get<int>("concat_axis");
43+
44+
size_t x_rank = ins[0].size();
45+
size_t trans_axis_size = trans_axis.size();
46+
PADDLE_ENFORCE_EQ(x_rank, trans_axis_size,
47+
"The input tensor's rank(%d) "
48+
"should be equal to the permutation axis's size(%d)",
49+
x_rank, trans_axis_size);
50+
51+
auto dims0 =
52+
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0]));
53+
std::vector<int> out_dims(dims0);
54+
for (size_t i = 1; i < n; i++) {
55+
auto dimsi =
56+
GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[i]));
57+
for (int j = 0; j < static_cast<int>(dims0.size()); j++) {
58+
if (j == concat_axis) {
59+
out_dims[concat_axis] += dimsi[j];
60+
} else {
61+
PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j],
62+
"After flatting, the %d-th dim should be save "
63+
"except the specify axis.",
64+
j);
65+
}
66+
}
67+
}
68+
if (out_dims[concat_axis] < 0) {
69+
out_dims[concat_axis] = -1;
70+
}
71+
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
72+
}
73+
};
74+
75+
class TransposeFlattenConcatFusionOpMaker
76+
: public framework::OpProtoAndCheckerMaker {
77+
public:
78+
void Make() override {
79+
AddInput(
80+
"X",
81+
"(Tensor) The input tensor, tensors with rank up to 6 are supported.")
82+
.AsDuplicable();
83+
AddOutput("Out", "(Tensor)The output tensor.");
84+
AddAttr<std::vector<int>>(
85+
"trans_axis",
86+
"(vector<int>) A list of values, and the size of the list should be "
87+
"the same with the input tensor rank. This operator permutes the input "
88+
"tensor's axes according to the values given.");
89+
AddAttr<int>("flatten_axis",
90+
"(int)"
91+
"Indicate up to which input dimensions (exclusive) should be"
92+
"flattened to the outer dimension of the output. The value"
93+
"for axis must be in the range [0, R], where R is the rank of"
94+
"the input tensor. When axis = 0, the shape of the output"
95+
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
96+
"input tensor is (d_0, d_1, ... d_n).");
97+
AddAttr<int>("concat_axis",
98+
"The axis along which the input tensors will be concatenated. "
99+
"It should be 0 or 1, since the tensor is 2D after flatting.");
100+
AddComment(R"DOC(
101+
102+
103+
)DOC");
104+
}
105+
};
106+
107+
} // namespace operators
108+
} // namespace paddle
109+
110+
namespace ops = paddle::operators;
111+
REGISTER_OPERATOR(fusion_transpose_flatten_concat,
112+
ops::TransposeFlattenConcatFusionOp,
113+
ops::TransposeFlattenConcatFusionOpMaker,
114+
paddle::framework::EmptyGradOpMaker);
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h"
16+
#include <vector>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/platform/cudnn_helper.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
using CudnnDataType = platform::CudnnDataType<T>;
25+
26+
template <typename T>
27+
class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext& ctx) const override {
30+
auto ins = ctx.MultiInput<framework::Tensor>("X");
31+
auto* out = ctx.Output<framework::Tensor>("Out");
32+
out->mutable_data<T>(ctx.GetPlace());
33+
auto odims = out->dims();
34+
35+
std::vector<int> trans_axis = ctx.Attr<std::vector<int>>("trans_axis");
36+
int flatten_axis = ctx.Attr<int>("flatten_axis");
37+
int concat_axis = ctx.Attr<int>("concat_axis");
38+
39+
int rank = ins[0]->dims().size();
40+
// use at least 4D in cudnnTransformTensor
41+
int max_dim = rank < 4 ? 4 : rank;
42+
std::vector<int> stride_x(max_dim, 0);
43+
std::vector<int> stride_y(max_dim, 0);
44+
std::vector<int> dims_y(max_dim, 0);
45+
46+
cudnnTensorDescriptor_t in_desc;
47+
cudnnTensorDescriptor_t out_desc;
48+
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&in_desc));
49+
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&out_desc));
50+
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
51+
52+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
53+
auto handle = dev_ctx.cudnn_handle();
54+
55+
T* odata = out->data<T>();
56+
for (size_t k = 0; k < ins.size(); ++k) {
57+
auto perm_shape = GetPermuteShape(trans_axis, ins[k]->dims());
58+
int osize = 1;
59+
auto idims = ins[k]->dims();
60+
for (int i = 0; i < rank; i++) {
61+
stride_x[i] = 1;
62+
for (int j = trans_axis[i] + 1; j < rank; j++) {
63+
stride_x[i] *= idims[j];
64+
}
65+
dims_y[i] = perm_shape[i];
66+
osize *= perm_shape[i];
67+
}
68+
stride_y[rank - 1] = 1;
69+
for (int i = rank - 2; i >= 0; i--) {
70+
if (((i + 1) == flatten_axis) && (concat_axis == 1)) {
71+
stride_y[i] = odims[1];
72+
} else {
73+
stride_y[i] = stride_y[i + 1] * perm_shape[i + 1];
74+
}
75+
}
76+
77+
// Since concat is aftern flatten, the output is 2D tensor.
78+
// If concat_axis is 0, each input's permutated tensor is continuous.
79+
// If concat_axis is 1, the stride of 0-th dim of each input's
80+
// permutated tensor is odims()[1].
81+
82+
for (int i = rank; i < max_dim; i++) {
83+
stride_x[i] = 1;
84+
stride_y[i] = 1;
85+
dims_y[i] = 1;
86+
}
87+
88+
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
89+
in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data()));
90+
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
91+
out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data()));
92+
93+
CUDNN_ENFORCE(platform::dynload::cudnnTransformTensor(
94+
handle, CudnnDataType<T>::kOne(), in_desc,
95+
static_cast<const void*>(ins[k]->data<T>()),
96+
CudnnDataType<T>::kZero(), out_desc, static_cast<void*>(odata)));
97+
if (concat_axis == 0) {
98+
odata += osize;
99+
} else {
100+
auto flat_shape = GetFlattenShape(flatten_axis, perm_shape);
101+
odata += flat_shape[1];
102+
}
103+
}
104+
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(in_desc));
105+
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(out_desc));
106+
}
107+
};
108+
109+
} // namespace operators
110+
} // namespace paddle
111+
112+
namespace ops = paddle::operators;
113+
REGISTER_OP_CUDA_KERNEL(fusion_transpose_flatten_concat,
114+
ops::TransposeFlattenConcatFusionKernel<float>,
115+
ops::TransposeFlattenConcatFusionKernel<double>);
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/ddim.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
inline std::vector<int32_t> GetPermuteShape(const std::vector<int>& axis,
25+
const framework::DDim& in_dims) {
26+
std::vector<int32_t> out_dims(in_dims.size());
27+
for (size_t i = 0; i < axis.size(); i++) {
28+
out_dims[i] = in_dims[axis[i]];
29+
}
30+
return out_dims;
31+
}
32+
33+
inline std::vector<int32_t> GetFlattenShape(const int axis,
34+
const std::vector<int>& in_dims) {
35+
int64_t outer = 1, inner = 1;
36+
for (int i = 0; i < static_cast<int>(in_dims.size()); ++i) {
37+
if (i < axis) {
38+
outer *= in_dims[i];
39+
} else {
40+
inner *= in_dims[i];
41+
}
42+
}
43+
std::vector<int32_t> out_shape(2);
44+
out_shape[0] = outer;
45+
out_shape[1] = inner;
46+
return out_shape;
47+
}
48+
49+
} // namespace operators
50+
} // namespace paddle

paddle/fluid/operators/lookup_sparse_table_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
6767
framework::proto::VarType::FP32,
6868
"The sparse table only support FP32");
6969
w_t->Get(ids_t, out_t, true, is_test);
70+
out_t->set_lod(ids_t.lod());
7071
}
7172
};
7273

paddle/fluid/operators/sum_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel<T> {
127127
math::scatter::MergeAdd<DeviceContext, T> merge_add;
128128
merge_add(context.template device_context<DeviceContext>(), inputs,
129129
out);
130+
131+
out->SyncIndex();
132+
130133
} else {
131134
// no data, just set a empty out tensor.
132135
out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),

0 commit comments

Comments
 (0)