Skip to content

Commit 835201b

Browse files
authored
Fix multi-threads memory out of bounds error for passes (#21920) (#22132)
* fix seqconv_eltadd_relu pass during multi-threads predictor, test=develop * fix attention_lstm_fuse_pass during multi-threads inference, test=develop * fix embedding_fc_lstm_fuse_pass during multi-threads inference, test=develop * fix fc_lstm_fuse_pass during multi-threads inference, test=develop * fix seq_concat_fc_fuse_pass during multi-threads inference, test=develop
1 parent 5a611af commit 835201b

File tree

5 files changed

+74
-50
lines changed

5 files changed

+74
-50
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct Param {
4141
std::string LSTMOUT = "at.lstmout.new";
4242
};
4343

44-
void PrepareParameters(Graph* graph, const Param& param);
44+
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op);
4545

4646
void FindWhileOp(Graph* graph) {
4747
GraphPatternDetector gpd;
@@ -98,7 +98,7 @@ void FindWhileOp(Graph* graph) {
9898
auto* hidden_init = graph->RetrieveNode(8);
9999

100100
auto* lstm_op = graph->CreateOpNode(&op_desc);
101-
PrepareParameters(graph, param);
101+
PrepareParameters(graph, param, lstm_op);
102102

103103
IR_NODE_LINK_TO(X, lstm_op);
104104
IR_NODE_LINK_TO(cell_init, lstm_op);
@@ -133,20 +133,29 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
133133
const LoDTensor& B_output, const LoDTensor& B_cell,
134134
LoDTensor* out);
135135

136-
void PrepareParameters(Graph* graph, const Param& param) {
136+
void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) {
137137
// Check parameters
138138
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
139139
auto& scope = graph->Get<Scope>(kParamScopeAttr);
140140

141141
// Create new parameters.
142+
// AddInput
142143
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
143144
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
144-
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
145-
scope.Var(param.Cell)->GetMutable<LoDTensor>();
146-
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
147-
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
148-
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
149-
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
145+
// AddOutput
146+
#define IR_NODE(x) \
147+
VarDesc key_##x(param.x); \
148+
key_##x.SetPersistable(false); \
149+
auto* node_##x = graph->CreateVarNode(&key_##x); \
150+
IR_NODE_LINK_TO(lstm_op, node_##x);
151+
152+
IR_NODE(Hidden);
153+
IR_NODE(Cell);
154+
IR_NODE(AttentionedX);
155+
IR_NODE(AttentionFCOut);
156+
IR_NODE(LSTMX);
157+
IR_NODE(LSTMOUT);
158+
#undef IR_NODE
150159

151160
#define GATE_W(name__) \
152161
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \

paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -127,47 +127,53 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
127127
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
128128
op_desc.SetInput("Embeddings", {embeddings});
129129

130-
// Create temp variables.
131-
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
132-
const std::string BatchedCellPreAct =
133-
patterns::UniqueKey("BatchedCellPreAct");
134-
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
135-
136-
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
137-
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
138-
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
139-
140130
op_desc.SetInput("H0", {});
141131
op_desc.SetInput("C0", {});
142132
op_desc.SetOutput("Hidden", {hidden->Name()});
143133
op_desc.SetOutput("Cell", {cell->Name()});
144134
op_desc.SetOutput("XX", {xx->Name()});
145-
op_desc.SetOutput("BatchedGate", {BatchedGate});
146-
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
147-
op_desc.SetOutput("BatchedInput", {BatchedInput});
148135
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
149136
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
150137
// TODO(TJ): get from attr
151138
op_desc.SetAttr("use_seq", true);
152139

153-
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
154-
auto& scope = graph->Get<Scope>(kParamScopeAttr);
140+
// Create temp variables.
155141
#define OP_SET_OUT(x) \
156142
const std::string x = patterns::UniqueKey(#x); \
157-
op_desc.SetOutput(#x, {x}); \
158-
scope.Var(x)->GetMutable<LoDTensor>()
143+
op_desc.SetOutput(#x, {x});
144+
145+
OP_SET_OUT(BatchedGate);
146+
OP_SET_OUT(BatchCellPreAct);
147+
OP_SET_OUT(BatchedInput);
159148
OP_SET_OUT(BatchedCell);
160149
OP_SET_OUT(BatchedHidden);
161150
OP_SET_OUT(ReorderedH0);
162151
OP_SET_OUT(ReorderedC0);
163152
#undef OP_SET_OUT
164153

165154
auto* op = graph->CreateOpNode(&op_desc);
155+
166156
IR_NODE_LINK_TO(input, op);
167157
IR_NODE_LINK_TO(weight_x, op);
168158
IR_NODE_LINK_TO(weight_h, op);
169159
IR_NODE_LINK_TO(bias, op);
170160
IR_NODE_LINK_TO(op, hidden);
161+
162+
#define IR_NODE(x) \
163+
VarDesc key_##x(x); \
164+
key_##x.SetPersistable(false); \
165+
auto* node_##x = graph->CreateVarNode(&key_##x); \
166+
IR_NODE_LINK_TO(op, node_##x);
167+
168+
IR_NODE(BatchedGate);
169+
IR_NODE(BatchCellPreAct);
170+
IR_NODE(BatchedInput);
171+
IR_NODE(BatchedCell);
172+
IR_NODE(BatchedHidden);
173+
IR_NODE(ReorderedH0);
174+
IR_NODE(ReorderedC0);
175+
#undef IR_NODE
176+
171177
return op;
172178
};
173179

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,50 +74,55 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
7474
op_desc.SetInput("Bias", {new_bias_var});
7575
}
7676

77-
// Create temp variables.
78-
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
79-
const std::string BatchedCellPreAct =
80-
patterns::UniqueKey("BatchedCellPreAct");
81-
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
82-
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
83-
84-
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
85-
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
86-
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
87-
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
88-
8977
op_desc.SetInput("H0", {});
9078
op_desc.SetInput("C0", {});
9179
op_desc.SetOutput("Hidden", {hidden->Name()});
9280
op_desc.SetOutput("Cell", {cell->Name()});
9381
op_desc.SetOutput("XX", {xx->Name()});
94-
op_desc.SetOutput("BatchedGate", {BatchedGate});
95-
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
96-
op_desc.SetOutput("BatchedInput", {BatchedInput});
97-
op_desc.SetOutput("CheckedCell", {CheckedCell});
9882
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
9983
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
10084
// TODO(TJ): get from attr
10185
op_desc.SetAttr("use_seq", true);
10286

103-
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
104-
auto& scope = graph->Get<Scope>(kParamScopeAttr);
87+
// Create temp variables.
10588
#define OP_SET_OUT(x) \
10689
const std::string x = patterns::UniqueKey(#x); \
107-
op_desc.SetOutput(#x, {x}); \
108-
scope.Var(x)->GetMutable<LoDTensor>()
90+
op_desc.SetOutput(#x, {x});
91+
92+
OP_SET_OUT(BatchedGate);
93+
OP_SET_OUT(BatchedCellPreAct);
94+
OP_SET_OUT(BatchedInput);
95+
OP_SET_OUT(CheckedCell);
10996
OP_SET_OUT(BatchedCell);
11097
OP_SET_OUT(BatchedHidden);
11198
OP_SET_OUT(ReorderedH0);
11299
OP_SET_OUT(ReorderedC0);
113100
#undef OP_SET_OUT
114101

115102
auto* op = graph->CreateOpNode(&op_desc);
103+
116104
IR_NODE_LINK_TO(input, op);
117105
IR_NODE_LINK_TO(weight_x, op);
118106
IR_NODE_LINK_TO(weight_h, op);
119107
IR_NODE_LINK_TO(bias, op);
120108
IR_NODE_LINK_TO(op, hidden);
109+
110+
#define IR_NODE(x) \
111+
VarDesc key_##x(x); \
112+
key_##x.SetPersistable(false); \
113+
auto* node_##x = graph->CreateVarNode(&key_##x); \
114+
IR_NODE_LINK_TO(op, node_##x);
115+
116+
IR_NODE(BatchedGate);
117+
IR_NODE(BatchedCellPreAct);
118+
IR_NODE(BatchedInput);
119+
IR_NODE(CheckedCell);
120+
IR_NODE(BatchedCell);
121+
IR_NODE(BatchedHidden);
122+
IR_NODE(ReorderedH0);
123+
IR_NODE(ReorderedC0);
124+
#undef IR_NODE
125+
121126
return op;
122127
};
123128

paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
214214
op_desc.SetInput("FCWeight", {fc_w->Name()});
215215
op_desc.SetInput("FCBias", {fc_bias->Name()});
216216
const std::string fc_out_tmp = fc_out->Name() + ".tmp";
217-
param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
217+
VarDesc fc_out_key(fc_out_tmp);
218+
fc_out_key.SetPersistable(false);
219+
auto* fc_out_node = graph->CreateVarNode(&fc_out_key);
218220
op_desc.SetOutput("FCOut", {fc_out_tmp});
219221
op_desc.SetOutput("Out", {fc_out->Name()});
220222
op_desc.SetAttr("fc_activation", act->Op()->Type());
@@ -227,6 +229,7 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
227229
IR_NODE_LINK_TO(sequence_expand0_in, op_node);
228230
IR_NODE_LINK_TO(sequence_expand1_in, op_node);
229231
IR_NODE_LINK_TO(op_node, fc_out);
232+
IR_NODE_LINK_TO(op_node, fc_out_node);
230233

231234
// Clean nodes.
232235
std::unordered_set<const Node*> marked_nodes;

paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,19 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
4242
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
4343
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
4444
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
45-
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
46-
auto& scope = graph->Get<Scope>(kParamScopeAttr);
4745
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
4846
op_desc.SetOutput("ColMat", {ColMat});
4947
op_desc.SetOutput("Out", {relu_out->Name()});
50-
scope.Var(ColMat)->GetMutable<LoDTensor>();
48+
VarDesc key(ColMat);
49+
key.SetPersistable(false);
50+
auto* key_col_mat = graph->CreateVarNode(&key);
5151

5252
auto* op = graph->CreateOpNode(&op_desc);
5353
IR_NODE_LINK_TO(input, op);
5454
IR_NODE_LINK_TO(seqconv_weight, op);
5555
IR_NODE_LINK_TO(eltadd_bias, op);
5656
IR_NODE_LINK_TO(op, relu_out);
57+
IR_NODE_LINK_TO(op, key_col_mat);
5758
return op;
5859
};
5960

0 commit comments

Comments
 (0)