Skip to content

Commit a360660

Browse files
authored
[AMD] Fix InThreadTranspose data flow traversal (triton-lang#6320)
This PR fixes multiple problems in DF traversal: - Fix nested scf structure of the following kind: ``` %1 = for { %2 = for { yield 2 } yield 1 } use %1 ``` Data flow analysis during processing %1 traverses only `yield 1`, before this PR it considered both yields. - fixes traversal of `scf.while` operation - fixes traversal of loops in data flow graph without non-scf operations - fixes flickering memory corruption caused by duplicates in list of loads staged for replacement - stricter processing of unsupported scf operations, optimization should abort if it comes across them.
1 parent 07478c2 commit a360660

File tree

2 files changed

+402
-57
lines changed

2 files changed

+402
-57
lines changed

test/TritonGPU/amd/in-thread-transpose.mlir

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,253 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
269269
tt.return
270270
}
271271
}
272+
273+
// -----
274+
275+
// Test that backward SCF traversal correctly process nested CF structures
276+
// CHECK-LABEL: inThreadTranspose_nested_scf_traversal_regression
277+
278+
// CHECK: [[IF:%.*]] = scf.if {{.*}} -> (!ttg.memdesc<32x128xf16, #shared, #smem>) {
279+
// CHECK: scf.if {{.*}} -> (tensor<32x128xf16, #blocked>) {
280+
// CHECK: } else {
281+
// CHECK: }
282+
// CHECK: [[TRANS1:%.*]] = amdgpu.in_thread_transpose {{.*}} : tensor<32x128xf16
283+
// CHECK: [[ALLOC1:%.*]] = ttg.local_alloc [[TRANS1]] : {{.*}} !ttg.memdesc<32x128xf16
284+
// CHECK: scf.yield [[ALLOC1]] : !ttg.memdesc<32x128xf16, #shared, #smem>
285+
// CHECK: } else {
286+
// CHECK: [[TRANS2:%.*]] = amdgpu.in_thread_transpose {{.*}} : tensor<32x128xf16
287+
// CHECK: [[ALLOC2:%.*]] = ttg.local_alloc [[TRANS2]] : {{.*}} -> !ttg.memdesc<32x128xf16
288+
// CHECK: scf.yield [[ALLOC2]] : !ttg.memdesc<32x128xf16
289+
// CHECK: }
290+
// CHECK: ttg.local_load [[IF]]
291+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
292+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
293+
#smem = #ttg.shared_memory
294+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
295+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
296+
tt.func public @inThreadTranspose_nested_scf_traversal_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
297+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
298+
%0 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
299+
%5 = scf.if %arg2 -> (!ttg.memdesc<32x128xf16, #shared, #smem>) {
300+
%10 = scf.if %arg2 -> (tensor<32x128xf16, #blocked>) {
301+
%11 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
302+
scf.yield %11 : tensor<32x128xf16, #blocked>
303+
} else {
304+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
305+
scf.yield %cst_1 : tensor<32x128xf16, #blocked>
306+
}
307+
%2 = ttg.local_alloc %10 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared, #smem>
308+
scf.yield %2 : !ttg.memdesc<32x128xf16, #shared, #smem>
309+
} else {
310+
%3 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
311+
%4 = ttg.local_alloc %3 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared, #smem>
312+
scf.yield %4 : !ttg.memdesc<32x128xf16, #shared, #smem>
313+
}
314+
%6 = ttg.local_load %5 : !ttg.memdesc<32x128xf16, #shared, #smem> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
315+
%7 = tt.dot %arg0, %6, %cst_0 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
316+
tt.return
317+
}
318+
}
319+
320+
// -----
321+
322+
// Test that ITT does not crash on following Data flow:
323+
//
324+
// %v = define mem ref
325+
// while (%arg = %v) {
326+
// use %arg
327+
// }
328+
//
329+
// CHECK-LABEL: inThreadTranspose_inbound_df_while_regression
330+
// CHECK: [[TRANS1:%.*]] = amdgpu.in_thread_transpose
331+
// CHECK: ttg.local_alloc [[TRANS1]] : (tensor<32x128xf16
332+
// CHECK: scf.while
333+
// CHECK: } do {
334+
// CHECK: [[TRANS2:%.*]] = amdgpu.in_thread_transpose
335+
// CHECK: ttg.local_store [[TRANS2]], {{.*}} : tensor<32x128xf16
336+
// CHECK: }
337+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
338+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
339+
#smem = #ttg.shared_memory
340+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
341+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
342+
tt.func public @inThreadTranspose_inbound_df_while_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
343+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
344+
%0 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
345+
%1 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
346+
%2 = ttg.local_alloc %1 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
347+
%3:1 = scf.while (%arg10 = %2, %arg11 = %arg2) : (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>, i1) -> (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>) {
348+
scf.condition(%arg11) %arg10 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
349+
} do {
350+
^bb0(%arg20: !ttg.memdesc<32x128xf16, #shared, #smem, mutable>):
351+
%10 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
352+
%11 = ttg.local_load %arg20 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
353+
ttg.local_store %10, %arg20 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
354+
%12 = tt.dot %arg0, %11, %cst_0 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
355+
scf.yield %arg20, %arg2 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>, i1
356+
}
357+
tt.return
358+
}
359+
}
360+
361+
// -----
362+
363+
// Test that ITT does not crash on following Data flow:
364+
//
365+
// %w = while () {
366+
// %v = define mem ref
367+
// yield %v
368+
// }
369+
// use %w
370+
//
371+
// CHECK-LABEL: inThreadTranspose_outbound_df_while_regression
372+
// CHECK: [[TRANS1:%.*]] = amdgpu.in_thread_transpose
373+
// CHECK: ttg.local_alloc [[TRANS1]] : (tensor<32x128xf16
374+
// CHECK: scf.while
375+
// CHECK: } do {
376+
// CHECK: }
377+
// CHECK: [[TRANS2:%.*]] = amdgpu.in_thread_transpose
378+
// CHECK: ttg.local_store [[TRANS2]], {{.*}} : tensor<32x128xf16
379+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
380+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
381+
#smem = #ttg.shared_memory
382+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
383+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
384+
tt.func public @inThreadTranspose_outbound_df_while_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
385+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
386+
%0 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
387+
%1 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
388+
%2 = ttg.local_alloc %1 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
389+
%3:1 = scf.while (%arg10 = %2, %arg11 = %arg2) : (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>, i1) -> (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>) {
390+
scf.condition(%arg11) %arg10 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
391+
} do {
392+
^bb0(%arg20: !ttg.memdesc<32x128xf16, #shared, #smem, mutable>):
393+
scf.yield %arg20, %arg2 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>, i1
394+
}
395+
ttg.local_store %1, %3#0 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
396+
%4 = ttg.local_load %3#0 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
397+
%5 = tt.dot %arg0, %4, %cst_0 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
398+
tt.return
399+
}
400+
}
401+
402+
// -----
403+
404+
// Test that ITT does not crash on following Data flow:
405+
//
406+
// %v = define mem ref
407+
// for (%arg = %v) {
408+
// use %arg
409+
// }
410+
//
411+
// CHECK-LABEL: inThreadTranspose_inbound_df_for_regression
412+
// CHECK: [[TRANS1:%.*]] = amdgpu.in_thread_transpose
413+
// CHECK: ttg.local_alloc [[TRANS1]] : (tensor<32x128xf16
414+
// CHECK: scf.for
415+
// CHECK: [[TRANS2:%.*]] = amdgpu.in_thread_transpose
416+
// CHECK: ttg.local_store [[TRANS2]], {{.*}} : tensor<32x128xf16
417+
// CHECK: }
418+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
419+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
420+
#smem = #ttg.shared_memory
421+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
422+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
423+
tt.func public @inThreadTranspose_inbound_df_for_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
424+
%c0_i32 = arith.constant 0 : i32
425+
%c1_i32 = arith.constant 0 : i32
426+
%c10_i32 = arith.constant 10 : i32
427+
%0 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
428+
%1 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
429+
%2 = ttg.local_alloc %1 : (tensor<32x128xf16, #blocked>) -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
430+
%3:1 = scf.for %arg10 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg11 = %2) -> (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>) : i32 {
431+
%10 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
432+
%11 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
433+
ttg.local_store %10, %arg11 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
434+
scf.yield %arg11 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
435+
}
436+
tt.return
437+
}
438+
}
439+
440+
// -----
441+
442+
// Test that ITT does not crash on following Data flow:
443+
//
444+
// %f = for () {
445+
// %v = define mem ref
446+
// yield %v
447+
// }
448+
// use %f
449+
//
450+
// CHECK-LABEL: inThreadTranspose_outbound_df_for_regression
451+
// CHECK: scf.for
452+
// CHECK: [[TRANS:%.*]] = amdgpu.in_thread_transpose
453+
// CHECK: ttg.local_store [[TRANS]], {{.*}} : tensor<32x128xf16
454+
// CHECK: }
455+
// CHECK: ttg.local_load
456+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
457+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
458+
#smem = #ttg.shared_memory
459+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
460+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
461+
tt.func public @inThreadTranspose_outbound_df_for_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
462+
%c0_i32 = arith.constant 0 : i32
463+
%c1_i32 = arith.constant 0 : i32
464+
%c10_i32 = arith.constant 10 : i32
465+
%0 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
466+
%1 = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
467+
%2:1 = scf.for %arg10 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg11 = %1) -> (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>) : i32 {
468+
%10 = tt.load %0 : tensor<32x128x!tt.ptr<f16>, #blocked>
469+
%11 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
470+
ttg.local_store %10, %arg11 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
471+
scf.yield %arg11 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
472+
}
473+
%3 = ttg.local_load %2#0 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
474+
tt.return
475+
}
476+
}
477+
478+
// -----
479+
480+
// Test that ITT does not crash on following Data flow:
481+
//
482+
// %i = if () {
483+
// %v1 = define mem ref
484+
// yield %v1
485+
// } else {
486+
// %v2 = define mem ref
487+
// yield %v2
488+
// }
489+
// use %i
490+
//
491+
// CHECK-LABEL: inThreadTranspose_outbound_df_for_regression
492+
// CHECK: [[IF:%.*]] = scf.if
493+
// CHECK: [[ALLOC1:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16
494+
// CHECK: scf.yield [[ALLOC1]]
495+
// CHECK: } else {
496+
// CHECK: [[ALLOC2:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16
497+
// CHECK: scf.yield [[ALLOC2]]
498+
// CHECK: }
499+
// CHECK: [[TRANS:%.*]] = amdgpu.in_thread_transpose
500+
// CHECK: ttg.local_store [[TRANS]], [[IF]] : tensor<32x128xf16
501+
// CHECK: ttg.local_load [[IF]] : !ttg.memdesc<32x128xf16
502+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [1, 0]}>
503+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
504+
#smem = #ttg.shared_memory
505+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
506+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
507+
tt.func public @inThreadTranspose_outbound_df_for_regression(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i1) {
508+
%0 = scf.if %arg2 -> (!ttg.memdesc<32x128xf16, #shared, #smem, mutable>) {
509+
%1 = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
510+
scf.yield %1 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
511+
} else {
512+
%2 = ttg.local_alloc : () -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
513+
scf.yield %2 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
514+
}
515+
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
516+
%4 = tt.load %3: tensor<32x128x!tt.ptr<f16>, #blocked>
517+
ttg.local_store %4, %0 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared, #smem, mutable>
518+
%5 = ttg.local_load %0 : !ttg.memdesc<32x128xf16, #shared, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
519+
tt.return
520+
}
521+
}

0 commit comments

Comments
 (0)