Skip to content

Commit 5c1a0ac

Browse files
committed
address review comments
1 parent 99b82c6 commit 5c1a0ac

File tree

2 files changed

+90
-41
lines changed

2 files changed

+90
-41
lines changed

test/TritonIntelGPU/rewrite-tensor-pointer.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,53 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
335335
tt.return
336336
}
337337
}
338+
339+
// -----
340+
341+
// COM: Case 5:
342+
// COM: Check that a make tensor ptr with no loads is handled properly
343+
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
344+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
345+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
346+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
347+
// CHECK: @matmul_kernel_with_block_pointers
348+
%c4_i32 = arith.constant 4 : i32
349+
%c256_i32 = arith.constant 256 : i32
350+
%c1024_i64 = arith.constant 1024 : i64
351+
%c5120_i64 = arith.constant 5120 : i64
352+
%c1_i64 = arith.constant 1 : i64
353+
%c0_i32 = arith.constant 0 : i32
354+
%c4096_i64 = arith.constant 4096 : i64
355+
%c32_i32 = arith.constant 32 : i32
356+
%c64_i32 = arith.constant 64 : i32
357+
%c5120_i32 = arith.constant 5120 : i32
358+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
359+
%0 = tt.get_program_id x : i32
360+
%1 = arith.divsi %0, %c64_i32 : i32
361+
%2 = arith.muli %1, %c4_i32 : i32
362+
%3 = arith.subi %c4_i32, %2 : i32
363+
%4 = arith.minsi %3, %c4_i32 : i32
364+
%5 = arith.remsi %0, %4 : i32
365+
%6 = arith.addi %2, %5 : i32
366+
%7 = arith.remsi %0, %c64_i32 : i32
367+
%8 = arith.divsi %7, %4 : i32
368+
%9 = arith.muli %6, %c256_i32 : i32
369+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
370+
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
371+
%11 = arith.muli %8, %c256_i32 : i32
372+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
373+
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
374+
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
375+
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
376+
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
377+
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
378+
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
379+
scf.yield %arg4, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
380+
}
381+
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
382+
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
383+
// CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>
384+
tt.store %14, %15 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
385+
tt.return
386+
}
387+
}

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -675,52 +675,51 @@ class TritonIntelGPURewriteTensorPointerPass
675675
ModuleOp mod = getOperation();
676676

677677
DenseSet<Operation *> tensorPointersToRemove;
678-
mod.walk([&](Operation *op) {
679-
if (isa<tt::MakeTensorPtrOp>(op)) {
680-
DenseSet<Operation *> workingSet;
678+
mod.walk([&](tt::MakeTensorPtrOp makeTensorPtrOp) {
679+
DenseSet<Operation *> workingSet;
681680

682-
auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op);
683-
LDBG("Considering: " << *op);
684-
Value result = op->getResult(0);
685-
for (auto user : result.getUsers()) {
686-
workingSet.insert(user);
687-
}
688-
while (!workingSet.empty()) {
689-
auto crtOpItr = workingSet.begin();
690-
auto crtOp = *crtOpItr;
691-
LDBG("Processing op: " << *crtOp);
692-
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
693-
if (shouldRemove(makeTensorPtrOp,
694-
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
695-
/*isBlockLoad=*/
696-
isa<tt::LoadOp>(crtOp) &&
697-
crtOp->hasAttr(ttgi::TritonIntelGPUDialect::
698-
getBlockIOAttrName()))) {
699-
tensorPointersToRemove.insert(makeTensorPtrOp);
700-
}
701-
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
702-
for (auto [arg, blockArg] :
703-
llvm::zip(forOp.getInitArgs(),
704-
forOp.getBody()->getArguments().drop_front(
705-
forOp.getNumInductionVars()))) {
706-
if (arg == makeTensorPtrOp) {
707-
// add users of block arg
708-
for (auto user : blockArg.getUsers()) {
709-
workingSet.insert(user);
710-
}
681+
LDBG("Considering: " << makeTensorPtrOp);
682+
Value result = makeTensorPtrOp.getResult();
683+
for (auto user : result.getUsers()) {
684+
workingSet.insert(user);
685+
}
686+
while (!workingSet.empty()) {
687+
auto crtOpItr = workingSet.begin();
688+
auto crtOp = *crtOpItr;
689+
LDBG("Processing op: " << *crtOp);
690+
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
691+
if (shouldRemove(
692+
makeTensorPtrOp,
693+
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
694+
/*isBlockLoad=*/
695+
isa<tt::LoadOp>(crtOp) &&
696+
crtOp->hasAttr(
697+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
698+
tensorPointersToRemove.insert(makeTensorPtrOp);
699+
return;
700+
}
701+
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
702+
for (auto [arg, blockArg] :
703+
llvm::zip(forOp.getInitArgs(),
704+
forOp.getBody()->getArguments().drop_front(
705+
forOp.getNumInductionVars()))) {
706+
if (arg == makeTensorPtrOp) {
707+
// add users of block arg
708+
for (auto user : blockArg.getUsers()) {
709+
workingSet.insert(user);
711710
}
712711
}
713-
} else if (crtOp->getNumResults() > 0) {
714-
// TODO: should we handle more than one result?
715-
auto crtOpResult = crtOp->getResult(0);
716-
LDBG("Not a load store and not a loop, adding users to working "
717-
"set.");
718-
for (auto user : crtOpResult.getUsers()) {
719-
workingSet.insert(user);
720-
}
721712
}
722-
workingSet.erase(crtOpItr);
713+
} else if (crtOp->getNumResults() > 0) {
714+
// TODO: should we handle more than one result?
715+
auto crtOpResult = crtOp->getResult(0);
716+
LDBG("Not a load store and not a loop, adding users to working "
717+
"set.");
718+
for (auto user : crtOpResult.getUsers()) {
719+
workingSet.insert(user);
720+
}
723721
}
722+
workingSet.erase(crtOpItr);
724723
}
725724
});
726725

0 commit comments

Comments
 (0)