|
| 1 | +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2" -canonicalize | FileCheck %s |
| 2 | + |
| 3 | +// Pick a common shared memory layout with vec = max kWidth of all users. |
| 4 | +// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}> |
| 5 | +// CHECK-NOT: #ttg.swizzled_shared |
| 6 | +// CHECK{LITERAL}: #smem = #ttg.shared_memory |
| 7 | +// CHECK-LABEL: test_lds_layout_selection |
| 8 | + |
| 9 | +// CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> |
| 10 | +// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]] |
| 11 | + |
| 12 | +// CHECK: scf.for {{.+}} iter_args({{.*}}, %[[MEMDESC_IDX_ITER:.+]] = %[[MEMDESC_IDX]]) -> ({{.+}}) |
| 13 | +// CHECK: %[[LOAD:.+]] = tt.load {{.+}} : tensor<64x16x!tt.ptr<f16>, #blocked> |
| 14 | +// CHECK: %[[LOCAL_LOAD_TRANS:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #linear> |
| 15 | +// CHECK: %[[LOCAL_LOAD_DIRECT:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> |
| 16 | +// CHECK: tt.dot {{.+}}, %[[LOCAL_LOAD_DIRECT]], {{.+}} |
| 17 | +// CHECK: %[[TRANS:.+]] = tt.trans %[[LOCAL_LOAD_TRANS]] {{.+}} : {{.+}} -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> |
| 18 | +// CHECK: tt.dot {{.+}}, %[[TRANS]], {{.+}} |
| 19 | +// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]] |
| 20 | +// CHECK: ttg.local_store %[[LOAD]], %[[MEMDESC_IDX]] |
| 21 | +// CHECK: scf.yield |
| 22 | + |
| 23 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> |
| 24 | +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}> |
| 25 | +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> |
| 26 | +#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> |
| 27 | + |
| 28 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 29 | + tt.func public @test_lds_layout_selection( |
| 30 | + %arg0: tensor<64x16x!tt.ptr<f16>, #blocked>, |
| 31 | + %out0 : tensor<128x16x!tt.ptr<f32>, #blocked>, |
| 32 | + %out1 : tensor<128x64x!tt.ptr<f32>, #blocked> |
| 33 | + ) { |
| 34 | + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> |
| 35 | + %cst_1 = arith.constant dense<0.693147182> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> |
| 36 | + %cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> |
| 37 | + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> |
| 38 | + %c0_i32 = arith.constant 0 : i32 |
| 39 | + %c1_i32 = arith.constant 1 : i32 |
| 40 | + %c8_i32 = arith.constant 8 : i32 |
| 41 | + |
| 42 | + %0:2 = scf.for %arg1 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg2 = %cst_0, %arg3 = %cst_3) -> (tensor<128x16xf32, #mma1>, tensor<128x64xf32, #mma>) : i32 { |
| 43 | + %1 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked> |
| 44 | + %2 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear> |
| 45 | + %3 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> |
| 46 | + %4 = tt.dot %cst_1, %3, %arg2 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> |
| 47 | + %5 = tt.trans %2 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> |
| 48 | + %6 = tt.dot %cst_2, %5, %arg3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma> |
| 49 | + scf.yield %4, %6 : tensor<128x16xf32, #mma1>, tensor<128x64xf32, #mma> |
| 50 | + } |
| 51 | + |
| 52 | + %7 = ttg.convert_layout %0#0 : tensor<128x16xf32, #mma1> -> tensor<128x16xf32, #blocked> |
| 53 | + %8 = ttg.convert_layout %0#1 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> |
| 54 | + tt.store %out0, %7 : tensor<128x16x!tt.ptr<f32>, #blocked> |
| 55 | + tt.store %out1, %8 : tensor<128x64x!tt.ptr<f32>, #blocked> |
| 56 | + tt.return |
| 57 | + } |
| 58 | +} |
| 59 | +// ----- |
| 60 | + |
| 61 | +// Verify that a common shared memory layout is chosen for users with different kWidth and opIdx. |
| 62 | +// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}> |
| 63 | +// CHECK-NOT: #ttg.swizzled_shared |
| 64 | +// CHECK{LITERAL}: #smem = #ttg.shared_memory |
| 65 | +// CHECK-LABEL: test_lds_layout_selection_different_opIdx |
| 66 | + |
| 67 | +// CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> |
| 68 | +// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]] |
| 69 | + |
| 70 | +// CHECK: scf.for {{.+}} iter_args({{.*}}, %[[MEMDESC_IDX_ITER:.+]] = %[[MEMDESC_IDX]]) -> ({{.+}}) |
| 71 | +// CHECK: %[[LOAD:.+]] = tt.load {{.+}} : tensor<64x16x!tt.ptr<f16>, #blocked> |
| 72 | +// CHECK: %[[LOCAL_LOAD_TRANS:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #linear> |
| 73 | +// CHECK: %[[LOCAL_LOAD_DIRECT:.+]] = ttg.local_load %[[MEMDESC_IDX_ITER]] : {{.+}} -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> |
| 74 | +// CHECK: tt.dot %[[LOCAL_LOAD_DIRECT]], {{.+}} |
| 75 | +// CHECK: %[[TRANS:.+]] = tt.trans %[[LOCAL_LOAD_TRANS]] {{.+}} : {{.+}} -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> |
| 76 | +// CHECK: tt.dot {{.+}}, %[[TRANS]], {{.+}} |
| 77 | +// CHECK: %[[MEMDESC_IDX:.+]] = ttg.memdesc_index %[[ALLOC]] |
| 78 | +// CHECK: ttg.local_store %[[LOAD]], %[[MEMDESC_IDX]] |
| 79 | +// CHECK: scf.yield |
| 80 | + |
| 81 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> |
| 82 | +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}> |
| 83 | +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> |
| 84 | +#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> |
| 85 | + |
| 86 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 87 | + tt.func public @test_lds_layout_selection_different_opIdx( |
| 88 | + %arg0: tensor<64x16x!tt.ptr<f16>, #blocked>, |
| 89 | + %out0 : tensor<64x64x!tt.ptr<f32>, #blocked>, |
| 90 | + %out1 : tensor<128x64x!tt.ptr<f32>, #blocked> |
| 91 | + ) { |
| 92 | + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1> |
| 93 | + %cst_1 = arith.constant dense<0.693147182> : tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> |
| 94 | + %cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> |
| 95 | + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> |
| 96 | + %c0_i32 = arith.constant 0 : i32 |
| 97 | + %c1_i32 = arith.constant 1 : i32 |
| 98 | + %c8_i32 = arith.constant 8 : i32 |
| 99 | + |
| 100 | + %0:2 = scf.for %arg1 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg2 = %cst_0, %arg3 = %cst_3) -> (tensor<64x64xf32, #mma1>, tensor<128x64xf32, #mma>) : i32 { |
| 101 | + %1 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked> |
| 102 | + %2 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear> |
| 103 | + %3 = ttg.convert_layout %1 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> |
| 104 | + %4 = tt.dot %3, %cst_1, %arg2 : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<64x64xf32, #mma1> |
| 105 | + %5 = tt.trans %2 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> |
| 106 | + %6 = tt.dot %cst_2, %5, %arg3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma> |
| 107 | + scf.yield %4, %6 : tensor<64x64xf32, #mma1>, tensor<128x64xf32, #mma> |
| 108 | + } |
| 109 | + |
| 110 | + %7 = ttg.convert_layout %0#0 : tensor<64x64xf32, #mma1> -> tensor<64x64xf32, #blocked> |
| 111 | + %8 = ttg.convert_layout %0#1 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked> |
| 112 | + tt.store %out0, %7 : tensor<64x64x!tt.ptr<f32>, #blocked> |
| 113 | + tt.store %out1, %8 : tensor<128x64x!tt.ptr<f32>, #blocked> |
| 114 | + tt.return |
| 115 | + } |
| 116 | +} |
0 commit comments