|
| 1 | +// RUN: triton-opt %s -triton-intel-remove-masks | FileCheck %s |
| 2 | + |
| 3 | +module { |
| 4 | + tt.func public @test_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { |
| 5 | + %c31_i32 = arith.constant 31 : i32 |
| 6 | + %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32> |
| 7 | + %c127_i32 = arith.constant 127 : i32 |
| 8 | + %c63_i32 = arith.constant 63 : i32 |
| 9 | + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xf16> |
| 10 | + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> |
| 11 | + %c1_i32 = arith.constant 1 : i32 |
| 12 | + %c0_i32 = arith.constant 0 : i32 |
| 13 | + %cst_2 = arith.constant dense<32> : tensor<64x32xi32> |
| 14 | + %c32_i32 = arith.constant 32 : i32 |
| 15 | + %c128_i32 = arith.constant 128 : i32 |
| 16 | + %c64_i32 = arith.constant 64 : i32 |
| 17 | + %c4_i32 = arith.constant 4 : i32 |
| 18 | + %0 = tt.get_program_id x : i32 |
| 19 | + %1 = arith.addi %arg3, %c63_i32 : i32 |
| 20 | + %2 = arith.divsi %1, %c64_i32 : i32 |
| 21 | + %3 = arith.addi %arg4, %c127_i32 : i32 |
| 22 | + %4 = arith.divsi %3, %c128_i32 : i32 |
| 23 | + %5 = arith.muli %4, %c4_i32 : i32 |
| 24 | + %6 = arith.divsi %0, %5 : i32 |
| 25 | + %7 = arith.muli %6, %c4_i32 : i32 |
| 26 | + %8 = arith.subi %2, %7 : i32 |
| 27 | + %9 = arith.minsi %8, %c4_i32 : i32 |
| 28 | + %10 = arith.remsi %0, %5 : i32 |
| 29 | + %11 = arith.remsi %10, %9 : i32 |
| 30 | + %12 = arith.addi %7, %11 : i32 |
| 31 | + %13 = arith.divsi %10, %9 : i32 |
| 32 | + %14 = arith.muli %12, %c64_i32 : i32 |
| 33 | + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> |
| 34 | + %16 = tt.splat %14 : i32 -> tensor<64xi32> |
| 35 | + %17 = arith.addi %16, %15 : tensor<64xi32> |
| 36 | + %18 = tt.splat %arg3 : i32 -> tensor<64xi32> |
| 37 | + %19 = arith.remsi %17, %18 : tensor<64xi32> |
| 38 | + %20 = arith.muli %13, %c128_i32 : i32 |
| 39 | + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> |
| 40 | + %22 = tt.splat %20 : i32 -> tensor<128xi32> |
| 41 | + %23 = arith.addi %22, %21 : tensor<128xi32> |
| 42 | + %24 = tt.splat %arg4 : i32 -> tensor<128xi32> |
| 43 | + %25 = arith.remsi %23, %24 : tensor<128xi32> |
| 44 | + %26 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> |
| 45 | + %27 = tt.expand_dims %19 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> |
| 46 | + %28 = tt.splat %arg6 : i32 -> tensor<64x1xi32> |
| 47 | + %29 = arith.muli %27, %28 : tensor<64x1xi32> |
| 48 | + %30 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> |
| 49 | + %31 = tt.broadcast %29 : tensor<64x1xi32> -> tensor<64x32xi32> |
| 50 | + %32 = tt.broadcast %30 : tensor<1x32xi32> -> tensor<64x32xi32> |
| 51 | + %33 = arith.addi %31, %32 : tensor<64x32xi32> |
| 52 | + %34 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x32x!tt.ptr<f16>> |
| 53 | + %35 = tt.addptr %34, %33 : tensor<64x32x!tt.ptr<f16>>, tensor<64x32xi32> |
| 54 | + %36 = tt.expand_dims %26 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> |
| 55 | + %37 = tt.splat %arg7 : i32 -> tensor<32x1xi32> |
| 56 | + %38 = arith.muli %36, %37 : tensor<32x1xi32> |
| 57 | + %39 = tt.expand_dims %25 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> |
| 58 | + %40 = tt.broadcast %38 : tensor<32x1xi32> -> tensor<32x128xi32> |
| 59 | + %41 = tt.broadcast %39 : tensor<1x128xi32> -> tensor<32x128xi32> |
| 60 | + %42 = arith.addi %40, %41 : tensor<32x128xi32> |
| 61 | + %43 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>> |
| 62 | + %44 = tt.addptr %43, %42 : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32> |
| 63 | + %45 = arith.addi %arg5, %c31_i32 : i32 |
| 64 | + %46 = arith.divsi %45, %c32_i32 : i32 |
| 65 | + %47 = arith.muli %arg7, %c32_i32 : i32 |
| 66 | + %48 = tt.splat %47 : i32 -> tensor<32x128xi32> |
| 67 | + %49:3 = scf.for %arg9 = %c0_i32 to %46 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %35, %arg12 = %44) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 { |
| 68 | + %67 = arith.muli %arg9, %c32_i32 : i32 |
| 69 | + %68 = arith.subi %arg5, %67 : i32 |
| 70 | + %69 = tt.splat %68 : i32 -> tensor<1x32xi32> |
| 71 | + %70 = arith.cmpi slt, %30, %69 : tensor<1x32xi32> |
| 72 | + %71 = tt.broadcast %70 : tensor<1x32xi1> -> tensor<64x32xi1> |
| 73 | + %72 = tt.load %arg11, %71, %cst_1 : tensor<64x32x!tt.ptr<f16>> |
| 74 | + %73 = tt.splat %68 : i32 -> tensor<32x1xi32> |
| 75 | + %74 = arith.cmpi slt, %36, %73 : tensor<32x1xi32> |
| 76 | + %75 = tt.broadcast %74 : tensor<32x1xi1> -> tensor<32x128xi1> |
| 77 | + %76 = tt.load %arg12, %75, %cst_0 : tensor<32x128x!tt.ptr<f16>> |
| 78 | + %77 = tt.dot %72, %76, %arg10, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32> |
| 79 | + %78 = tt.addptr %arg11, %cst_2 : tensor<64x32x!tt.ptr<f16>>, tensor<64x32xi32> |
| 80 | + %79 = tt.addptr %arg12, %48 : tensor<32x128x!tt.ptr<f16>>, tensor<32x128xi32> |
| 81 | + scf.yield %77, %78, %79 : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>> |
| 82 | + } |
| 83 | + %50 = arith.truncf %49#0 : tensor<64x128xf32> to tensor<64x128xf16> |
| 84 | + tt.return |
| 85 | + } |
| 86 | + |
| 87 | + // CHECK: tt.func public @test_kernel([[PARAM_0_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, [[PARAM_3_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_4_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_5_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_6_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_7_:%.+]]: i32 {tt.divisibility = 16 : i32}, [[PARAM_8_:%.+]]: i32 {tt.divisibility = 16 : i32}) { |
| 88 | + // CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32 |
| 89 | + // CHECK: [[CST_32_i32:%.+]] = arith.constant 32 : i32 |
| 90 | + // CHECK: [[REM:%.+]] = arith.remsi [[PARAM_5_]], [[CST_32_i32]] : i32 |
| 91 | + // CHECK: [[CMP1:%.+]] = arith.cmpi eq, [[REM]], [[CST_0_i32]] : i32 |
| 92 | + // CHECK: [[CMP2:%.+]] = arith.cmpi sgt, [[PARAM_5_]], [[CST_32_i32]] : i32 |
| 93 | + // CHECK: [[VER_COND:%.+]] = arith.andi [[CMP1]], [[CMP2]] : i1 |
| 94 | + // CHECK: [[LOOP_VER:%.+]] = scf.if [[VER_COND]] -> (tensor<64x128xf32>) { |
| 95 | + // CHECK: [[THEN_LOOP_RES:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg10:%.+]] = {{.*}}, [[VAR_arg11:%.+]] = {{.*}}, [[VAR_arg12:%.+]] = {{.*}}) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 { |
| 96 | + // CHECK: [[LOAD_A1:%.+]] = tt.load [[VAR_arg11]] : tensor<64x32x!tt.ptr<f16>> |
| 97 | + // CHECK: [[LOAD_B2:%.+]] = tt.load [[VAR_arg12]] : tensor<32x128x!tt.ptr<f16>> |
| 98 | + // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>> |
| 99 | + // CHECK: } |
| 100 | + // CHECK: scf.yield [[THEN_LOOP_RES]]#0 : tensor<64x128xf32> |
| 101 | + // CHECK: } else { |
| 102 | + // CHECK: [[ELSE_LOOP_RES:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg10:%.+]] = {{.*}}, [[VAR_arg11:%.+]] = {{.*}}, [[VAR_arg12:%.+]] = {{.*}}) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>>) : i32 { |
| 103 | + // CHECK: [[LOAD_A2:%.+]] = tt.load [[VAR_arg11]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr<f16>> |
| 104 | + // CHECK: [[LOAD_B2:%.+]] = tt.load [[VAR_arg12]], {{.*}}, {{.*}} : tensor<32x128x!tt.ptr<f16>> |
| 105 | + // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x128x!tt.ptr<f16>> |
| 106 | + // CHECK: } |
| 107 | + // CHECK: scf.yield [[ELSE_LOOP_RES]]#0 : tensor<64x128xf32> |
| 108 | + // CHECK: } |
| 109 | + // CHECK: tt.return |
| 110 | + // CHECK: } |
| 111 | +} |
0 commit comments