|
| 1 | +// RUN: triton-opt %s -split-input-file --nvgpu-test-ws-data-partition=num-warp-groups=3 | FileCheck %s |
| 2 | + |
| 3 | +// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel |
| 4 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> |
| 5 | +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> |
| 6 | +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> |
| 7 | +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 0}> |
| 8 | +#smem = #ttg.shared_memory |
| 9 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { |
| 10 | + tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 11 | + %c0_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 0 : i32 |
| 12 | + %c1_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 1 : i32 |
| 13 | + %c64_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 64 : i32 |
| 14 | + %cst = arith.constant {async_task_id = array<i32: 1, 2>} dense<0.000000e+00> : tensor<128x256xf32, #mma> |
| 15 | + %0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32 |
| 16 | + %1 = tt.get_num_programs x {async_task_id = array<i32: 0, 1, 2>} : i32 |
| 17 | + scf.for %arg6 = %0 to %arg3 step %1 : i32 { |
| 18 | + %2 = tt.splat %arg0 {async_task_id = array<i32: 0>} : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked> |
| 19 | + %3 = tt.splat %arg1 {async_task_id = array<i32: 0>} : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked1> |
| 20 | + %4:2 = scf.for %arg7 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { |
| 21 | + // CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr<f16> |
| 22 | + // CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr<f16> |
| 23 | + %8 = tt.load %2 {async_task_id = array<i32: 0>} : tensor<128x64x!tt.ptr<f16>, #blocked> |
| 24 | + // CHECK: %[[#LA1:]] = ttg.local_alloc %[[#GA1]] |
| 25 | + // CHECK: %[[#LA2:]] = ttg.local_alloc %[[#GA2]] |
| 26 | + %9 = ttg.local_alloc %8 {async_task_id = array<i32: 1, 2>} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> |
| 27 | + // CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr<f16> |
| 28 | + %10 = tt.load %3 {async_task_id = array<i32: 0>} : tensor<64x256x!tt.ptr<f16>, #blocked1> |
| 29 | + // CHECK: %[[#LB:]] = ttg.local_alloc %[[#GB]] |
| 30 | + %11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> |
| 31 | + // CHECK: %[[#C1:]] = ttng.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma> |
| 32 | + // CHECK: %[[#C2:]] = ttng.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<64x256xf32, #mma> |
| 33 | + %12 = ttng.warp_group_dot %9, %11, %arg8 {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> |
| 34 | + %13 = arith.addi %arg9, %c64_i32 {async_task_id = array<i32: 0>} : i32 |
| 35 | + scf.yield {async_task_id = array<i32: 0, 1, 2>} %12, %13 : tensor<128x256xf32, #mma>, i32 |
| 36 | + } {async_task_id = array<i32: 0, 1, 2>} |
| 37 | + %5 = arith.truncf %4#0 {async_task_id = array<i32: 1, 2>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> |
| 38 | + %6 = ttg.convert_layout %5 {async_task_id = array<i32: 1, 2>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> |
| 39 | + %7 = tt.splat %arg2 {async_task_id = array<i32: 1, 2>} : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #blocked1> |
| 40 | + // CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr<f16>, #blocked1> |
| 41 | + // CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr<f16>, #blocked1> |
| 42 | + tt.store %7, %6 {async_task_id = array<i32: 1, 2>} : tensor<128x256x!tt.ptr<f16>, #blocked1> |
| 43 | + } |
| 44 | + tt.return |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// ----- |
| 49 | + |
| 50 | +// CHECK-LABEL: @cross_dim_partition |
| 51 | +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> |
| 52 | +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> |
| 53 | +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> |
| 54 | +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 0}> |
| 55 | +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 0}> |
| 56 | +#smem = #ttg.shared_memory |
| 57 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { |
| 58 | + tt.func public @cross_dim_partition(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg7: f32, %arg8: i32, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32) attributes {noinline = false} { |
| 59 | + %cst = arith.constant {async_task_id = array<i32: 1, 2>} dense<0.000000e+00> : tensor<128x128xf32, #mma> |
| 60 | + %cst_0 = arith.constant {async_task_id = array<i32: 1, 2>} dense<true> : tensor<128x128xi1, #blocked> |
| 61 | + %c1_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 1 : i32 |
| 62 | + %c128_i32 = arith.constant {async_task_id = array<i32: 0, 1, 2>} 128 : i32 |
| 63 | + %c64_i32 = arith.constant {async_task_id = array<i32: 0>} 64 : i32 |
| 64 | + %0 = tt.get_program_id x {async_task_id = array<i32: 0, 1, 2>} : i32 |
| 65 | + %1 = tt.get_program_id y {async_task_id = array<i32: 0, 1, 2>} : i32 |
| 66 | + %2 = tt.load %arg1 {async_task_id = array<i32: 0, 1, 2>} : !tt.ptr<i32> |
| 67 | + %3 = arith.extsi %arg8 {async_task_id = array<i32: 0>} : i32 to i64 |
| 68 | + tt.experimental_tensormap_create %arg6, %arg0, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> () |
| 69 | + tt.experimental_tensormap_create %arg6, %arg2, [%c64_i32, %c128_i32], [%arg8, %arg9], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> () |
| 70 | + tt.experimental_tensormap_create %arg6, %arg3, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> () |
| 71 | + tt.experimental_tensormap_create %arg6, %arg5, [%c64_i32, %c64_i32], [%arg8, %2], [%3], [%c1_i32, %c1_i32] {async_task_id = array<i32: 0>, elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr<i8>, !tt.ptr<bf16>, i32, i32, i32, i32, i64, i32, i32) -> () |
| 72 | + %4 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>> |
| 73 | + %5 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>> |
| 74 | + %6 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>> |
| 75 | + %7 = tt.reinterpret_tensor_descriptor %arg6 {async_task_id = array<i32: 0>} : !tt.ptr<i8> to !tt.tensordesc<tensor<128x128xbf16>> |
| 76 | + // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 |
| 77 | + // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 |
| 78 | + %8 = tt.descriptor_load %4[%0, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1> |
| 79 | + %9 = ttg.local_alloc %8 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> |
| 80 | + // CHECK: tt.descriptor_load {{.*}} -> tensor<128x128xbf16 |
| 81 | + %10 = tt.descriptor_load %5[%1, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1> |
| 82 | + %11 = ttg.local_alloc %10 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> |
| 83 | + // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}} |
| 84 | + // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<64x128xbf16, {{.*}} * !ttg.memdesc<128x128xbf16, {{.*}} -> tensor<64x128xf32, {{.*}} |
| 85 | + %12 = ttng.warp_group_dot %9, %11, %cst {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x128xbf16, #shared, #smem> -> tensor<128x128xf32, #mma> |
| 86 | + %13 = arith.truncf %12 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> |
| 87 | + %14 = ttg.local_alloc %13 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #mma>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> |
| 88 | + // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 |
| 89 | + // CHECK: tt.descriptor_load {{.*}} -> tensor<64x128xbf16 |
| 90 | + %15 = tt.descriptor_load %6[%0, %1] {async_task_id = array<i32: 0>} : !tt.tensordesc<tensor<128x128xbf16>> -> tensor<128x128xbf16, #blocked1> |
| 91 | + %16 = ttg.local_alloc %15 {async_task_id = array<i32: 1, 2>} : (tensor<128x128xbf16, #blocked1>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> |
| 92 | + %17 = ttg.memdesc_trans %16 {async_task_id = array<i32: 1, 2>, order = array<i32: 1, 0>} : !ttg.memdesc<128x128xbf16, #shared, #smem> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> |
| 93 | + // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<128x64xbf16, {{.*}} * !ttg.memdesc<64x128xbf16, {{.*}} -> tensor<128x128xf32, {{.*}} |
| 94 | + // CHECK: ttng.warp_group_dot {{.*}} : !ttg.memdesc<128x64xbf16, {{.*}} * !ttg.memdesc<64x128xbf16, {{.*}} -> tensor<128x128xf32, {{.*}} |
| 95 | + %18 = ttng.warp_group_dot %17, %14, %cst {async_task_id = array<i32: 1, 2>, inputPrecision = 0 : i32} : !ttg.memdesc<128x128xbf16, #shared1, #smem> * !ttg.memdesc<128x128xbf16, #shared, #smem> -> tensor<128x128xf32, #mma> |
| 96 | + %19 = ttg.convert_layout %18 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> |
| 97 | + %20 = arith.truncf %19 {async_task_id = array<i32: 1, 2>} : tensor<128x128xf32, #blocked> to tensor<128x128xbf16, #blocked> |
| 98 | + %21 = tt.splat %arg4 {async_task_id = array<i32: 1, 2>} : !tt.ptr<bf16> -> tensor<1x128x!tt.ptr<bf16>, #blocked> |
| 99 | + %22 = tt.broadcast %21 {async_task_id = array<i32: 1, 2>} : tensor<1x128x!tt.ptr<bf16>, #blocked> -> tensor<128x128x!tt.ptr<bf16>, #blocked> |
| 100 | + %23 = tt.atomic_rmw fadd, relaxed, gpu, %22, %20, %cst_0 {async_task_id = array<i32: 1, 2>} : (tensor<128x128x!tt.ptr<bf16>, #blocked>, tensor<128x128xbf16, #blocked>, tensor<128x128xi1, #blocked>) -> tensor<128x128xbf16, #blocked> |
| 101 | + tt.return |
| 102 | + } |
| 103 | +} |
0 commit comments