|
| 1 | +// RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s |
| 2 | + |
| 3 | +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> |
| 4 | + |
| 5 | +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { |
| 6 | + |
| 7 | +// CHECK-LABEL: @softmax_kernel |
| 8 | +tt.func public @softmax_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 9 | + %cst = arith.constant dense<0xFF800000> : tensor<128xf32, #blocked> |
| 10 | + %0 = tt.get_program_id x : i32 |
| 11 | + %1 = tt.get_num_programs x : i32 |
| 12 | + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> |
| 13 | + %3 = tt.splat %arg5 : i32 -> tensor<128xi32, #blocked> |
| 14 | + // CHECK: [[MASK:%.*]] = arith.cmpi slt, {{.*}} tensor<128xi32, |
| 15 | + %4 = arith.cmpi slt, %2, %3 : tensor<128xi32, #blocked> |
| 16 | + // CHECK: scf.for |
| 17 | + scf.for %arg6 = %0 to %arg4 step %1 : i32 { |
| 18 | + %5 = tt.splat %arg1 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked> |
| 19 | + %6 = tt.addptr %5, %2 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked> |
| 20 | + // CHECK: [[RESULT:%.*]] = triton_gpu.local_load |
| 21 | + // CHECK-NEXT: arith.select [[MASK]], [[RESULT]], %cst |
| 22 | + %7 = tt.load %6, %4, %cst {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x!tt.ptr<f32>, #blocked> |
| 23 | + %8 = tt.splat %arg0 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked> |
| 24 | + %9 = tt.addptr %8, %2 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked>, tensor<128xi32, #blocked> |
| 25 | + tt.store %9, %7, %4 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x!tt.ptr<f32>, #blocked> |
| 26 | + } {tt.num_stages = 2 : i32} |
| 27 | + tt.return |
| 28 | +} |
| 29 | + |
| 30 | +} |
0 commit comments