Skip to content

Commit 54bf220

Browse files
authored
[CINN] Fix substitute dimexpr circle (#74432)
1 parent 1be810b commit 54bf220

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,17 @@ CreateGroupShapeOrDataExprs(
253253
InferSymbolicShapeForOperation(op, &local_shape_analysis);
254254
}
255255

256+
auto broadcast_contains = [](const symbol::DimExpr& dimexpr,
257+
const symbol::DimExpr& target) {
258+
auto broadcast =
259+
std::get_if<symbol::Broadcast<symbol::DimExpr>>(&dimexpr.variant());
260+
if (broadcast == nullptr) return false;
261+
for (const auto& operand : *(broadcast->operands)) {
262+
if (operand == target) return true;
263+
}
264+
return false;
265+
};
266+
256267
// Add shape constraints after infer.
257268
auto& mut_substitute_dimexpr_map = group->mut_substitute_dimexpr_map();
258269
for (auto* op : group->ops()) {
@@ -264,7 +275,9 @@ CreateGroupShapeOrDataExprs(
264275
if (global_result_shape.size() != local_result_shape.size()) continue;
265276
for (size_t i = 0; i < global_result_shape.size(); ++i) {
266277
if (global_result_shape[i] != local_result_shape[i] &&
267-
!global_result_shape[i].isa<std::int64_t>()) {
278+
!global_result_shape[i].isa<std::int64_t>() &&
279+
!broadcast_contains(local_result_shape[i],
280+
global_result_shape[i])) {
268281
mut_substitute_dimexpr_map[global_result_shape[i]] =
269282
local_result_shape[i];
270283
}

0 commit comments

Comments
 (0)