22#include " intel/include/Utils/DefUseChain.h"
33#include " intel/include/Utils/Utility.h"
44#include " mlir/Dialect/Arith/IR/Arith.h"
5+ #include " mlir/Dialect/SCF/IR/SCF.h"
56#include " mlir/IR/BuiltinAttributes.h"
67#include " mlir/IR/BuiltinTypes.h"
8+ #include " mlir/IR/IRMapping.h"
79#include " mlir/IR/Verifier.h"
810#include " triton/Dialect/Triton/IR/Dialect.h"
911#include " llvm/ADT/STLExtras.h"
@@ -22,6 +24,59 @@ namespace mlir::triton::intel {
2224
2325namespace {
2426
27+ scf::IfOp createIfBlock (OpBuilder &builder, Location loc, arith::CmpIOp condOp,
28+ tt::LoadOp loadOp) {
29+ assert (isa<RankedTensorType>(loadOp.getType ()) &&
30+ " Unexpected load result type" );
31+
32+ auto tensorType = cast<RankedTensorType>(loadOp.getType ());
33+ assert (tensorType.getShape ().size () == 2 );
34+ Type elemType = tensorType.getElementType ();
35+
36+ builder.setInsertionPointAfter (loadOp);
37+ auto ifOp = builder.create <scf::IfOp>(loc, tensorType, condOp, true , true );
38+ loadOp->moveBefore (ifOp.thenBlock (), ifOp.thenBlock ()->end ());
39+ builder.setInsertionPointAfter (loadOp);
40+ builder.create <scf::YieldOp>(loc, loadOp->getResult (0 ));
41+
42+ builder.setInsertionPointToStart (ifOp.elseBlock ());
43+ tt::PaddingOption padding = (!loadOp.getPadding ())
44+ ? tt::PaddingOption::PAD_ZERO
45+ : *loadOp.getPadding ();
46+ DenseElementsAttr denseAttr = nullptr ;
47+ switch (padding) {
48+ case tt::PaddingOption::PAD_ZERO: {
49+ denseAttr = DenseElementsAttr::get (cast<ShapedType>(tensorType),
50+ builder.getZeroAttr (elemType));
51+ } break ;
52+ case tt::PaddingOption::PAD_NAN: {
53+ assert (elemType.isF128 () && " Expecting a floating point type" );
54+ auto NaNVal =
55+ APFloat::getNaN (cast<FloatType>(elemType).getFloatSemantics ());
56+ denseAttr = DenseElementsAttr::get (cast<ShapedType>(tensorType),
57+ builder.getFloatAttr (elemType, NaNVal));
58+ } break ;
59+ default :
60+ llvm_unreachable (" Unhandled padding kind" );
61+ }
62+ assert (denseAttr && " Expecting a valid attribute" );
63+
64+ Value other = builder.create <arith::ConstantOp>(loc, tensorType, denseAttr);
65+ builder.create <scf::YieldOp>(loc, other);
66+ return ifOp;
67+ }
68+
69+ scf::IfOp createCheckedLoad (OpBuilder &builder, arith::CmpIOp condOp,
70+ tt::LoadOp loadOp) {
71+ scf::IfOp ifOp = createIfBlock (builder, loadOp.getLoc (), condOp, loadOp);
72+ loadOp->replaceUsesWithIf (ifOp, [&](OpOperand &operand) {
73+ if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner ()))
74+ return yieldOp->getParentOp () != ifOp;
75+ return true ;
76+ });
77+ return ifOp;
78+ };
79+
2580// Transform:
2681// %one = arith.constant 1 : i64
2782// %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
@@ -87,7 +142,7 @@ class FuseReshape {
87142
88143 LLVM_DEBUG (llvm::dbgs () << " [Before fusion]:\n " << manager << " \n " );
89144
90- // Fuse tt.LoadOp->tt.reshapeOp operations.
145+ // Fuse tt.LoadOp->tt.ReshapeOp operations.
91146 fuse (manager.getChains ());
92147
93148 // Remove operations that are no longer used.
@@ -238,14 +293,57 @@ class FuseReshape {
238293
239294 LLVM_DEBUG (llvm::dbgs () << " newMakeTensorPtrOp:\n " << ptr << " \n " );
240295
241- // Adjust the boundary check on the load operation.
242- ArrayRef<int > boundaryCheck = loadOp.getBoundaryCheck ();
243- assert (boundaryCheck.size () == 2 && " Expecting a 2-dim load" );
244- loadOp.setBoundaryCheck ({boundaryCheck[0 ] - 1 , boundaryCheck[1 ] - 1 });
245-
246296 // Propagate the new ptr through the def-use chain.
247- propagateToUsers (ptr, chain);
297+ IRMapping mapping;
298+ propagateToUsers (ptr, chain, mapping);
248299 cleanUp.insert (makeTensorPtrOp);
300+
301+ // We have collapsed 2 dimensions into one, therefore we might have to
302+ // materialize the boundary check for the new collapsed dimension. There
303+ // are 2 possibilities:
304+ // a) if the load checks only the innermost dimension, we are ok because
305+ // we haven't collapsed that dimension
306+ // b) if the load check the new outermost dimension the boundary check
307+ // on the load is not sufficient and we have to materialize the
308+ // correct boundary check. Example:
309+ // OLD PTR NEW PTR
310+ // shape: [20, 10, 5] -> [210, 5]
311+ // strides: [50, 5, 1] -> [ 5, 1]
312+ //
313+ // Consider a load offset of [1, 11, 1], this access is clearly
314+ // out-of-bound in dim 1 (11 > 10). However, the new offset is not
315+ // longer out-of-bound (5 < 210).
316+ auto newLoadOp =
317+ cast<tt::LoadOp>(mapping.lookup (static_cast <Operation *>(loadOp)));
318+ ArrayRef<int > boundaryCheck = newLoadOp.getBoundaryCheck ();
319+ switch (boundaryCheck.size ()) {
320+ case 0 :
321+ break ;
322+ case 1 :
323+ // intentional fall-through
324+ case 2 : {
325+ SmallVector<int > newBoundaryCheck{boundaryCheck[0 ] - 1 };
326+ if (boundaryCheck.size () == 2 )
327+ newBoundaryCheck.push_back (boundaryCheck[1 ] - 1 );
328+ newLoadOp.setBoundaryCheck ({newBoundaryCheck});
329+
330+ if (llvm::any_of (newBoundaryCheck, [&](unsigned boundIdx) {
331+ return boundIdx == newOutermostDimIdx;
332+ })) {
333+ Value lhs = newOffsets[newOutermostDimIdx];
334+ Value rhs = shapes[newOutermostDimIdx + 1 ];
335+ auto cmpOp = builder.create <arith::CmpIOp>(
336+ loc, arith::CmpIPredicate::ult, lhs,
337+ builder.create <arith::TruncIOp>(loc, lhs.getType (), rhs));
338+ createCheckedLoad (builder, cmpOp, newLoadOp);
339+ }
340+ } break ;
341+ default :
342+ // Note: while selecting candidates, we already ensured that the original
343+ // load's boundary check doesn't check dim zero. So its max rank should
344+ // be 2.
345+ assert (boundaryCheck.size () != 3 && " Unexpected boundary check rank" );
346+ }
249347 }
250348
251349 // Candidate is of the form:
@@ -276,7 +374,8 @@ class FuseReshape {
276374 return false ;
277375 }
278376
279- // Check whether \p reshapeOp is used by a `dotOp` (directly or indirectly).
377+ // Check whether \p reshapeOp is used by a `dotOp` (directly or
378+ // indirectly).
280379 auto usedByDotOp = [](tt::ReshapeOp reshapeOp) {
281380 if (!reshapeOp->hasOneUse ())
282381 return false ;
@@ -347,10 +446,10 @@ class FuseReshape {
347446 [](int val) { return val == 0 ; });
348447 }
349448
350- // Prune chains that cannot be handled during fusion. For example, operations
351- // in the def-use chain should have a single user, except in special
352- // circumstances (e.g. the root operation of a chain might have more than one
353- // user).
449+ // Prune chains that cannot be handled during fusion. For example,
450+ // operations in the def-use chain should have a single user, except in
451+ // special circumstances (e.g. the root operation of a chain might have more
452+ // than one user).
354453 void pruneInvalid (DefUseChains &chains) const {
355454 assert (!chains.empty () && " Expecting at least one candidate chain" );
356455
@@ -368,11 +467,10 @@ class FuseReshape {
368467
369468 // Determine whether all operations in the given def-use chain have a single
370469 // user.
371- // Note: we allow an operation in the def-use chain to have an additional user
372- // if the operation is in a for loop, and the additional user is the loop
373- // yield operation, provided that the result yielded is not used after the
374- // loop.
375- // Example:
470+ // Note: we allow an operation in the def-use chain to have an additional
471+ // user if the operation is in a for loop, and the additional user is the
472+ // loop yield operation, provided that the result yielded is not used after
473+ // the loop. Example:
376474 // make_tensor_ptr -> advance -> load (OK)
377475 // make_tensor_ptr -> for init_arg -> advance -> load (OK)
378476 // -> yield (OK)
@@ -461,7 +559,8 @@ class FuseReshape {
461559 }
462560
463561 // Propagate \p newVal to operations in the given def-use chain.
464- void propagateToUsers (Value newVal, const DefUseChain &chain) {
562+ void propagateToUsers (Value newVal, const DefUseChain &chain,
563+ IRMapping &mapping) {
465564 auto start = cast<tt::MakeTensorPtrOp>(chain.getStart ());
466565 Operation *end = chain.getEnd ();
467566 auto it = llvm::find_if (start->getUsers (), [&](Operation *user) {
@@ -470,22 +569,22 @@ class FuseReshape {
470569 assert (it != start->getUsers ().end () && " Expecting valid iterator" );
471570
472571 Operation *nextOp = *it;
473- propagateToUser (newVal, start.getResult (), nextOp, end);
572+ propagateToUser (newVal, start.getResult (), nextOp, end, mapping );
474573 }
475574
476575 // Propagate \p newVal to users of \p origOp.
477576 void propagateToUsers (Value newVal, Value origVal, Operation *origOp,
478- Operation *sentinel) {
577+ Operation *sentinel, IRMapping &mapping ) {
479578 assert (origOp && sentinel && " Expecting valid operations" );
480579 const SmallVector<Operation *> users (origOp->getUsers ());
481580 for (Operation *user : users)
482- propagateToUser (newVal, origVal, user, sentinel);
581+ propagateToUser (newVal, origVal, user, sentinel, mapping );
483582 }
484583
485584 // If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise
486585 // terminate the propagation.
487586 void propagateToUser (Value newVal, Value origVal, Operation *user,
488- Operation *sentinel) {
587+ Operation *sentinel, IRMapping &mapping ) {
489588 assert (user && sentinel && " Expecting valid operations" );
490589 assert (llvm::is_contained (origVal.getUsers (), user) && " Invalid usage" );
491590
@@ -515,11 +614,13 @@ class FuseReshape {
515614 SmallVector<Value> newOffsets (advanceOp.getOffsets ().drop_front ());
516615 auto newAdvanceOp = rewriter.create <tt::AdvanceOp>(loc, newVal.getType (),
517616 newVal, newOffsets);
617+ mapping.map (static_cast <Operation *>(advanceOp),
618+ static_cast <Operation *>(newAdvanceOp));
518619 LLVM_DEBUG (llvm::dbgs ().indent (2 )
519620 << " newAdvanceOp: " << newAdvanceOp << " \n " );
520621 cleanUp.insert (advanceOp);
521622 return propagateToUsers (newAdvanceOp, advanceOp.getResult (), advanceOp,
522- sentinel);
623+ sentinel, mapping );
523624 }
524625
525626 if (auto loadOp = dyn_cast<tt::LoadOp>(user)) {
@@ -529,10 +630,12 @@ class FuseReshape {
529630 loadOp.getBoundaryCheckAttr (), loadOp.getPaddingAttr (),
530631 loadOp.getCache (), loadOp.getEvict (), loadOp.getIsVolatile ());
531632 newLoadOp->setAttrs (loadOp->getAttrs ());
532-
633+ mapping.map (static_cast <Operation *>(loadOp),
634+ static_cast <Operation *>(newLoadOp));
533635 LLVM_DEBUG (llvm::dbgs ().indent (2 ) << " newLoadOp: " << newLoadOp << " \n " );
534636 cleanUp.insert (loadOp);
535- return propagateToUsers (newLoadOp, loadOp.getResult (), loadOp, sentinel);
637+ return propagateToUsers (newLoadOp, loadOp.getResult (), loadOp, sentinel,
638+ mapping);
536639 }
537640
538641 if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
@@ -553,13 +656,13 @@ class FuseReshape {
553656 }
554657
555658 if (auto forOp = dyn_cast<scf::ForOp>(user))
556- return propagateToLoop (newVal, origVal, forOp, sentinel);
659+ return propagateToLoop (newVal, origVal, forOp, sentinel, mapping );
557660
558661 llvm_unreachable (" Unexpected kind of user" );
559662 }
560663
561664 void propagateToLoop (Value newVal, Value origVal, LoopLikeOpInterface loopOp,
562- Operation *sentinel) {
665+ Operation *sentinel, IRMapping &mapping ) {
563666 assert (sentinel && sentinel != loopOp && " Unexpected sentinel kind" );
564667 LLVM_DEBUG ({
565668 llvm::dbgs () << " In " << __func__ << " \n " ;
@@ -573,7 +676,7 @@ class FuseReshape {
573676 rgnInitArg.setType (initArg.get ().getType ());
574677 const SmallVector<Operation *> users (rgnInitArg.getUsers ());
575678 for (Operation *user : users)
576- propagateToUser (rgnInitArg, rgnInitArg, user, sentinel);
679+ propagateToUser (rgnInitArg, rgnInitArg, user, sentinel, mapping );
577680 }
578681 }
579682 }
0 commit comments