Skip to content

Commit bb9b4c3

Browse files
committed
Address code review comments
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent a40844b commit bb9b4c3

File tree

5 files changed

+14
-23
lines changed

5 files changed

+14
-23
lines changed

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def make_ttgir(mod, metadata, opt, properties):
235235
intel.passes.ttgpuir.add_accelerate_matmul(pm)
236236
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
237237
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
238-
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
238+
# intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
239239
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
240240

241241
intel.passes.ttgpuir.add_coalesce(pm)

third_party/intel/include/Analysis/AxisInfo.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212

1313
#include <optional>
14-
#include <type_traits>
1514

1615
namespace mlir::triton::intel {
1716

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ inline unsigned getNumElementsPerThread(
4646
inline bool applyTransposedReduction() {
4747
return tools::getBoolEnv("TRITON_INTEL_REDUCE_TRANSPOSE");
4848
}
49-
5049
} // namespace mlir::triton::gpu::intel
5150

5251
#endif // TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1+
#include "intel/include/Analysis/AxisInfo.h"
12
#include "mlir/Analysis/DataFlowFramework.h"
23
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
35
#include "llvm/Support/Debug.h"
46
#include "llvm/Support/raw_ostream.h"
57

6-
#include "intel/include/Analysis/AxisInfo.h"
7-
#include "mlir/IR/BuiltinTypes.h"
8-
#include "triton/Dialect/Triton/IR/Dialect.h"
9-
108
#define DEBUG_TYPE "intel-axis-info"
119
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1210
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
#include "intel/include/Analysis/AxisInfo.h"
2-
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
2+
// #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
44
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
55
#include "mlir/IR/Operation.h"
66
#include "mlir/IR/Value.h"
77
#include "mlir/IR/Verifier.h"
88
#include "mlir/Support/LLVM.h"
9-
#include "triton/Dialect/Triton/IR/Dialect.h"
10-
#include "triton/Dialect/Triton/IR/Types.h"
9+
// #include "triton/Dialect/Triton/IR/Dialect.h"
10+
// #include "triton/Dialect/Triton/IR/Types.h"
1111
#include "triton/Dialect/Triton/IR/Utility.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1313
#include "triton/Tools/StrUtil.h"
1414
#include "llvm/Support/Debug.h"
1515
#include "llvm/Support/raw_ostream.h"
16-
#include <variant>
1716

1817
#define DEBUG_TYPE "tritonintelgpu-coalesce"
1918
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -104,8 +103,8 @@ struct CoalescePass
104103
perThread = std::min<int>(perThread, std::max(numElems / numThreads, 1));
105104
LDBG("perThread: " << perThread);
106105

107-
if (perThread <= 1)
108-
return;
106+
// if (perThread <= 1)
107+
// return;
109108

110109
if (!dyn_cast<triton::LoadOp>(op)) {
111110
// For ops that can result in a global memory write, we should enforce
@@ -299,7 +298,6 @@ struct CoalescePass
299298
LDBG("Coalescing op: " << *op);
300299

301300
OpBuilder builder(op);
302-
IRRewriter rewriter(builder);
303301

304302
// Convert operands
305303
// Note: for load/store with a blocked pointers argument we cannot change
@@ -312,14 +310,15 @@ struct CoalescePass
312310
if (tensorType &&
313311
!isa<ttg::SharedEncodingAttr>(tensorType.getEncoding())) {
314312
RankedTensorType newType = getNewType(tensorType, encoding);
315-
newArgs.push_back(rewriter.create<ttg::ConvertLayoutOp>(
313+
newArgs.push_back(builder.create<ttg::ConvertLayoutOp>(
316314
op->getLoc(), newType, operand));
317315
} else {
318316
assert(isa<tt::PointerType>(operand.getType()) &&
319317
"Expecting operand to have blocked pointer type");
320318
auto defOp = findDefiningMakeTensorPtrOp(operand);
321319
assert(defOp && "Expected a make_tensor_ptr operation");
322320
LDBG("Found make_tensor_ptr definition: " << *defOp);
321+
IRRewriter rewriter(builder);
323322
changeAndPropagateLayout(*defOp, encoding, rewriter);
324323
newArgs.push_back(operand);
325324
}
@@ -335,14 +334,14 @@ struct CoalescePass
335334

336335
// Construct new op with the new encoding.
337336
Operation *newOp =
338-
rewriter.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
339-
newTypes, op->getAttrs());
337+
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
338+
newTypes, op->getAttrs());
340339

341340
// Cast the results back to the original layout.
342341
for (size_t i = 0; i < op->getNumResults(); i++) {
343342
Value newResult = newOp->getResult(i);
344343
if (newTypes[i] != op->getResultTypes()[i]) {
345-
newResult = rewriter.create<ttg::ConvertLayoutOp>(
344+
newResult = builder.create<ttg::ConvertLayoutOp>(
346345
op->getLoc(), op->getResult(i).getType(), newResult);
347346
}
348347
op->getResult(i).replaceAllUsesWith(newResult);
@@ -400,11 +399,7 @@ struct CoalescePass
400399
coalesceOp(layout, op);
401400
}
402401

403-
// Verify the module's functions after the transformation.
404-
for (auto op : moduleOp.getOps<tt::FuncOp>()) {
405-
for (Operation &op1 : op.getOps())
406-
assert(succeeded(verify(&op1)));
407-
}
402+
assert(succeeded(verify(moduleOp)) && "Module verification failed");
408403
}
409404
};
410405

0 commit comments

Comments
 (0)