File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed
paddle/cinn/hlir/dialect/operator/transforms/lowering_pass Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -253,6 +253,17 @@ CreateGroupShapeOrDataExprs(
253
253
InferSymbolicShapeForOperation (op, &local_shape_analysis);
254
254
}
255
255
256
+ auto broadcast_contains = [](const symbol::DimExpr& dimexpr,
257
+ const symbol::DimExpr& target) {
258
+ auto broadcast =
259
+ std::get_if<symbol::Broadcast<symbol::DimExpr>>(&dimexpr.variant ());
260
+ if (broadcast == nullptr ) return false ;
261
+ for (const auto & operand : *(broadcast->operands )) {
262
+ if (operand == target) return true ;
263
+ }
264
+ return false ;
265
+ };
266
+
256
267
// Add shape constraints after infer.
257
268
auto & mut_substitute_dimexpr_map = group->mut_substitute_dimexpr_map ();
258
269
for (auto * op : group->ops ()) {
@@ -264,7 +275,9 @@ CreateGroupShapeOrDataExprs(
264
275
if (global_result_shape.size () != local_result_shape.size ()) continue ;
265
276
for (size_t i = 0 ; i < global_result_shape.size (); ++i) {
266
277
if (global_result_shape[i] != local_result_shape[i] &&
267
- !global_result_shape[i].isa <std::int64_t >()) {
278
+ !global_result_shape[i].isa <std::int64_t >() &&
279
+ !broadcast_contains (local_result_shape[i],
280
+ global_result_shape[i])) {
268
281
mut_substitute_dimexpr_map[global_result_shape[i]] =
269
282
local_result_shape[i];
270
283
}
You can’t perform that action at this time.
0 commit comments