@@ -24,23 +24,6 @@ namespace mlir::triton::intel {
2424
2525namespace {
2626
27- // Transform:
28- // %one = arith.constant 1 : i64
29- // %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
30- // [%q_25, %q_26, %one], [%offset_5, %offset_1_13, %q_28]
31- // {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
32- // %load = tt.load %ptr {boundaryCheck = array<i32: 1, 2>}
33- // : !tt.ptr<tensor<1x512x64xf16>>
34- // %a = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
35- // tt.dot(%a, ...)
36- // into:
37- // %one = arith.constant 1 : i64
38- // %ptr = make_tensor_ptr %q_view, [%q_23, %q_24], [%q_26, %one],
39- // [%offset_1_13, %offset_5*%q_25+%q_28]
40- // {order = array<i32: 1, 0>} : <tensor<512x64xf16>>
41- // %a = tt.load %ptr {boundaryCheck = array<i32: 0, 1>}
42- // : !tt.ptr<tensor<512x64xf16>>
43- // tt.dot(%a, ...)
4427class FuseReshape {
4528private:
4629 SmallPtrSet<Operation *, 8 > cleanUp;
@@ -250,25 +233,10 @@ class FuseReshape {
250233 auto newLoadOp =
251234 cast<tt::LoadOp>(mapping.lookup (static_cast <Operation *>(loadOp)));
252235 ArrayRef<int > boundaryCheck = newLoadOp.getBoundaryCheck ();
253-
254- switch (boundaryCheck.size ()) {
255- case 0 :
256- break ;
257- case 1 :
258- // intentional fall-through
259- case 2 : {
260- SmallVector<int > newBoundaryCheck;
261- if ((boundaryCheck[0 ] - 1 ) != 0 )
262- newBoundaryCheck.push_back ((boundaryCheck[0 ] - 1 ));
263- if (boundaryCheck.size () == 2 && (boundaryCheck[1 ] - 1 ) != 0 )
264- newBoundaryCheck.push_back (boundaryCheck[1 ] - 1 );
265- newLoadOp.setBoundaryCheck (newBoundaryCheck);
266- } break ;
267- default :
268- // Note: while selecting candidates, we already ensured that the original
269- // load's boundary check doesn't check dim zero. So its max rank should
270- // be 2.
271- assert (boundaryCheck.size () != 3 && " Unexpected boundary check rank" );
236+ for (int idx : boundaryCheck) {
237+ assert (idx == (newInnermostDimIdx + 1 ) &&
238+ " Unexpected boundary check idx" );
239+ newLoadOp.setBoundaryCheck ({static_cast <int >(newInnermostDimIdx)});
272240 }
273241 }
274242
@@ -359,18 +327,6 @@ class FuseReshape {
359327 ++innermostDimIdx;
360328 }
361329
362- // Ensure that the innermost stride is one.
363- auto strides = makeTensorPtrOp->getStrides ();
364- Value innermostStride = strides[innermostDimIdx];
365- if (!innermostStride.getDefiningOp () ||
366- !isa<arith::ConstantIntOp>(innermostStride.getDefiningOp ()))
367- return false ;
368-
369- auto integerCst =
370- cast<arith::ConstantIntOp>(innermostStride.getDefiningOp ());
371- if (integerCst.value () != 1 )
372- return false ;
373-
374330 // Ensure the load operation checks at most the innermost dimension.
375331 return llvm::all_of (loadOp.getBoundaryCheck (),
376332 [&](int idx) { return idx == innermostDimIdx; });
0 commit comments