11#include " intel/include/Analysis/AxisInfo.h"
2- #include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
2+ // #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33#include " intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
44#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
55#include " mlir/IR/Operation.h"
66#include " mlir/IR/Value.h"
77#include " mlir/IR/Verifier.h"
88#include " mlir/Support/LLVM.h"
9- #include " triton/Dialect/Triton/IR/Dialect.h"
10- #include " triton/Dialect/Triton/IR/Types.h"
9+ // #include "triton/Dialect/Triton/IR/Dialect.h"
10+ // #include "triton/Dialect/Triton/IR/Types.h"
1111#include " triton/Dialect/Triton/IR/Utility.h"
1212#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1313#include " triton/Tools/StrUtil.h"
1414#include " llvm/Support/Debug.h"
1515#include " llvm/Support/raw_ostream.h"
16- #include < variant>
1716
1817#define DEBUG_TYPE " tritonintelgpu-coalesce"
1918#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
@@ -104,8 +103,8 @@ struct CoalescePass
104103 perThread = std::min<int >(perThread, std::max (numElems / numThreads, 1 ));
105104 LDBG (" perThread: " << perThread);
106105
107- if (perThread <= 1 )
108- return ;
106+ // if (perThread <= 1)
107+ // return;
109108
110109 if (!dyn_cast<triton::LoadOp>(op)) {
111110 // For ops that can result in a global memory write, we should enforce
@@ -299,7 +298,6 @@ struct CoalescePass
299298 LDBG (" Coalescing op: " << *op);
300299
301300 OpBuilder builder (op);
302- IRRewriter rewriter (builder);
303301
304302 // Convert operands
305303 // Note: for load/store with a blocked pointers argument we cannot change
@@ -312,14 +310,15 @@ struct CoalescePass
312310 if (tensorType &&
313311 !isa<ttg::SharedEncodingAttr>(tensorType.getEncoding ())) {
314312 RankedTensorType newType = getNewType (tensorType, encoding);
315- newArgs.push_back (rewriter .create <ttg::ConvertLayoutOp>(
313+ newArgs.push_back (builder .create <ttg::ConvertLayoutOp>(
316314 op->getLoc (), newType, operand));
317315 } else {
318316 assert (isa<tt::PointerType>(operand.getType ()) &&
319317 " Expecting operand to have blocked pointer type" );
320318 auto defOp = findDefiningMakeTensorPtrOp (operand);
321319 assert (defOp && " Expected a make_tensor_ptr operation" );
322320 LDBG (" Found make_tensor_ptr definition: " << *defOp);
321+ IRRewriter rewriter (builder);
323322 changeAndPropagateLayout (*defOp, encoding, rewriter);
324323 newArgs.push_back (operand);
325324 }
@@ -335,14 +334,14 @@ struct CoalescePass
335334
336335 // Construct new op with the new encoding.
337336 Operation *newOp =
338- rewriter .create (op->getLoc (), op->getName ().getIdentifier (), newArgs,
339- newTypes, op->getAttrs ());
337+ builder .create (op->getLoc (), op->getName ().getIdentifier (), newArgs,
338+ newTypes, op->getAttrs ());
340339
341340 // Cast the results back to the original layout.
342341 for (size_t i = 0 ; i < op->getNumResults (); i++) {
343342 Value newResult = newOp->getResult (i);
344343 if (newTypes[i] != op->getResultTypes ()[i]) {
345- newResult = rewriter .create <ttg::ConvertLayoutOp>(
344+ newResult = builder .create <ttg::ConvertLayoutOp>(
346345 op->getLoc (), op->getResult (i).getType (), newResult);
347346 }
348347 op->getResult (i).replaceAllUsesWith (newResult);
@@ -400,11 +399,7 @@ struct CoalescePass
400399 coalesceOp (layout, op);
401400 }
402401
403- // Verify the module's functions after the transformation.
404- for (auto op : moduleOp.getOps <tt::FuncOp>()) {
405- for (Operation &op1 : op.getOps ())
406- assert (succeeded (verify (&op1)));
407- }
402+ assert (succeeded (verify (moduleOp)) && " Module verification failed" );
408403 }
409404};
410405
0 commit comments