Skip to content

Commit 0357128

Browse files
committed
fix VisitVariable
1 parent fbb75c6 commit 0357128

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ namespace framework {
2929
namespace details {
3030

3131
struct BroadcastOpHandle : public OpHandleBase {
32-
const std::vector<Scope *> &local_scopes_;
33-
const std::vector<platform::Place> &places_;
34-
32+
public:
3533
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
3634
const std::vector<platform::Place> &places);
3735

@@ -41,10 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
4139

4240
protected:
4341
void RunImpl() override;
44-
4542
void WaitInputVarGenerated(const VarHandle &in_var);
46-
};
4743

44+
private:
45+
const std::vector<Scope *> &local_scopes_;
46+
const std::vector<platform::Place> &places_;
47+
};
4848
} // namespace details
4949
} // namespace framework
5050
} // namespace paddle

paddle/fluid/framework/details/variable_visitor.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@ namespace paddle {
1818
namespace framework {
1919
namespace details {
2020
template <typename Func>
21-
static void VisitVariable(Variable* var, Func func) {
21+
static void VisitVariable(Variable* var, Func* func) {
2222
if (var->IsType<LoDTensor>()) {
23-
func(var->GetMutable<LoDTensor>());
23+
(*func)(var->GetMutable<LoDTensor>());
2424
} else if (var->IsType<SelectedRows>()) {
25-
func(var->GetMutable<SelectedRows>());
25+
(*func)(var->GetMutable<SelectedRows>());
2626
} else {
2727
PADDLE_THROW("Not supported type %s", var->Type().name());
2828
}
2929
}
3030

3131
template <typename Func>
32-
static void VisitVariable(const Variable& var, Func func) {
32+
static void VisitVariable(const Variable& var, Func* func) {
3333
if (var.IsType<LoDTensor>()) {
34-
func(var.Get<LoDTensor>());
34+
(*func)(var.Get<LoDTensor>());
3535
} else if (var.IsType<SelectedRows>()) {
36-
func(var.Get<SelectedRows>());
36+
(*func)(var.Get<SelectedRows>());
3737
} else {
3838
PADDLE_THROW("Not supported type %s", var.Type().name());
3939
}
@@ -56,7 +56,7 @@ struct TensorVisitor {
5656

5757
Tensor& VariableVisitor::GetMutableTensor(Variable* var) {
5858
TensorVisitor vistor;
59-
VisitVariable(var, vistor);
59+
VisitVariable(var, &vistor);
6060
return *vistor.result_;
6161
}
6262

@@ -85,7 +85,7 @@ struct ShareDimsAndLoDVisitor {
8585

8686
void VariableVisitor::ShareDimsAndLoD(const Variable& src, Variable* trg) {
8787
ShareDimsAndLoDVisitor visitor{trg};
88-
VisitVariable(src, visitor);
88+
VisitVariable(src, &visitor);
8989
}
9090

9191
} // namespace details

0 commit comments

Comments
 (0)