|
10 | 10 | #include "llvm/Support/Debug.h" |
11 | 11 |
|
12 | 12 | #define DEBUG_TYPE "tritonintelgpu-pipeline" |
| 13 | +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 14 | +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
13 | 15 |
|
14 | 16 | using namespace mlir; |
15 | 17 | namespace tt = mlir::triton; |
@@ -55,30 +57,25 @@ static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) { |
55 | 57 | if (!tensorType) |
56 | 58 | return nullptr; |
57 | 59 |
|
58 | | - if (isa<ttg::SharedEncodingAttr>(tensorType.getEncoding())) |
59 | | - return allTransitiveUsesHaveDotEncoding(res); |
60 | | - |
61 | | - return llvm::dyn_cast_or_null<ttg::DotOperandEncodingAttr>( |
62 | | - tensorType.getEncoding()); |
| 60 | + Attribute layout = tensorType.getEncoding(); |
| 61 | + return isa<ttg::SharedEncodingAttr, ttg::BlockedEncodingAttr>(layout) |
| 62 | + ? allTransitiveUsesHaveDotEncoding(res) |
| 63 | + : llvm::dyn_cast_or_null<ttg::DotOperandEncodingAttr>(layout); |
63 | 64 | } |
64 | 65 |
|
65 | 66 | /// If all the transitive uses of the given value are used by a convert to the |
66 | 67 | /// same dot operand encoding, return the encoding. Otherwise return nullptr. |
67 | 68 | static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) { |
68 | 69 | ttg::DotOperandEncodingAttr attr{nullptr}; |
69 | | - LLVM_DEBUG(llvm::dbgs() << "Checking users of " << val << "\n"); |
| 70 | + LDBG("Checking users of " << val); |
70 | 71 | for (Operation *user : val.getUsers()) { |
71 | | - ttg::DotOperandEncodingAttr dotAttr; |
72 | | - if (isa<triton::DotOp>(user)) { |
73 | | - auto tensorType = cast<RankedTensorType>(val.getType()); |
74 | | - dotAttr = dyn_cast<ttg::DotOperandEncodingAttr>(tensorType.getEncoding()); |
75 | | - } else { |
76 | | - dotAttr = getDotEncodingFromUser(user); |
77 | | - } |
| 72 | + ttg::DotOperandEncodingAttr dotAttr = |
| 73 | + isa<triton::DotOp>(user) |
| 74 | + ? dyn_cast<ttg::DotOperandEncodingAttr>( |
| 75 | + cast<RankedTensorType>(val.getType()).getEncoding()) |
| 76 | + : getDotEncodingFromUser(user); |
78 | 77 | if (!dotAttr || (attr != nullptr && attr != dotAttr)) { |
79 | | - LLVM_DEBUG({ |
80 | | - llvm::dbgs() << "no dot attribute found for user: " << user << "\n"; |
81 | | - }); |
| 78 | + LDBG("no dot attribute found for user: " << *user); |
82 | 79 | return nullptr; |
83 | 80 | } |
84 | 81 | attr = dotAttr; |
@@ -292,14 +289,14 @@ bool ttgi::preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, |
292 | 289 | SmallVector<LoadDotOperand> loads; |
293 | 290 | collectOpsToPipeline(forOp, loads, supportRegularPtr); |
294 | 291 | if (loads.empty()) { |
295 | | - LLVM_DEBUG(llvm::dbgs() << "No loads to pipeline\n"); |
| 292 | + LDBG("No loads to pipeline"); |
296 | 293 | return false; |
297 | 294 | } |
298 | 295 |
|
299 | 296 | LLVM_DEBUG({ |
300 | | - llvm::dbgs() << "Loads to pipeline:\n"; |
| 297 | + DBGS() << "Loads to pipeline:\n"; |
301 | 298 | for (const LoadDotOperand &load : loads) |
302 | | - llvm::dbgs() << " " << *load.load << "\n"; |
| 299 | + DBGS() << " " << *load.load << "\n"; |
303 | 300 | }); |
304 | 301 |
|
305 | 302 | // 2. Create the prefetching operations for the loads collected. |
|
0 commit comments