|
1 | 1 | #include "mlir/Analysis/Liveness.h"
|
| 2 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 3 | +#include "mlir/Interfaces/ControlFlowInterfaces.h" |
2 | 4 | #include "mlir/Support/LogicalResult.h"
|
3 | 5 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
4 | 6 | #include "mlir/Transforms/Passes.h"
|
5 | 7 | #include "triton/Analysis/Allocation.h"
|
6 | 8 | #include "triton/Dialect/Triton/IR/Utility.h"
|
| 9 | +#include "triton/Dialect/TritonGPU/IR/Dialect.h" |
| 10 | +#include "triton/Dialect/TritonGPU/IR/Traits.h" |
7 | 11 | #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
8 | 12 | #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
|
9 | 13 | #include "llvm/ADT/EquivalenceClasses.h"
|
@@ -175,30 +179,92 @@ static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
|
175 | 179 | return chunk;
|
176 | 180 | }
|
177 | 181 |
|
178 |
| -static Operation *getAlloc(Value value) { |
179 |
| - while (true) { |
180 |
| - if (auto allocOp = value.getDefiningOp<TMEMAllocOp>()) |
181 |
| - return allocOp; |
182 |
| - if (auto indexOp = value.getDefiningOp<ttg::MemDescIndexOp>()) { |
183 |
| - value = indexOp.getSrc(); |
| 182 | +static SmallVector<Operation *> getAlloc(Value value) { |
| 183 | + SmallVector<Operation *> allocs; |
| 184 | + DenseSet<Value> seen; |
| 185 | + SmallVector<Value> worklist{value}; |
| 186 | + |
| 187 | + while (!worklist.empty()) { |
| 188 | + Value v = worklist.pop_back_val(); |
| 189 | + if (!seen.insert(v).second) |
184 | 190 | continue;
|
185 |
| - } |
186 |
| - if (auto reinterpOp = value.getDefiningOp<ttg::MemDescReinterpretOp>()) { |
187 |
| - value = reinterpOp.getSrc(); |
| 191 | + |
| 192 | + // Handle block arguments. |
| 193 | + if (auto arg = dyn_cast<BlockArgument>(v)) { |
| 194 | + Block *block = arg.getOwner(); |
| 195 | + Operation *parentOp = block->getParentOp(); |
| 196 | + |
| 197 | + // Handle block with predecessors. |
| 198 | + if (!block->isEntryBlock()) { |
| 199 | + for (Block *pred : block->getPredecessors()) { |
| 200 | + Operation *predOp = pred->getTerminator(); |
| 201 | + auto br = dyn_cast<BranchOpInterface>(predOp); |
| 202 | + if (!br) { |
| 203 | + llvm::report_fatal_error("unhandled branch op: " + |
| 204 | + predOp->getName().getStringRef()); |
| 205 | + } |
| 206 | + SmallVector<Attribute> operands(br->getNumOperands()); |
| 207 | + auto it = llvm::find(br->getSuccessors(), block); |
| 208 | + unsigned idx = std::distance(br->getSuccessors().begin(), it); |
| 209 | + SuccessorOperands args = br.getSuccessorOperands(idx); |
| 210 | + Value operand = |
| 211 | + args.getForwardedOperands()[arg.getArgNumber() - |
| 212 | + args.getProducedOperandCount()]; |
| 213 | + worklist.push_back(operand); |
| 214 | + } |
| 215 | + continue; |
| 216 | + } |
| 217 | + |
| 218 | + // Handle region entry arguments. |
| 219 | + if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) { |
| 220 | + worklist.push_back( |
| 221 | + wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]); |
| 222 | + } else if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) { |
| 223 | + unsigned idx = arg.getArgNumber() - 1; |
| 224 | + worklist.push_back(forOp.getYieldedValues()[idx]); |
| 225 | + worklist.push_back(forOp.getInits()[idx]); |
| 226 | + } else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp)) { |
| 227 | + unsigned idx = arg.getArgNumber(); |
| 228 | + if (arg.getParentRegion() == &whileOp.getAfter()) { |
| 229 | + worklist.push_back(whileOp.getConditionOp().getArgs()[idx]); |
| 230 | + } else { |
| 231 | + worklist.push_back(whileOp.getYieldedValues()[idx]); |
| 232 | + worklist.push_back(whileOp.getInits()[idx]); |
| 233 | + } |
| 234 | + } else { |
| 235 | + llvm::report_fatal_error( |
| 236 | + "unhandled parent op when looking for TMEM alloc: " + |
| 237 | + parentOp->getName().getStringRef()); |
| 238 | + } |
188 | 239 | continue;
|
189 | 240 | }
|
190 |
| - if (auto slice = value.getDefiningOp<TMEMSubSliceOp>()) { |
191 |
| - value = slice.getSrc(); |
192 |
| - continue; |
| 241 | + |
| 242 | + Operation *defOp = v.getDefiningOp(); |
| 243 | + unsigned idx = cast<OpResult>(v).getResultNumber(); |
| 244 | + if (isa<TMEMAllocOp>(defOp)) { |
| 245 | + allocs.push_back(defOp); |
| 246 | + } else if (defOp->hasTrait<OpTrait::MemDescViewTrait>()) { |
| 247 | + worklist.push_back(defOp->getOperand(0)); |
| 248 | + } else if (auto sliceOp = dyn_cast<TMEMSubSliceOp>(defOp)) { |
| 249 | + worklist.push_back(sliceOp.getSrc()); |
| 250 | + } else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) { |
| 251 | + worklist.push_back(selectOp.getTrueValue()); |
| 252 | + worklist.push_back(selectOp.getFalseValue()); |
| 253 | + } else if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) { |
| 254 | + worklist.push_back(ifOp.thenYield().getOperand(idx)); |
| 255 | + worklist.push_back(ifOp.elseYield().getOperand(idx)); |
| 256 | + } else if (auto forOp = dyn_cast<scf::ForOp>(defOp)) { |
| 257 | + worklist.push_back(forOp.getYieldedValues()[idx]); |
| 258 | + worklist.push_back(forOp.getInits()[idx]); |
| 259 | + } else if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) { |
| 260 | + worklist.push_back(whileOp.getConditionOp().getArgs()[idx]); |
| 261 | + } else { |
| 262 | + llvm::report_fatal_error("unhandled op when looking for TMEM alloc: " + |
| 263 | + defOp->getName().getStringRef()); |
193 | 264 | }
|
194 |
| - auto arg = dyn_cast<BlockArgument>(value); |
195 |
| - if (!arg || !isa<triton::gpu::WarpSpecializePartitionsOp>( |
196 |
| - arg.getOwner()->getParentOp())) |
197 |
| - llvm::report_fatal_error("expected to find a TMEM alloc op"); |
198 |
| - auto partitions = cast<triton::gpu::WarpSpecializePartitionsOp>( |
199 |
| - arg.getOwner()->getParentOp()); |
200 |
| - value = partitions.getParentOp().getExplicitCaptures()[arg.getArgNumber()]; |
201 | 265 | }
|
| 266 | + |
| 267 | + return allocs; |
202 | 268 | }
|
203 | 269 |
|
204 | 270 | class RowIdConstraints {
|
@@ -245,8 +311,11 @@ allocateTMem(Operation *parentOp,
|
245 | 311 | if (allocSize.numRows == 64) {
|
246 | 312 | // HW restriction, the A alloc and accumulator needs to be in the same
|
247 | 313 | // rows.
|
248 |
| - rowIdConstraints.joinOps(getAlloc(mmaOp.getA()), |
249 |
| - getAlloc(mmaOp.getAccumulator())); |
| 314 | + SmallVector<Operation *> lhsAllocs = getAlloc(mmaOp.getA()); |
| 315 | + SmallVector<Operation *> accAllocs = getAlloc(mmaOp.getAccumulator()); |
| 316 | + for (Operation *lhsAlloc : lhsAllocs) |
| 317 | + for (Operation *accAlloc : accAllocs) |
| 318 | + rowIdConstraints.joinOps(lhsAlloc, accAlloc); |
250 | 319 | } else {
|
251 | 320 | // TODO: we need to handle cases where the format is blockM and we
|
252 | 321 | // have multiple blocks.
|
|
0 commit comments