|
1 | 1 | // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
|
2 | 2 |
|
| 3 | +// COM: test complete example |
3 | 4 | #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
|
4 | 5 | #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
|
5 | 6 | #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
|
@@ -59,6 +60,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
|
59 | 60 |
|
60 | 61 | // -----
|
61 | 62 |
|
| 63 | +// COM: Test while loop / tt.advance before tt.load (TODO) |
| 64 | +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> |
| 65 | +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
| 66 | +// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> |
| 67 | +// CHECK-DAG: #[[$SUBGROUP_2D_BLOCK:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> |
| 68 | +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
| 69 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { |
| 70 | + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>) { |
| 71 | + %c1024_i64 = arith.constant 1024 : i64 |
| 72 | + %c5120_i64 = arith.constant 5120 : i64 |
| 73 | + %c1_i64 = arith.constant 1 : i64 |
| 74 | + %c256_i32 = arith.constant 256 : i32 |
| 75 | + %c0_i32 = arith.constant 0 : i32 |
| 76 | + %c32_i32 = arith.constant 32 : i32 |
| 77 | + |
| 78 | + // CHECK: %[[A_PTR:.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> |
| 79 | + %a_ptr = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>> |
| 80 | + |
| 81 | + // CHECK: scf.while {{.*}} : (!tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>) -> !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> |
| 82 | + %1 = scf.while (%a_ptr_crt = %a_ptr) : (!tt.ptr<tensor<256x32xf16, #blocked1>>) -> (!tt.ptr<tensor<256x32xf16, #blocked1>>) { |
| 83 | + %2 = "dummy.evaluate_condition"() : () -> i1 |
| 84 | + // CHECK: scf.condition({{.*}}) {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> |
| 85 | + scf.condition(%2) %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>> |
| 86 | + } do { |
| 87 | + ^bb0(%a_ptr_crt: !tt.ptr<tensor<256x32xf16, #blocked1>>): |
| 88 | + // CHECK: ^bb0({{.*}}: !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>): |
| 89 | + |
| 90 | + // CHECK: %[[A_LOAD:.*]] = tt.load {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> |
| 91 | + %3 = tt.load %a_ptr_crt {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>> |
| 92 | + // CHECK: ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]> -> tensor<256x32xf16, #[[$BLOCKED]]> |
| 93 | + // CHECK: ttg.convert_layout {{.*}} : tensor<256x32xf16, #[[$BLOCKED]]> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> |
| 94 | + %4 = ttg.convert_layout %3 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> |
| 95 | + |
| 96 | + %cstB = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> |
| 97 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> |
| 98 | + |
| 99 | + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> |
| 100 | + %5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> |
| 101 | + %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1> |
| 102 | + // COM: TODO: support nested tt.advance |
| 103 | + // %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>> |
| 104 | + |
| 105 | + // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> |
| 106 | + scf.yield %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>> |
| 107 | + } |
| 108 | + tt.return |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +// ----- |
| 113 | + |
62 | 114 | // COM: test complex control flow
|
63 | 115 | // COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if.
|
64 | 116 | #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
|
|
0 commit comments