|
40 | 40 | #include "llvm/IR/PatternMatch.h" |
41 | 41 | #include "llvm/Support/Alignment.h" |
42 | 42 | #include "llvm/Support/CommandLine.h" |
| 43 | +#include "llvm/Support/Compiler.h" |
43 | 44 | #include "llvm/Support/Debug.h" |
44 | 45 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
45 | 46 | #include "llvm/Transforms/Utils/LoopUtils.h" |
@@ -221,7 +222,16 @@ struct ShapeInfo { |
221 | 222 |
|
222 | 223 | /// Returns the transposed shape. |
223 | 224 | ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } |
| 225 | + |
| 226 | + friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI); |
| 227 | + |
| 228 | + LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; } |
224 | 229 | }; |
| 230 | + |
| 231 | +raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) { |
| 232 | + return OS << SI.NumRows << 'x' << SI.NumColumns; |
| 233 | +} |
| 234 | + |
225 | 235 | } // namespace |
226 | 236 |
|
227 | 237 | static bool isUniformShape(Value *V) { |
@@ -466,6 +476,8 @@ class LowerMatrixIntrinsics { |
466 | 476 | return getNumColumns(); |
467 | 477 | } |
468 | 478 |
|
| 479 | + ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; } |
| 480 | + |
469 | 481 | /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the |
470 | 482 | /// matrix is column-major, the result vector is extracted from a column |
471 | 483 | /// vector, otherwise from a row vector. |
@@ -578,6 +590,25 @@ class LowerMatrixIntrinsics { |
578 | 590 | SplitVecs.push_back(V); |
579 | 591 | } |
580 | 592 |
|
| 593 | + LLVM_DEBUG(if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) { |
| 594 | + if (Found != Inst2ColumnMatrix.end()) { |
| 595 | + // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles |
| 596 | + // that embedInVector created. |
| 597 | + dbgs() << "matrix reshape from " << Found->second.shape() << " to " |
| 598 | + << SI << " using at least " << SplitVecs.size() |
| 599 | + << " shuffles on behalf of " << *Inst << '\n'; |
| 600 | + } else if (!ShapeMap.contains(MatrixVal)) { |
| 601 | + dbgs() << "splitting a " << SI << " matrix with " << SplitVecs.size() |
| 602 | + << " shuffles beacuse we do not have a shape-aware lowering for " |
| 603 | + "its def: " |
| 604 | + << *Inst << '\n'; |
| 605 | + } else { |
| 606 | + // The ShapeMap has it, so it's a case where we're being lowered |
| 607 | + // before the def, and we expect that InstCombine will clean things up |
| 608 | + // afterward. |
| 609 | + } |
| 610 | + }); |
| 611 | + |
581 | 612 | return {SplitVecs}; |
582 | 613 | } |
583 | 614 |
|
@@ -1386,11 +1417,19 @@ class LowerMatrixIntrinsics { |
1386 | 1417 | ToRemove.push_back(Inst); |
1387 | 1418 | Value *Flattened = nullptr; |
1388 | 1419 | for (Use &U : llvm::make_early_inc_range(Inst->uses())) { |
1389 | | - if (!ShapeMap.contains(U.getUser())) { |
1390 | | - if (!Flattened) |
1391 | | - Flattened = Matrix.embedInVector(Builder); |
1392 | | - U.set(Flattened); |
| 1420 | + if (ShapeMap.contains(U.getUser())) |
| 1421 | + continue; |
| 1422 | + |
| 1423 | + if (!Flattened) { |
| 1424 | + Flattened = Matrix.embedInVector(Builder); |
| 1425 | + LLVM_DEBUG( |
| 1426 | + if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs() |
| 1427 | + << "flattening a " << Matrix.shape() << " matrix " << *Inst |
| 1428 | + << " because we do not have a shape-aware lowering for its " |
| 1429 | + "user: " |
| 1430 | + << *User << '\n';); |
1393 | 1431 | } |
| 1432 | + U.set(Flattened); |
1394 | 1433 | } |
1395 | 1434 | } |
1396 | 1435 |
|
|
0 commit comments