47
47
48
48
#include < cassert>
49
49
#include < cstdint>
50
+ #include < numeric>
50
51
51
52
#include " mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52
53
// Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
2412
2413
return success ();
2413
2414
}
2414
2415
2416
+ // / Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2417
+ // /
2418
+ // / Example:
2419
+ // / %b = vector.broadcast %x : i32 to vector<3xf32>
2420
+ // / %e:3 = vector.to_elements %b : vector<3xf32>
2421
+ // / user_op %e#0, %e#1, %e#2
2422
+ // / becomes:
2423
+ // / user_op %x, %x, %x
2424
+ // /
2425
+ // / The vector source case is handled by a canonicalization pattern.
2426
+ static LogicalResult
2427
+ foldToElementsOfBroadcast (ToElementsOp toElementsOp,
2428
+ SmallVectorImpl<OpFoldResult> &results) {
2429
+ auto bcastOp = toElementsOp.getSource ().getDefiningOp <BroadcastOp>();
2430
+ if (!bcastOp)
2431
+ return failure ();
2432
+ // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2433
+ if (isa<VectorType>(bcastOp.getSource ().getType ()))
2434
+ return failure ();
2435
+
2436
+ auto resultVecType = cast<VectorType>(toElementsOp.getSource ().getType ());
2437
+
2438
+ Value scalar = bcastOp.getSource ();
2439
+ results.assign (resultVecType.getNumElements (), scalar);
2440
+ return success ();
2441
+ }
2442
+
2415
2443
LogicalResult ToElementsOp::fold (FoldAdaptor adaptor,
2416
2444
SmallVectorImpl<OpFoldResult> &results) {
2417
- return foldToElementsFromElements (*this , results);
2445
+ if (succeeded (foldToElementsFromElements (*this , results)))
2446
+ return success ();
2447
+ return foldToElementsOfBroadcast (*this , results);
2418
2448
}
2419
2449
2420
2450
LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2427
2457
return success ();
2428
2458
}
2429
2459
2460
+ // / Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2461
+ // / vector.
2462
+ // / - Build `vector.to_elements %v` and remap each destination element to the
2463
+ // / corresponding source element using broadcast rules (match or 1 →
2464
+ // / replicate).
2465
+ // /
2466
+ // / Example:
2467
+ // / %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2468
+ // / %e:6 = vector.to_elements %v : vector<3x2xf32>
2469
+ // / becomes:
2470
+ // / %src_elems:2 = vector.to_elements %src : vector<2xf32>
2471
+ // / // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2472
+ // / // %src_elems#1, %src_elems#0, %src_elems#1
2473
+ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2474
+ using Base::Base;
2475
+
2476
+ LogicalResult matchAndRewrite (ToElementsOp toElementsOp,
2477
+ PatternRewriter &rewriter) const override {
2478
+ auto bcastOp = toElementsOp.getSource ().getDefiningOp <BroadcastOp>();
2479
+ if (!bcastOp)
2480
+ return failure ();
2481
+
2482
+ // Only handle broadcasts from a vector source here.
2483
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource ().getType ());
2484
+ if (!srcType)
2485
+ return failure ();
2486
+
2487
+ auto dstType = cast<VectorType>(toElementsOp.getSource ().getType ());
2488
+
2489
+ ArrayRef<int64_t > dstShape = dstType.getShape ();
2490
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2491
+
2492
+ int64_t dstRank = dstShape.size ();
2493
+ int64_t srcRank = srcShape.size ();
2494
+
2495
+ // Create elements for the broadcast source vector.
2496
+ auto srcElems = vector::ToElementsOp::create (
2497
+ rewriter, toElementsOp.getLoc (), bcastOp.getSource ());
2498
+
2499
+ int64_t dstCount = std::accumulate (dstShape.begin (), dstShape.end (), 1 ,
2500
+ std::multiplies<int64_t >());
2501
+
2502
+ SmallVector<Value> replacements;
2503
+ replacements.reserve (dstCount);
2504
+
2505
+ // For each element of the destination, determine which element of the
2506
+ // source should be used. We walk all destination positions using a single
2507
+ // counter, decode it into per-dimension indices, then build the matching
2508
+ // source position: use the same index where sizes match, and use 0 where
2509
+ // the source size is 1 (replication). This mapping is needed so we can
2510
+ // replace each result of to_elements with the corresponding element from
2511
+ // the broadcast source.
2512
+ // Inner-dimension stretch example:
2513
+ // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2514
+ // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2515
+ // becomes:
2516
+ // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2517
+ // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2518
+ // // %src_elems#1, %src_elems#0, %src_elems#1,
2519
+ // // %src_elems#2, %src_elems#3, %src_elems#2,
2520
+ // // %src_elems#3, %src_elems#2, %src_elems#3
2521
+
2522
+ // Row-major strides for the destination shape.
2523
+ SmallVector<int64_t > dstStrides = computeStrides (dstShape);
2524
+ // Row-major strides for the source shape.
2525
+ SmallVector<int64_t > srcStrides = computeStrides (srcShape);
2526
+ SmallVector<int64_t > dstIdx (dstRank);
2527
+ SmallVector<int64_t > srcIdx (srcRank);
2528
+ for (int64_t lin = 0 ; lin < dstCount; ++lin) {
2529
+ // Convert linear destination index to per-dimension indices.
2530
+ dstIdx = delinearize (lin, dstStrides);
2531
+ for (int64_t k = 0 ; k < srcRank; ++k)
2532
+ srcIdx[k] = (srcShape[k] == 1 ) ? 0 : dstIdx[dstRank - srcRank + k];
2533
+ // Convert per-dimension source indices back to a linear index.
2534
+ int64_t srcLin = linearize (srcIdx, srcStrides);
2535
+ replacements.push_back (srcElems.getResult (srcLin));
2536
+ }
2537
+
2538
+ rewriter.replaceOp (toElementsOp, replacements);
2539
+ return success ();
2540
+ }
2541
+ };
2542
+
2543
+ void ToElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
2544
+ MLIRContext *context) {
2545
+ results.add <ToElementsOfBroadcast>(context);
2546
+ }
2547
+
2430
2548
// ===----------------------------------------------------------------------===//
2431
2549
// FromElementsOp
2432
2550
// ===----------------------------------------------------------------------===//
0 commit comments