Skip to content

Commit dbaed63

Browse files
authored
[CINN] Create new AxisTransform when substitute dimexpr (#71587) (#71653)
1 parent bbc3129 commit dbaed63

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -161,31 +161,41 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
161161

162162
void RunAxisTransformInstr(const std::shared_ptr<AxisTransformInstr>& instr,
163163
FusionInterpreter* interpreter) {
164-
auto substitute_dimexpr_for_shape = [&](std::vector<symbol::DimExpr>& shape) {
165-
for (auto& dim_expr : shape) {
166-
if (dim_expr.isa<std::int64_t>()) continue;
167-
symbol::DimExpr origin_dim_expr = dim_expr;
168-
while (true) {
169-
dim_expr = symbol::SubstituteDimExpr(
170-
dim_expr, interpreter->substitute_dimexpr_map);
171-
if (dim_expr == origin_dim_expr || dim_expr.isa<std::int64_t>()) break;
172-
origin_dim_expr = dim_expr;
173-
}
174-
}
175-
};
176-
auto substitute_dimexpr_for_transform =
177-
adt::match{[&](const AppendAxisTransformPtr& transform) {
178-
substitute_dimexpr_for_shape(transform->shape);
179-
},
180-
[&](const ReshapeTransformPtr& transform) {
181-
substitute_dimexpr_for_shape(transform->in_shape);
182-
substitute_dimexpr_for_shape(transform->out_shape);
183-
},
184-
[&](const auto& transform) {}};
164+
auto substitute_dimexpr_for_shape =
165+
[&](const std::vector<symbol::DimExpr>& shape) {
166+
std::vector<symbol::DimExpr> result;
167+
for (const auto& dim_expr : shape) {
168+
symbol::DimExpr substituted = dim_expr;
169+
while (true) {
170+
if (substituted.isa<std::int64_t>()) break;
171+
auto tmp_substituted = symbol::SubstituteDimExpr(
172+
substituted, interpreter->substitute_dimexpr_map);
173+
if (tmp_substituted == substituted) break;
174+
substituted = tmp_substituted;
175+
}
176+
result.emplace_back(substituted);
177+
}
178+
return result;
179+
};
180+
auto substitute_dimexpr_for_transform = adt::match{
181+
[&](const AppendAxisTransformPtr& trans) -> AxisTransform {
182+
auto substituted_shape = substitute_dimexpr_for_shape(trans->shape);
183+
return std::make_shared<AppendAxisTransform>(trans->axis,
184+
substituted_shape);
185+
},
186+
[&](const ReshapeTransformPtr& trans) -> AxisTransform {
187+
auto substituted_in_shape =
188+
substitute_dimexpr_for_shape(trans->in_shape);
189+
auto substituted_out_shape =
190+
substitute_dimexpr_for_shape(trans->out_shape);
191+
return std::make_shared<ReshapeTransform>(substituted_in_shape,
192+
substituted_out_shape);
193+
},
194+
[&](const auto& trans) -> AxisTransform { return trans; }};
185195
auto axis_transform = [&](ir::Expr op_expr) -> ir::Expr {
186196
for (auto trans : instr->axis_transform_route_) {
187-
std::visit(substitute_dimexpr_for_transform, trans);
188-
op_expr = std::visit(ApplyAxisTransform(op_expr), trans);
197+
auto new_trans = std::visit(substitute_dimexpr_for_transform, trans);
198+
op_expr = std::visit(ApplyAxisTransform(op_expr), new_trans);
189199
}
190200
return op_expr;
191201
};

0 commit comments

Comments
 (0)