Skip to content

Commit c5b6573

Browse files
author
chengduo
authored
Fix input<tensor> (#14208)
* fix input<tensor> test=develop * fix split_ids test=develop * ElementwiseMul should not support SelectedRows * fix scale op test=develop * change GetTensorFromVar() method to GetTensorOrSelectedRowsFromVar() * fix operator * refine MultiOutput * fix MultiOutput test=develop * disable test_dist_save_load test=develop * fix elementwise_op test=develop * add get_sparse_as_op test=develop * add info for check test=develop * rename get_sparse_as_op with extract_rows_as_op. test=develop * elementwise doesn't support selected_rows * fix regularizer * remove extract_rows_as test=develop * fix ci test=develop * add test for sum_op * fix regularizer test=develop * test=develop * fix pserver weight decay multi inputs test=develop
1 parent 813e54e commit c5b6573

24 files changed

+240
-393
lines changed

paddle/fluid/framework/details/multi_devices_graph_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,12 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
648648
const ir::Graph &graph, const std::string &varname,
649649
const std::unordered_map<std::string, int> &sharded_var_device) const {
650650
auto got = sharded_var_device.find(varname);
651+
if (got == sharded_var_device.end()) {
652+
auto pos = varname.find(framework::kNewGradSuffix);
653+
if (pos != std::string::npos) {
654+
got = sharded_var_device.find(varname.substr(0, pos));
655+
}
656+
}
651657
return got == sharded_var_device.end() ? -1 : got->second;
652658
}
653659

paddle/fluid/framework/operator.cc

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) {
358358
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
359359
}
360360

361-
const Tensor* GetTensorFromVar(const Variable& var) {
361+
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
362362
if (var.IsType<LoDTensor>()) {
363363
return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
364364
} else if (var.IsType<SelectedRows>()) {
@@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) {
369369
}
370370
}
371371

372-
static Tensor* GetMutableTensorFromVar(Variable* var) {
372+
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
373373
if (var->IsType<LoDTensor>()) {
374374
return var->GetMutable<LoDTensor>();
375375
} else if (var->IsType<SelectedRows>()) {
@@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
414414

415415
template <>
416416
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
417-
auto* var = InputVar(name);
418-
return var == nullptr ? nullptr : GetTensorFromVar(*var);
417+
return Input<LoDTensor>(name);
419418
}
420419

421420
template <>
@@ -425,17 +424,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
425424
std::vector<const Tensor*> res;
426425
res.reserve(names.size());
427426
std::transform(names.begin(), names.end(), std::back_inserter(res),
428-
[&](const std::string& sub_name) {
427+
[&](const std::string& sub_name) -> const Tensor* {
429428
auto var = scope_.FindVar(sub_name);
430-
return var == nullptr ? nullptr : GetTensorFromVar(*var);
429+
if (var == nullptr) return nullptr;
430+
PADDLE_ENFORCE(
431+
var->IsType<LoDTensor>(),
432+
"%s should be LoDTensor, but the received type is %s",
433+
sub_name, var->Type().name());
434+
return &(var->Get<LoDTensor>());
431435
});
432436
return res;
433437
}
434438

435439
template <>
436440
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
437-
auto var = OutputVar(name);
438-
return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
441+
return Output<LoDTensor>(name);
439442
}
440443

441444
template <>
@@ -445,10 +448,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
445448
std::vector<Tensor*> res;
446449
res.reserve(names.size());
447450
std::transform(names.begin(), names.end(), std::back_inserter(res),
448-
[&](const std::string& sub_name) {
451+
[&](const std::string& sub_name) -> Tensor* {
449452
auto var = scope_.FindVar(sub_name);
450-
return var == nullptr ? nullptr
451-
: GetMutableTensorFromVar(var);
453+
if (var == nullptr) return nullptr;
454+
PADDLE_ENFORCE(
455+
var->IsType<LoDTensor>(),
456+
"%s should be LoDTensor, but the received type is %s",
457+
sub_name, var->Type().name());
458+
return var->GetMutable<LoDTensor>();
452459
});
453460
return res;
454461
}
@@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
768775
const Scope& transfer_scope) const {
769776
for (auto& var_name : inplace_vars) {
770777
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
771-
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
778+
auto* original_tensor =
779+
GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
772780
auto* var = transfer_scope.FindVar(var_name);
773781
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
774782
var_name);
775-
auto* transformed_tensor = GetTensorFromVar(*var);
783+
auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
776784
original_tensor->ShareDataWith(*transformed_tensor);
777785
}
778786
}
@@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData(
789797
continue;
790798
}
791799

792-
auto* tensor_in = GetTensorFromVar(*var);
800+
auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
793801
if (!tensor_in->IsInitialized()) {
794802
continue;
795803
}

paddle/fluid/framework/operator.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD";
5454
/// Variables with this suffix are supposed to be filled up with zeros.
5555
constexpr char kZeroVarSuffix[] = "@ZERO";
5656

57+
/// Variables with this suffix are the new Gradient.
58+
constexpr char kNewGradSuffix[] = "@NEWGRAD@";
59+
5760
// define some kernel priority
5861
/* Define multiple kernel type fallback order*/
5962
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) {
6366
}
6467

6568
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
66-
const Tensor* GetTensorFromVar(const Variable& var);
69+
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
70+
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
6771

6872
class OperatorBase;
6973
class ExecutionContext;
@@ -224,7 +228,7 @@ class ExecutionContext {
224228
std::vector<const T*> res;
225229
res.reserve(names.size());
226230
std::transform(names.begin(), names.end(), std::back_inserter(res),
227-
[&](const std::string& sub_name) {
231+
[&](const std::string& sub_name) -> const T* {
228232
auto var = scope_.FindVar(sub_name);
229233
return var == nullptr ? nullptr : &var->Get<T>();
230234
});
@@ -237,7 +241,7 @@ class ExecutionContext {
237241
std::vector<T*> res;
238242
res.reserve(names.size());
239243
std::transform(names.begin(), names.end(), std::back_inserter(res),
240-
[&](const std::string& sub_name) {
244+
[&](const std::string& sub_name) -> T* {
241245
auto var = scope_.FindVar(sub_name);
242246
return var == nullptr ? nullptr : var->GetMutable<T>();
243247
});

paddle/fluid/operators/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor)
296296
op_library(parallel_do_op DEPS executor)
297297
op_library(unsqueeze_op DEPS reshape_op)
298298
op_library(squeeze_op DEPS reshape_op)
299-
op_library(extract_rows_op DEPS memory)
300299
op_library(flatten_op DEPS reshape_op)
301300
op_library(sequence_pad_op DEPS sequence_padding)
302301
op_library(unstack_op DEPS stack_op)

paddle/fluid/operators/elementwise_add_op.h

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ struct AddFunctor {
2828
};
2929

3030
template <typename DeviceContext, typename T>
31-
void default_elementwise_add(const framework::ExecutionContext& ctx,
32-
const framework::Tensor* x,
33-
const framework::Tensor* y, framework::Tensor* z) {
31+
void default_elementwise_add(const framework::ExecutionContext &ctx,
32+
const framework::Tensor *x,
33+
const framework::Tensor *y, framework::Tensor *z) {
3434
int axis = ctx.Attr<int>("axis");
3535
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
3636
AddFunctor<T>(), z);
@@ -40,9 +40,9 @@ template <typename DeviceContext, typename T>
4040
typename std::enable_if<
4141
std::is_floating_point<T>::value &&
4242
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
43-
elementwise_add(const framework::ExecutionContext& ctx,
44-
const framework::Tensor* x, const framework::Tensor* y,
45-
framework::Tensor* z) {
43+
elementwise_add(const framework::ExecutionContext &ctx,
44+
const framework::Tensor *x, const framework::Tensor *y,
45+
framework::Tensor *z) {
4646
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
4747
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
4848
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
@@ -55,21 +55,20 @@ template <typename DeviceContext, typename T>
5555
typename std::enable_if<
5656
!std::is_floating_point<T>::value ||
5757
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
58-
elementwise_add(const framework::ExecutionContext& ctx,
59-
const framework::Tensor* x, const framework::Tensor* y,
60-
framework::Tensor* z) {
58+
elementwise_add(const framework::ExecutionContext &ctx,
59+
const framework::Tensor *x, const framework::Tensor *y,
60+
framework::Tensor *z) {
6161
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
6262
}
6363

6464
template <typename DeviceContext, typename T>
6565
class ElementwiseAddKernel : public framework::OpKernel<T> {
6666
public:
67-
void Compute(const framework::ExecutionContext& ctx) const override {
68-
using Tensor = framework::Tensor;
67+
void Compute(const framework::ExecutionContext &ctx) const override {
68+
auto *x = ctx.Input<framework::LoDTensor>("X");
69+
auto *y = ctx.Input<framework::LoDTensor>("Y");
70+
auto *z = ctx.Output<framework::LoDTensor>("Out");
6971

70-
const auto x = ctx.Input<Tensor>("X");
71-
const auto y = ctx.Input<Tensor>("Y");
72-
auto z = ctx.Output<Tensor>("Out");
7372
z->mutable_data<T>(ctx.GetPlace());
7473

7574
auto dims_equal = x->dims() == y->dims();
@@ -87,13 +86,13 @@ struct IdentityGrad {
8786
};
8887

8988
template <typename DeviceContext, typename T>
90-
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
91-
const framework::Tensor* x,
92-
const framework::Tensor* y,
93-
const framework::Tensor* out,
94-
const framework::Tensor* dout,
95-
framework::Tensor* dx,
96-
framework::Tensor* dy) {
89+
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
90+
const framework::Tensor *x,
91+
const framework::Tensor *y,
92+
const framework::Tensor *out,
93+
const framework::Tensor *dout,
94+
framework::Tensor *dx,
95+
framework::Tensor *dy) {
9796
int axis = ctx.Attr<int>("axis");
9897

9998
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
@@ -106,11 +105,11 @@ template <typename DeviceContext, typename T>
106105
typename std::enable_if<
107106
std::is_floating_point<T>::value &&
108107
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
109-
elementwise_add_grad(const framework::ExecutionContext& ctx,
110-
const framework::Tensor* x, const framework::Tensor* y,
111-
const framework::Tensor* out,
112-
const framework::Tensor* dout, framework::Tensor* dx,
113-
framework::Tensor* dy) {
108+
elementwise_add_grad(const framework::ExecutionContext &ctx,
109+
const framework::Tensor *x, const framework::Tensor *y,
110+
const framework::Tensor *out,
111+
const framework::Tensor *dout, framework::Tensor *dx,
112+
framework::Tensor *dy) {
114113
auto blas = math::GetBlas<DeviceContext, T>(ctx);
115114

116115
if (dx) {
@@ -128,27 +127,27 @@ template <typename DeviceContext, typename T>
128127
typename std::enable_if<
129128
!std::is_floating_point<T>::value ||
130129
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
131-
elementwise_add_grad(const framework::ExecutionContext& ctx,
132-
const framework::Tensor* x, const framework::Tensor* y,
133-
const framework::Tensor* out,
134-
const framework::Tensor* dout, framework::Tensor* dx,
135-
framework::Tensor* dy) {
130+
elementwise_add_grad(const framework::ExecutionContext &ctx,
131+
const framework::Tensor *x, const framework::Tensor *y,
132+
const framework::Tensor *out,
133+
const framework::Tensor *dout, framework::Tensor *dx,
134+
framework::Tensor *dy) {
136135
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
137136
}
138137

139138
template <typename DeviceContext, typename T>
140139
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
141140
public:
142-
void Compute(const framework::ExecutionContext& ctx) const override {
141+
void Compute(const framework::ExecutionContext &ctx) const override {
143142
ElemwiseGradKernel<T>::Compute(ctx);
144143

145144
using Tensor = framework::Tensor;
146145

147-
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
148-
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
149-
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
146+
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
147+
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
148+
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
150149
// skip out, x, y
151-
auto* out = dout;
150+
auto *out = dout;
152151
auto *x = dout, *y = dout;
153152

154153
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&

paddle/fluid/operators/elementwise_div_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseDivKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
31+
auto* x = ctx.Input<framework::LoDTensor>("X");
32+
auto* y = ctx.Input<framework::LoDTensor>("Y");
33+
auto* z = ctx.Output<framework::LoDTensor>("Out");
3234

33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
3635
z->mutable_data<T>(ctx.GetPlace());
3736
int axis = ctx.Attr<int>("axis");
3837
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,

paddle/fluid/operators/elementwise_max_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ template <typename DeviceContext, typename T>
2929
class ElementwiseMaxKernel : public framework::OpKernel<T> {
3030
public:
3131
void Compute(const framework::ExecutionContext& ctx) const override {
32-
using Tensor = framework::Tensor;
32+
auto* x = ctx.Input<framework::LoDTensor>("X");
33+
auto* y = ctx.Input<framework::LoDTensor>("Y");
34+
auto* z = ctx.Output<framework::LoDTensor>("Out");
3335

34-
auto* x = ctx.Input<Tensor>("X");
35-
auto* y = ctx.Input<Tensor>("Y");
36-
auto* z = ctx.Output<Tensor>("Out");
3736
z->mutable_data<T>(ctx.GetPlace());
3837
int axis = ctx.Attr<int>("axis");
3938
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,

paddle/fluid/operators/elementwise_min_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
2828
class ElementwiseMinKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override {
31-
using Tensor = framework::Tensor;
31+
auto* x = ctx.Input<framework::LoDTensor>("X");
32+
auto* y = ctx.Input<framework::LoDTensor>("Y");
33+
auto* z = ctx.Output<framework::LoDTensor>("Out");
3234

33-
auto* x = ctx.Input<Tensor>("X");
34-
auto* y = ctx.Input<Tensor>("Y");
35-
auto* z = ctx.Output<Tensor>("Out");
3635
z->mutable_data<T>(ctx.GetPlace());
3736
int axis = ctx.Attr<int>("axis");
3837
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,

paddle/fluid/operators/elementwise_mul_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,10 @@ template <typename DeviceContext, typename T>
6060
class ElementwiseMulKernel : public framework::OpKernel<T> {
6161
public:
6262
void Compute(const framework::ExecutionContext& ctx) const override {
63-
using Tensor = framework::Tensor;
63+
auto* x = ctx.Input<framework::LoDTensor>("X");
64+
auto* y = ctx.Input<framework::LoDTensor>("Y");
65+
auto* z = ctx.Output<framework::LoDTensor>("Out");
6466

65-
auto* x = ctx.Input<Tensor>("X");
66-
auto* y = ctx.Input<Tensor>("Y");
67-
auto* z = ctx.Output<Tensor>("Out");
6867
z->mutable_data<T>(ctx.GetPlace());
6968
if (x->numel() == y->numel()) {
7069
elementwise_mul<DeviceContext, T>(ctx, x, y, z);

0 commit comments

Comments
 (0)