@@ -24,59 +24,6 @@ namespace mlir::triton::intel {
2424
2525namespace {
2626
27- scf::IfOp createIfBlock (OpBuilder &builder, Location loc, arith::CmpIOp condOp,
28- tt::LoadOp loadOp) {
29- assert (isa<RankedTensorType>(loadOp.getType ()) &&
30- " Unexpected load result type" );
31-
32- auto tensorType = cast<RankedTensorType>(loadOp.getType ());
33- assert (tensorType.getShape ().size () == 2 );
34- Type elemType = tensorType.getElementType ();
35-
36- builder.setInsertionPointAfter (loadOp);
37- auto ifOp = builder.create <scf::IfOp>(loc, tensorType, condOp, true , true );
38- loadOp->moveBefore (ifOp.thenBlock (), ifOp.thenBlock ()->end ());
39- builder.setInsertionPointAfter (loadOp);
40- builder.create <scf::YieldOp>(loc, loadOp->getResult (0 ));
41-
42- builder.setInsertionPointToStart (ifOp.elseBlock ());
43- tt::PaddingOption padding = (!loadOp.getPadding ())
44- ? tt::PaddingOption::PAD_ZERO
45- : *loadOp.getPadding ();
46- DenseElementsAttr denseAttr = nullptr ;
47- switch (padding) {
48- case tt::PaddingOption::PAD_ZERO: {
49- denseAttr = DenseElementsAttr::get (cast<ShapedType>(tensorType),
50- builder.getZeroAttr (elemType));
51- } break ;
52- case tt::PaddingOption::PAD_NAN: {
53- assert (elemType.isF128 () && " Expecting a floating point type" );
54- auto NaNVal =
55- APFloat::getNaN (cast<FloatType>(elemType).getFloatSemantics ());
56- denseAttr = DenseElementsAttr::get (cast<ShapedType>(tensorType),
57- builder.getFloatAttr (elemType, NaNVal));
58- } break ;
59- default :
60- llvm_unreachable (" Unhandled padding kind" );
61- }
62- assert (denseAttr && " Expecting a valid attribute" );
63-
64- Value other = builder.create <arith::ConstantOp>(loc, tensorType, denseAttr);
65- builder.create <scf::YieldOp>(loc, other);
66- return ifOp;
67- }
68-
69- scf::IfOp createCheckedLoad (OpBuilder &builder, arith::CmpIOp cmpOp,
70- tt::LoadOp loadOp) {
71- scf::IfOp ifOp = createIfBlock (builder, loadOp.getLoc (), cmpOp, loadOp);
72- loadOp->replaceUsesWithIf (ifOp, [&](OpOperand &operand) {
73- if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner ()))
74- return yieldOp->getParentOp () != ifOp;
75- return true ;
76- });
77- return ifOp;
78- };
79-
8027// Transform:
8128// %one = arith.constant 1 : i64
8229// %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
@@ -298,24 +245,12 @@ class FuseReshape {
298245 propagateToUsers (ptr, chain, mapping);
299246 cleanUp.insert (makeTensorPtrOp);
300247
301- // We have collapsed 2 dimensions into one, therefore we might have to
302- // materialize the boundary check for the new collapsed dimension. There
303- // are 2 possibilities:
304- // a) if the load checks only the innermost dimension, we are ok because
305- // we haven't collapsed that dimension
306- // b) if the load check the new outermost dimension the boundary check
307- // on the load is not sufficient and we have to materialize the
308- // correct boundary check. Example:
309- // OLD PTR NEW PTR
310- // shape: [20, 10, 5] -> [210, 5]
311- // strides: [50, 5, 1] -> [ 5, 1]
312- //
313- // Consider a load offset of [1, 11, 1], this access is clearly
314- // out-of-bound in dim 1 (11 > 10). However, the new offset is no
315- // longer out-of-bound (5 < 210).
248+ // We have collapsed 2 dimensions into one, therefore we need to adjust the
249+ // boundary check of the new load.
316250 auto newLoadOp =
317251 cast<tt::LoadOp>(mapping.lookup (static_cast <Operation *>(loadOp)));
318252 ArrayRef<int > boundaryCheck = newLoadOp.getBoundaryCheck ();
253+
319254 switch (boundaryCheck.size ()) {
320255 case 0 :
321256 break ;
@@ -327,26 +262,7 @@ class FuseReshape {
327262 newBoundaryCheck.push_back ((boundaryCheck[0 ] - 1 ));
328263 if (boundaryCheck.size () == 2 && (boundaryCheck[1 ] - 1 ) != 0 )
329264 newBoundaryCheck.push_back (boundaryCheck[1 ] - 1 );
330-
331265 newLoadOp.setBoundaryCheck (newBoundaryCheck);
332-
333- if (llvm::any_of (newBoundaryCheck, [&](unsigned boundIdx) {
334- return boundIdx == newOutermostDimIdx + 1 ;
335- })) {
336- unsigned oldIdx = newOutermostDimIdx + 1 ;
337- auto tensorType = cast<RankedTensorType>(loadOp.getResult ().getType ());
338- Type elemType = tensorType.getElementType ();
339- ArrayRef<int64_t > resShape = tensorType.getShape ();
340- auto add = builder.create <arith::AddIOp>(
341- loc, offsets[oldIdx],
342- builder.create <arith::ConstantIntOp>(loc, offsets[oldIdx].getType (),
343- resShape[oldIdx]));
344- auto cmpOp = builder.create <arith::CmpIOp>(
345- loc, arith::CmpIPredicate::ult, add,
346- builder.create <arith::TruncIOp>(loc, add.getResult ().getType (),
347- shapes[oldIdx]));
348- createCheckedLoad (builder, cmpOp, newLoadOp);
349- }
350266 } break ;
351267 default :
352268 // Note: while selecting candidates, we already ensured that the original
@@ -361,11 +277,13 @@ class FuseReshape {
361277 // Where:
362278 // - the reshape operation drops the outermost dimension of the operand,
363279 // which is a 3-dim tensor with outermost dimension extent equal to one
364- // - the reshape result is used by the dot operation
280+ // - the reshape result is used by a dot operation
365281 // - the reshape operation uses the result of a 3-dim load operation on a
366282 // block pointer (transitively) defined by a `make_tensor_ptr` operation
367283 // - the block pointer points to a tensor that has extent equal to 1 on the
368284 // outermost dimension
285+ // - the load operation doesn't have boundary checks on either of the
286+ // dimensions collapsed
369287 bool isCandidate (tt::ReshapeOp reshapeOp) const {
370288 assert (reshapeOp && " Expecting a valid reshape operation" );
371289
@@ -384,8 +302,7 @@ class FuseReshape {
384302 return false ;
385303 }
386304
387- // Check whether \p reshapeOp is used by a `dotOp` (directly or
388- // indirectly).
305+ // Check whether \p reshapeOp is used by a `dotOp`.
389306 auto usedByDotOp = [](tt::ReshapeOp reshapeOp) {
390307 if (!reshapeOp->hasOneUse ())
391308 return false ;
@@ -405,6 +322,7 @@ class FuseReshape {
405322 if (!usedByDotOp (reshapeOp))
406323 return false ;
407324
325+ // The reshape operation uses the result of a load operation.
408326 Operation *defOp = reshapeOp.getSrc ().getDefiningOp ();
409327 if (!defOp || !isa<tt::LoadOp>(defOp))
410328 return false ;
@@ -413,6 +331,8 @@ class FuseReshape {
413331 if (!loadOp->hasOneUse ())
414332 return false ;
415333
334+ // The load uses a 3-dim block pointer defined by a make_tensor_ptr
335+ // operation.
416336 Type ptrType = loadOp.getPtr ().getType ();
417337 if (!tt::isTensorPointerType (ptrType))
418338 return false ;
@@ -432,14 +352,14 @@ class FuseReshape {
432352 if (order.front () != tensorTy.getRank () - 1 )
433353 return false ;
434354
435- // Ensure that the innermost stride is one.
436355 unsigned innermostDimIdx = 0 ;
437- for (int i : order) {
438- if (i == 0 )
356+ for (int idx : order) {
357+ if (idx == 0 )
439358 break ;
440359 ++innermostDimIdx;
441360 }
442361
362+ // Ensure that the innermost stride is one.
443363 auto strides = makeTensorPtrOp->getStrides ();
444364 Value innermostStride = strides[innermostDimIdx];
445365 if (!innermostStride.getDefiningOp () ||
@@ -451,9 +371,9 @@ class FuseReshape {
451371 if (integerCst.value () != 1 )
452372 return false ;
453373
454- // Ensure the load boundary check doesn't check the outermost dimension.
455- return llvm::none_of (loadOp.getBoundaryCheck (),
456- [ ](int val ) { return val == 0 ; });
374+ // Ensure the load operation checks at most the innermost dimension.
375+ return llvm::all_of (loadOp.getBoundaryCheck (),
376+ [& ](int idx ) { return idx == innermostDimIdx ; });
457377 }
458378
459379 // Prune chains that cannot be handled during fusion. For example,
0 commit comments