Skip to content

Commit 8f63d38

Browse files
committed
Merge commit 'c2fd8e1b426b76011479f950a7fb1b7b1e93490e'
2 parents c874647 + c2fd8e1 commit 8f63d38

File tree

20 files changed

+351
-28
lines changed

20 files changed

+351
-28
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -902,9 +902,6 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
902902
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
903903
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
904904

905-
if B_DATA_TYPE != 'float4' and B_TRANS:
906-
pytest.xfail(f'No need to transpose B for {B_DATA_TYPE}')
907-
908905
if is_xpu():
909906
pytest.skip("FIXME: failed to legalize operation 'tt.dot_scaled' on XPU")
910907

@@ -913,13 +910,21 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
913910

914911
torch.manual_seed(42)
915912

916-
def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bool = False):
913+
def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bool = True):
917914
if dtype == "float8e5":
918-
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
919-
v_ref = f8_to_f16(v.view(torch.float8_e5m2), dtype).to(torch.float32)
915+
if transpose:
916+
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e5m2).to(device)
917+
v_ref = f8_to_f16(v.view(torch.float8_e5m2), dtype).to(torch.float32)
918+
else:
919+
v = torch.randint(20, 40, (size1, size0), dtype=torch.uint8).view(torch.float8_e5m2).to(device).T
920+
v_ref = f8_to_f16(v.view(torch.float8_e5m2).T, dtype).to(torch.float32).T
920921
elif dtype == "float8e4nv":
921-
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
922-
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn), dtype).to(torch.float32)
922+
if transpose:
923+
v = torch.randint(20, 40, (size0, size1), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device)
924+
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn), dtype).to(torch.float32)
925+
else:
926+
v = torch.randint(20, 40, (size1, size0), dtype=torch.uint8).view(torch.float8_e4m3fn).to(device).T
927+
v_ref = f8_to_f16(v.view(torch.float8_e4m3fn).T, dtype).to(torch.float32).T
923928
else:
924929
# float4
925930
if transpose:
@@ -937,8 +942,8 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
937942
a, a_ref = create_operand(A_DATA_TYPE, M, K, 1)
938943
b, b_ref = create_operand(B_DATA_TYPE, K, N, 0, B_TRANS)
939944

940-
a_scale_mxfp4 = MXScaleTensor(size=(M, (K + 32 - 1) // 32), device=device).random(high=64.0)
941-
b_scale_mxfp4 = MXScaleTensor(size=(N, (K + 32 - 1) // 32), device=device).random(high=64.0)
945+
a_scale_mxfp4 = MXScaleTensor(size=(M, (K + 32 - 1) // 32), device=device).random(high=32.0)
946+
b_scale_mxfp4 = MXScaleTensor(size=(N, (K + 32 - 1) // 32), device=device).random(high=32.0)
942947
a_scale = a_scale_mxfp4.data
943948
b_scale = b_scale_mxfp4.data
944949

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
8888
!ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma>
8989
tt.return
9090
}
91+
92+
// CHECK-LABEL: @wgmma_on_subtile
93+
// CHECK: nvgpu.wgmma %{{.*}}, %{{.*}}
94+
tt.func @wgmma_on_subtile(%a: tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %b: !ttg.memdesc<16x256xf16, #shared1, #smem, mutable, 3x64x256>){
95+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
96+
%m = ttng.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, isAsync = true} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<16x256xf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma>
97+
tt.return
98+
}
9199
}
92100

93101
// -----

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,98 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
423423
}
424424
}
425425

426+
// -----
427+
428+
//CHECK-LABEL: pingpong_small_prologue_load
429+
//CHECK: ttg.local_load
430+
//CHECK: rocdl.s.setprio 1
431+
//CHECK: tt.load
432+
//CHECK: rocdl.sched.barrier
433+
//CHECK: ttg.local_load
434+
//CHECK: rocdl.s.setprio 0
435+
//CHECK: tt.load
436+
//CHECK: rocdl.sched.barrier
437+
//CHECK: rocdl.s.setprio 1
438+
//CHECK: tt.dot
439+
//CHECK: rocdl.s.setprio 0
440+
441+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
442+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
443+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
444+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
445+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
446+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
447+
tt.func public @pingpong_small_prologue_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
448+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
449+
%c1_i32 = arith.constant 1 : i32
450+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
451+
%cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1>
452+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
453+
%c0_i32 = arith.constant 0 : i32
454+
%c64_i32 = arith.constant 64 : i32
455+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
456+
%1 = tt.get_program_id x : i32
457+
%2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
458+
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
459+
%4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
460+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
461+
%6 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
462+
%7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1>
463+
%8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
464+
%9 = tt.broadcast %8 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
465+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
466+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
467+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
468+
%13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
469+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
470+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
471+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
472+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
473+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
474+
%19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked>
475+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
476+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable>
477+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
478+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>
479+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
480+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 {
481+
%26 = arith.cmpi eq, %arg10, %c0_i32: i32
482+
%27 = scf.if %26 -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> {
483+
%28 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
484+
%29 = tt.broadcast %28 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
485+
%30 = tt.load %29 : tensor<128x64x!tt.ptr<f16>, #blocked1>
486+
%31 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable>
487+
%32 = ttg.memdesc_subview %31[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>
488+
ttg.local_store %30, %32 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>
489+
%33 = ttg.local_load %32 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
490+
scf.yield %33 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
491+
} else {
492+
scf.yield %cst_2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
493+
}
494+
%34 = tt.addptr %arg12, %cst_1 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
495+
%35 = tt.load %34 : tensor<128x64x!tt.ptr<f16>, #blocked1>
496+
%36 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
497+
%37 = tt.load %36 : tensor<64x128x!tt.ptr<f16>, #blocked>
498+
%38 = ttg.local_load %arg15 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
499+
%39 = arith.addf %38, %27: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
500+
%40 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
501+
%41 = tt.dot %39, %40, %arg11 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x128xf32, #mma>
502+
%42 = arith.addi %arg14, %c1_i32 : i32
503+
%43 = arith.cmpi slt, %42, %c1_i32 : i32
504+
%44 = arith.select %43, %42, %c0_i32 : i32
505+
%45 = ttg.memdesc_subview %21[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>
506+
ttg.local_store %35, %45 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>
507+
%46 = ttg.memdesc_subview %22[%44, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
508+
ttg.local_store %37, %46 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
509+
scf.yield %41, %34, %36, %44, %45, %46 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
510+
}
511+
ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #ttg.shared_memory, mutable>
512+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
513+
tt.return
514+
}
515+
}
516+
517+
426518
// -----
427519
// CHECK-LABEL: pingpong_medium_dependency
428520

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ class Pingponger {
7979
void appendSlicedLoadAB(int slice);
8080
void appendClusterBarrier(OpBuilder &builder, Location loc);
8181
void appendOpWithPrio(OpBuilder &builder, Operation *Op, Location loc);
82+
void determineDotMemoryOps(tt::DotOp dotOp,
83+
DenseSet<tt::LoadOp> &dotGlobalLoads,
84+
DenseSet<ttg::LocalLoadOp> &dotLocalLoads,
85+
DenseSet<ttg::LocalStoreOp> &dotLocalStores);
86+
template <typename T>
87+
void findClosestPredOps(Value v, DenseSet<T> &matchingOps);
8288
};
8389

8490
void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; }
@@ -150,6 +156,89 @@ void Pingponger::appendOpWithPrio(OpBuilder &builder, Operation *op,
150156
appendOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority));
151157
}
152158

159+
// Find all of the "closest" operations that are of a given type T
160+
// in the same basic block. Here "closest" means along any path P,
161+
// the first operation of type T that is encountered when traversing
162+
// P from the given value v. This also includes "later" operations
163+
// for block arguments. Note: That we find all T for every path P.
164+
template <typename T>
165+
void Pingponger::findClosestPredOps(Value v, DenseSet<T> &matchingOps) {
166+
// Create a cache so we can traverse across block arguments.
167+
DenseSet<Operation *> visitedOps;
168+
std::function<void(Value)> impl;
169+
impl = [&matchingOps, &visitedOps, &impl](Value v) {
170+
// If we encounter a block argument we only look at the terminators of the
171+
// current block
172+
if (auto blockArg = dyn_cast<BlockArgument>(v)) {
173+
auto operandNumber = blockArg.getArgNumber();
174+
auto block = blockArg.getOwner();
175+
if (auto yield = dyn_cast<scf::YieldOp>(block->getTerminator())) {
176+
auto parentOp = block->getParentOp();
177+
// Skip the induction variables to find the yield position
178+
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
179+
if (operandNumber < forOp.getNumInductionVars())
180+
return;
181+
operandNumber -= forOp.getNumInductionVars();
182+
}
183+
impl(yield->getOperand(operandNumber));
184+
}
185+
} else {
186+
auto definingOp = v.getDefiningOp();
187+
if (!definingOp)
188+
return;
189+
else if (visitedOps.contains(definingOp))
190+
return;
191+
visitedOps.insert(definingOp);
192+
if (auto matchOp = dyn_cast<T>(definingOp))
193+
matchingOps.insert(matchOp);
194+
else
195+
for (auto predValue : definingOp->getOperands())
196+
impl(predValue);
197+
}
198+
};
199+
impl(v);
200+
}
201+
202+
// Populate the dotGlobalLoads, dotLocalLoads, and dotLocalStores set with
203+
// any loads that are generated by the current dot product. This occurs in
204+
// steps to:
205+
// 1. Determine which loads are generated by the dot product via getA()
206+
// and getB().
207+
// 2. Determine which local stores are used to populate the inputs to
208+
// the local loads.
209+
// 3. Determine which global loads are used to populate the inputs to
210+
// the local stores.
211+
// Note: This function currently depends on num_stages=2, which is a
212+
// precondition for the pingpong scheduling.
213+
void Pingponger::determineDotMemoryOps(
214+
tt::DotOp dotOp, DenseSet<tt::LoadOp> &dotGlobalLoads,
215+
DenseSet<ttg::LocalLoadOp> &dotLocalLoads,
216+
DenseSet<ttg::LocalStoreOp> &dotLocalStores) {
217+
// Find the locals loads used to compute the dot inputs. These
218+
// must come before the dot op.
219+
findClosestPredOps<ttg::LocalLoadOp>(dotOp.getA(), dotLocalLoads);
220+
findClosestPredOps<ttg::LocalLoadOp>(dotOp.getB(), dotLocalLoads);
221+
222+
// Determine the local stores from the local loads.
223+
// With pipelining we expect this to be a single local
224+
// store within the loop based on a block argument after routing through
225+
// a ttg.MemDescSubviewOp.
226+
DenseSet<ttg::MemDescSubviewOp> subviews;
227+
for (auto &&localLoad : dotLocalLoads)
228+
findClosestPredOps<ttg::MemDescSubviewOp>(localLoad.getSrc(), subviews);
229+
230+
for (auto &&subview : subviews)
231+
for (auto &&user : subview->getUsers())
232+
if (auto localStore = dyn_cast<ttg::LocalStoreOp>(user))
233+
dotLocalStores.insert(localStore);
234+
235+
// Determine the global loads from the local stores.
236+
// We expect this to just be a global load
237+
// within the loop.
238+
for (auto &&localStore : dotLocalStores)
239+
findClosestPredOps<tt::LoadOp>(localStore.getSrc(), dotGlobalLoads);
240+
}
241+
153242
// Transform a loop into one Dot - Memory (ping - pong) clusters
154243
// Each cluster, especially the Dot cluster is guarded with setprio(1->0) so
155244
// each warp can complete the execution of the cluster without being
@@ -473,6 +562,46 @@ void Pingponger::getDotPingponged() {
473562
LDBG(message.str());
474563
return;
475564
}
565+
566+
// The existing code depends on the loads being targeted being safe to move,
567+
// which will not hold if we do not properly have a GEMM. As a result, we
568+
// filter the associated load operations to only those that are associated
569+
// // with the GEMM.
570+
DenseSet<tt::LoadOp> dotGlobalLoads;
571+
DenseSet<ttg::LocalLoadOp> dotLocalLoads;
572+
DenseSet<ttg::LocalStoreOp> dotLocalStores;
573+
determineDotMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads,
574+
dotLocalStores);
575+
576+
auto origGlobalLoadCount = gLoadOps.size();
577+
auto origLocalLoadCount = lLoadOps.size();
578+
// Prune Memory operations that may be moved to only those involved in dot
579+
// computation.
580+
auto gLoadIt = llvm::remove_if(gLoadOps, [&dotGlobalLoads](tt::LoadOp op) {
581+
return !dotGlobalLoads.contains(op);
582+
});
583+
gLoadOps.erase(gLoadIt, gLoadOps.end());
584+
auto lLoadIt =
585+
llvm::remove_if(lLoadOps, [&dotLocalLoads](ttg::LocalLoadOp op) {
586+
return !dotLocalLoads.contains(op);
587+
});
588+
lLoadOps.erase(lLoadIt, lLoadOps.end());
589+
auto lStoreIt =
590+
llvm::remove_if(lStoreOps, [&dotLocalStores](ttg::LocalStoreOp op) {
591+
return !dotLocalStores.contains(op);
592+
});
593+
lStoreOps.erase(lStoreIt, lStoreOps.end());
594+
// All PingPong Scheduler assumes there are 2 movable global loads and 2
595+
// movable local loads.
596+
if (gLoadOps.size() != 2 || lLoadOps.size() != 2) {
597+
std::stringstream message;
598+
message << "Unable to match ping pong slicing pattern. Details: "
599+
<< gLoadOps.size() << " global loads in dot computation, "
600+
<< lLoadOps.size() << " local loads in dot computation";
601+
LDBG(message.str());
602+
return;
603+
}
604+
476605
// Pingpong scheduling tries to form two different types of the instruction
477606
// clusters, i.e., Dot clusters and Memory clusters. While each SIMD has
478607
// two concurrent warps, both warps can execute a different type of
@@ -532,14 +661,21 @@ void Pingponger::getDotPingponged() {
532661
// numWarps=4 doesn't need asymmetric sync, return.
533662
return;
534663
} else if (numWarps == 8) { // Pingpong between warps from the same block
535-
if (gLoadOps.size() != 2 || lLoadOps.size() != 2) {
664+
if (origGlobalLoadCount != 2 || origLocalLoadCount != 2) {
536665
std::stringstream message;
537666
message << "Unable to match ping pong slicing pattern. Details: "
538667
<< gLoadOps.size() << " global loads, " << lLoadOps.size()
539668
<< " local loads";
540669
LDBG(message.str());
541670
return;
542671
}
672+
if (lStoreOps.size() != 2) {
673+
std::stringstream message;
674+
message << "Unable to match ping pong slicing pattern. Details: "
675+
<< lStoreOps.size() << " local stores in dot computation ";
676+
LDBG(message.str());
677+
return;
678+
}
543679
// Transform a loop where the tile size requires dots to be sliced
544680
if (tileSize == mediumTile) {
545681
if (transformTwoPPClusters(builder, dotOps[0]->getLoc()).failed()) {

0 commit comments

Comments
 (0)