Skip to content

Commit 9998434

Browse files
authored
[Inference] resolve_coflicts_between_passes (#64830)
* resolve_coflicts_between_passes * fix bug: transfer_layout_pass should rewrite value occur in control flow
1 parent 60871a2 commit 9998434

File tree

4 files changed

+66
-14
lines changed

4 files changed

+66
-14
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ const std::vector<std::string> kPirGpuPasses{
617617
"matmul_transpose_fuse_pass",
618618
"transpose_flatten_concat_fuse_pass",
619619
"remove_redundant_transpose_pass",
620-
"transfer_layout_pass"};
620+
"transfer_layout_pass",
621+
};
621622

622623
const std::vector<std::string> kPirXpuPasses{// Functional pass
623624
"map_op_to_another_pass",

paddle/fluid/pir/dialect/operator/interface/layout_transformation.cc

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,30 @@ common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
4848
data_format_attr));
4949
}
5050

51-
// Note(lyk): We exhibit the layout transformation for conv2d
51+
auto concrete_op = op->dyn_cast<Conv2dOp>();
52+
if (auto in = concrete_op.input()) {
53+
if (auto in_type = in.type()) {
54+
if (in_type.isa<DenseTensorType>()) {
55+
if (auto tensor_type = in_type.dyn_cast<DenseTensorType>()) {
56+
if (tensor_type.dtype().isa<pir::Float16Type>()) {
57+
return common::DataLayout::NHWC;
58+
}
59+
}
60+
}
61+
}
62+
}
63+
64+
return common::StringToDataLayout(data_format_attr.AsString());
65+
}
66+
67+
template <>
68+
std::vector<pir::Value> RelevantInputsImpl<Conv2dOp>(pir::Operation* op) {
69+
// Note(lyk): We exhibit the layout transformation for filter of conv2d
5270
// due to issues with its infermeta and kernel not functioning
5371
// properly in NHWC layout. However, if the FLAGS_manually_trans_conv_filter
5472
// is enabled, the transfer_layout_pass can also operate correctly.
55-
return common::StringToDataLayout(data_format_attr.AsString());
73+
auto concrete_op = op->dyn_cast<Conv2dOp>();
74+
return {concrete_op.input()};
5675
}
5776

5877
template <>
@@ -124,6 +143,31 @@ void RewriteByLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op,
124143
RewriteByInfermeta<FusedConv2dAddActOp>(op, new_layout);
125144
}
126145

146+
template <>
147+
bool CanBeModifiedImpl<FusedConv2dAddActOp>(pir::Operation* op) {
148+
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
149+
if (!data_format_attr) {
150+
PADDLE_THROW(phi::errors::InvalidArgument(
151+
"op (%s) should have attribute `data_format`, but got %s",
152+
op,
153+
data_format_attr));
154+
}
155+
auto cur_layout = common::StringToDataLayout(data_format_attr.AsString());
156+
auto prefer_layout = PreferLayoutImpl<FusedConv2dAddActOp>(op);
157+
auto can_be_modified = cur_layout != prefer_layout;
158+
159+
for (auto value : RelevantOutputsImpl<FusedConv2dAddActOp>(op)) {
160+
// TODO(lyk) if value was used in another block, we cannot rewrite this op
161+
for (auto it = value.use_begin(); it != value.use_end(); ++it) {
162+
if (it->owner()->GetParent() != op->GetParent()) {
163+
return false;
164+
}
165+
}
166+
}
167+
168+
return can_be_modified;
169+
}
170+
127171
template <>
128172
void RewriteByLayoutImpl<GroupNormOp>(pir::Operation* op,
129173
common::DataLayout new_layout) {

paddle/fluid/pir/dialect/operator/interface/layout_transformation.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ bool CanBeModifiedImpl(pir::Operation* op) {
105105
class FusedConv2dAddActOp;
106106
OVERLOAD_PREFER_LAYOUT(FusedConv2dAddActOp);
107107
OVERLOAD_REWRITE_BY_LAYOUT(FusedConv2dAddActOp);
108+
OVERLOAD_CAN_BE_MODIFIED(FusedConv2dAddActOp);
108109

109110
class Conv2dOp;
110111
OVERLOAD_PREFER_LAYOUT(Conv2dOp);
112+
OVERLOAD_RELEVANT_INPUTS(Conv2dOp);
111113
OVERLOAD_REWRITE_BY_LAYOUT(Conv2dOp);
112114

113115
class GroupNormOp;

paddle/fluid/pir/transforms/general/transfer_layout_pass.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,18 +278,22 @@ struct FlowGraph {
278278
}
279279
}
280280

281-
std::unordered_set<Node> nhwc_nodes;
281+
std::unordered_set<Node> mutable_nodes;
282282
for (auto& op : *(program.block())) {
283283
auto layout_transform_iface =
284284
op.dyn_cast<paddle::dialect::LayoutTransformationInterface>();
285285
if (!layout_transform_iface) {
286286
continue;
287287
}
288288

289+
if (!layout_transform_iface.CanBeModified(&op)) {
290+
continue;
291+
}
292+
289293
auto prefer_layout = layout_transform_iface.PreferLayout(&op);
290294
if (prefer_layout == common::DataLayout::NHWC) {
291295
Node op_node(&op);
292-
nhwc_nodes.insert(op_node);
296+
mutable_nodes.insert(op_node);
293297
AddEdge(op_node, dst_node(), INF);
294298
VLOG(10) << "[PreProcess] node: " << op_node
295299
<< " should be set to NHWC";
@@ -302,7 +306,7 @@ struct FlowGraph {
302306
// operation who have a dertermined layout and spread its layout to
303307
// its output and inputs recursively.
304308
std::queue<Node> q;
305-
for (auto& n : nhwc_nodes) {
309+
for (auto& n : mutable_nodes) {
306310
q.push(n);
307311
}
308312
std::unordered_set<Node> is_node_layout_visited;
@@ -362,13 +366,14 @@ struct FlowGraph {
362366
// a point of cut edge. So we set its outputs and inputs to
363367
// immutable.
364368
Node in_node = Node(v.defining_op());
365-
nhwc_nodes.erase(in_node);
366-
VLOG(10) << "erase node: " << in_node << " from nhwc set";
369+
mutable_nodes.erase(in_node);
370+
VLOG(10) << "erase node: " << in_node << " from mutable set";
367371

368372
for (auto it = v.use_begin(); it != v.use_end(); ++it) {
369373
Node out_node(it->owner());
370-
nhwc_nodes.erase(out_node);
371-
VLOG(10) << "erase node: " << out_node << " from nhwc set";
374+
mutable_nodes.erase(out_node);
375+
VLOG(10)
376+
<< "erase node: " << out_node << " from mutable set";
372377
}
373378
}
374379
return !can_be_transformed;
@@ -380,8 +385,8 @@ struct FlowGraph {
380385
continue;
381386
}
382387

383-
VLOG(10) << "add node to nhwc set: " << node;
384-
nhwc_nodes.insert(node);
388+
VLOG(10) << "add node to mutable set: " << node;
389+
mutable_nodes.insert(node);
385390

386391
VLOG(10) << "processing node successor: " << node;
387392

@@ -403,7 +408,7 @@ struct FlowGraph {
403408
continue;
404409
}
405410
is_node_layout_visited.insert(node);
406-
if (nhwc_nodes.count(node) == 0) {
411+
if (mutable_nodes.count(node) == 0) {
407412
VLOG(10) << "add node to nchw set: " << node;
408413
AddEdge(src_node(), node, INF);
409414
}
@@ -542,7 +547,7 @@ using Edge = FlowGraph::Edge;
542547

543548
class TransferLayoutPass : public pir::Pass {
544549
public:
545-
TransferLayoutPass() : pir::Pass("transfer_layout_pass", 3) {}
550+
TransferLayoutPass() : pir::Pass("transfer_layout_pass", 2) {}
546551

547552
bool CanApplyOn(pir::Operation* op) const override {
548553
if (!op->isa<pir::ModuleOp>()) {

0 commit comments

Comments
 (0)