|
8 | 8 | #include "triton-shared/Analysis/MaskAnalysis.h" |
9 | 9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | 10 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 11 | +#include "mlir/IR/Builders.h" |
11 | 12 | #include "mlir/Support/LogicalResult.h" |
12 | 13 |
|
13 | 14 | #include "triton-shared/Analysis/OpFoldResultUtils.h" |
@@ -452,32 +453,65 @@ LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc, |
452 | 453 | return failure(); |
453 | 454 | } |
454 | 455 |
|
| 456 | + // This is a bit of a hack!! |
| 457 | + // |
| 458 | + // The offset (MaskState::start) of a mask can now depend on a loop's |
| 459 | + // iter-arg like the following example: |
| 460 | + // |
| 461 | + // idx = offset + tl.arange(0, 4) |
| 462 | + // for it in range(n): |
| 463 | + // mask = idx < size |
| 464 | + // x = tl.load(x_ptr + idx, mask=mask) |
| 465 | + // tl.store(y_ptr + idx, x, mask=mask) |
| 466 | + // idx += 4 |
| 467 | + // |
| 468 | + // See |
| 469 | + // test/Conversion/TritonToStructured/mask_loop_iter_arg.mlir and |
| 470 | + // and |
| 471 | + // python/examples/test_mask_loop_iter_arg.py |
| 472 | + // for IR and full triton code. |
| 473 | + // |
| 474 | + // To support this case, we first make the following assumptions: |
| 475 | + // - MaskAnalysis is runs after PtrAnalysis's prepass finishes, which means |
| 476 | + // the offset for the load and store pointers have already been set up |
| 477 | + // at `argIndex + 1` |
| 478 | + // - The tensor of indices used by the load / store and the mask are the same |
| 479 | + // (see above where `idx` appears in both the mask and the pointer |
| 480 | + // arithmetic). This allows us to use the offset at `argIndex + 1` in the |
| 481 | + // above assumption. In the future, to make this more robust, we need to |
| 482 | + // verify that the offsets are indeed the same. Or alternatively, make sure |
| 483 | + // to generate a separate start and end offset for each mask that is being |
| 484 | + // updated in loops. |
| 485 | + // |
| 486 | + // Now to generate the mask state in each loop iteration, we first construct |
| 487 | + // the mask state *before* coming into the loop by parsing the init-arg. A |
| 488 | + // mask dimensions stay consistent throughout each loop iteration, but its |
| 489 | + // starting offset (`MaskState::start`) will change. So to construct the mask |
| 490 | + // state for each iteration, we need to make MaskState::state be the offset |
| 491 | + // iter-arg at `argIndex + 1`. Now for `MaskState::end`, we can first compute |
| 492 | + // the distance between `start` and `end` before coming into the loop, then |
| 493 | + // use this distance to compute the actual `end` in each loop. |
455 | 494 | auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it); |
456 | 495 | auto initArg = forOp.getInitArgs()[argIndex]; |
457 | 496 | if (auto getStateOp = initArg.getDefiningOp<tts::GetStructuredStateOp>()) { |
458 | 497 | auto tritonValue = getStateOp->getOperand(0); |
459 | 498 | MaskState lhsState; |
460 | | - if (failed(lhsState.parse(tritonValue, loc, builder))) { |
461 | | - return failure(); |
462 | | - } |
463 | 499 |
|
464 | | - // This is a bit of a hack!! |
465 | | - // |
466 | | - // The offsets and dimensions of a MaskState can now depend on a loop's |
467 | | - // iter-arg. |
468 | | - // |
469 | | - // Because the PtrAnalysis's pre-pass already sets up the offsets, |
470 | | - // we can create a new MaskState for each loop iteration by adding the |
471 | | - // original MaskState with the current iter-arg, which is at `argIndex + |
472 | | - // 1`. |
473 | | - // |
474 | | - // This will not work for nested loop scenarios, which would need a |
475 | | - // more robust implementation. |
476 | | - if (failed(this->addStateScalar( |
477 | | - lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) { |
478 | | - return failure(); |
| 500 | + { |
| 501 | + OpBuilder::InsertionGuard guard(builder); |
| 502 | + // Make sure all ops generated for the mask state are inserted before |
| 503 | + // the current loop |
| 504 | + builder.setInsertionPoint(forOp); |
| 505 | + if (failed(lhsState.parse(tritonValue, loc, builder))) { |
| 506 | + return failure(); |
| 507 | + } |
479 | 508 | } |
480 | 509 |
|
| 510 | + auto dist = subOFRs(lhsState.end, lhsState.start, loc, builder); |
| 511 | + this->start = forOp.getRegionIterArg(argIndex + 1); |
| 512 | + this->end = addOFRs(this->start, dist, loc, builder); |
| 513 | + this->dims = lhsState.dims; |
| 514 | + |
481 | 515 | return success(); |
482 | 516 | } |
483 | 517 |
|
|
0 commit comments