|
| 1 | +module { |
| 2 | + tt.func public @attn_fwd(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: f32, %arg24: i32, %arg25: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg26: i32) attributes {noinline = false} { |
| 3 | + %c8192_i32 = arith.constant 8192 : i32 |
| 4 | + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32> |
| 5 | + %cst_0 = arith.constant dense<0.127517432> : tensor<256xf32> |
| 6 | + %cst_1 = arith.constant dense<0.127517432> : tensor<256x64xf32> |
| 7 | + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32> |
| 8 | + %c16640_i32 = arith.constant 16640 : i32 |
| 9 | + %c786432_i32 = arith.constant 786432 : i32 |
| 10 | + %cst_3 = arith.constant dense<0.000000e+00> : tensor<256x128xf16> |
| 11 | + %cst_4 = arith.constant dense<true> : tensor<256x128xi1> |
| 12 | + %cst_5 = arith.constant dense<1.000000e+00> : tensor<256x1xf32> |
| 13 | + %cst_6 = arith.constant dense<16384> : tensor<256x1xi32> |
| 14 | + %cst_7 = arith.constant dense<1.000000e+00> : tensor<256xf32> |
| 15 | + %cst_8 = arith.constant dense<0xFF800000> : tensor<256xf32> |
| 16 | + %c64_i32 = arith.constant 64 : i32 |
| 17 | + %c16384_i32 = arith.constant 16384 : i32 |
| 18 | + %c256_i32 = arith.constant 256 : i32 |
| 19 | + %c1_i32 = arith.constant 1 : i32 |
| 20 | + %true = arith.constant true |
| 21 | + %c0_i32 = arith.constant 0 : i32 |
| 22 | + %0 = arith.cmpi sge, %arg5, %c0_i32 : i32 |
| 23 | + llvm.intr.assume %0 : i1 |
| 24 | + %1 = arith.cmpi sge, %arg6, %c0_i32 : i32 |
| 25 | + llvm.intr.assume %1 : i1 |
| 26 | + %2 = arith.cmpi sge, %arg7, %c0_i32 : i32 |
| 27 | + llvm.intr.assume %2 : i1 |
| 28 | + llvm.intr.assume %true : i1 |
| 29 | + %3 = arith.cmpi sge, %arg8, %c0_i32 : i32 |
| 30 | + llvm.intr.assume %3 : i1 |
| 31 | + %4 = arith.cmpi sge, %arg9, %c0_i32 : i32 |
| 32 | + llvm.intr.assume %4 : i1 |
| 33 | + %5 = arith.cmpi sge, %arg10, %c0_i32 : i32 |
| 34 | + llvm.intr.assume %5 : i1 |
| 35 | + llvm.intr.assume %true : i1 |
| 36 | + %6 = arith.cmpi sge, %arg17, %c0_i32 : i32 |
| 37 | + llvm.intr.assume %6 : i1 |
| 38 | + %7 = arith.cmpi sge, %arg18, %c0_i32 : i32 |
| 39 | + llvm.intr.assume %7 : i1 |
| 40 | + %8 = arith.cmpi sge, %arg19, %c0_i32 : i32 |
| 41 | + llvm.intr.assume %8 : i1 |
| 42 | + %9 = arith.cmpi sge, %arg20, %c0_i32 : i32 |
| 43 | + llvm.intr.assume %9 : i1 |
| 44 | + %10 = arith.cmpi sge, %arg11, %c0_i32 : i32 |
| 45 | + llvm.intr.assume %10 : i1 |
| 46 | + %11 = arith.cmpi sge, %arg12, %c0_i32 : i32 |
| 47 | + llvm.intr.assume %11 : i1 |
| 48 | + %12 = arith.cmpi sge, %arg13, %c0_i32 : i32 |
| 49 | + llvm.intr.assume %12 : i1 |
| 50 | + llvm.intr.assume %true : i1 |
| 51 | + %13 = arith.cmpi sge, %arg14, %c0_i32 : i32 |
| 52 | + llvm.intr.assume %13 : i1 |
| 53 | + %14 = arith.cmpi sge, %arg15, %c0_i32 : i32 |
| 54 | + llvm.intr.assume %14 : i1 |
| 55 | + %15 = arith.cmpi sge, %arg16, %c0_i32 : i32 |
| 56 | + llvm.intr.assume %15 : i1 |
| 57 | + llvm.intr.assume %true : i1 |
| 58 | + %16 = tt.get_program_id x : i32 |
| 59 | + %17 = tt.get_program_id y : i32 |
| 60 | + %18 = tt.get_program_id z : i32 |
| 61 | + %19 = arith.muli %16, %c256_i32 : i32 |
| 62 | + %20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> |
| 63 | + %21 = tt.splat %19 : i32 -> tensor<256xi32> |
| 64 | + %22 = arith.addi %21, %20 : tensor<256xi32> |
| 65 | + %23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> |
| 66 | + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> |
| 67 | + %25 = arith.muli %18, %arg5 : i32 |
| 68 | + %26 = tt.addptr %arg0, %25 : !tt.ptr<f16>, i32 |
| 69 | + %27 = arith.muli %17, %arg6 : i32 |
| 70 | + %28 = tt.addptr %26, %27 : !tt.ptr<f16>, i32 |
| 71 | + %29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> |
| 72 | + %30 = tt.splat %arg7 : i32 -> tensor<256x1xi32> |
| 73 | + %31 = arith.muli %29, %30 : tensor<256x1xi32> |
| 74 | + %32 = tt.splat %28 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>> |
| 75 | + %33 = tt.addptr %32, %31 : tensor<256x1x!tt.ptr<f16>>, tensor<256x1xi32> |
| 76 | + %34 = tt.expand_dims %24 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> |
| 77 | + %35 = tt.broadcast %33 : tensor<256x1x!tt.ptr<f16>> -> tensor<256x128x!tt.ptr<f16>> |
| 78 | + %36 = tt.broadcast %34 : tensor<1x128xi32> -> tensor<256x128xi32> |
| 79 | + %37 = tt.addptr %35, %36 : tensor<256x128x!tt.ptr<f16>>, tensor<256x128xi32> |
| 80 | + %38 = arith.muli %18, %arg8 : i32 |
| 81 | + %39 = tt.addptr %arg1, %38 : !tt.ptr<f16>, i32 |
| 82 | + %40 = arith.muli %17, %arg9 : i32 |
| 83 | + %41 = tt.addptr %39, %40 : !tt.ptr<f16>, i32 |
| 84 | + %42 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> |
| 85 | + %43 = tt.splat %41 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>> |
| 86 | + %44 = tt.addptr %43, %42 : tensor<128x1x!tt.ptr<f16>>, tensor<128x1xi32> |
| 87 | + %45 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> |
| 88 | + %46 = tt.splat %arg10 : i32 -> tensor<1x64xi32> |
| 89 | + %47 = arith.muli %45, %46 : tensor<1x64xi32> |
| 90 | + %48 = tt.broadcast %44 : tensor<128x1x!tt.ptr<f16>> -> tensor<128x64x!tt.ptr<f16>> |
| 91 | + %49 = tt.broadcast %47 : tensor<1x64xi32> -> tensor<128x64xi32> |
| 92 | + %50 = tt.addptr %48, %49 : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32> |
| 93 | + %51 = arith.muli %18, %arg11 : i32 |
| 94 | + %52 = tt.addptr %arg2, %51 : !tt.ptr<f16>, i32 |
| 95 | + %53 = arith.muli %17, %arg12 : i32 |
| 96 | + %54 = tt.addptr %52, %53 : !tt.ptr<f16>, i32 |
| 97 | + %55 = tt.expand_dims %23 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> |
| 98 | + %56 = tt.splat %arg13 : i32 -> tensor<64x1xi32> |
| 99 | + %57 = arith.muli %55, %56 : tensor<64x1xi32> |
| 100 | + %58 = tt.splat %54 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>> |
| 101 | + %59 = tt.addptr %58, %57 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32> |
| 102 | + %60 = tt.broadcast %59 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x128x!tt.ptr<f16>> |
| 103 | + %61 = tt.broadcast %34 : tensor<1x128xi32> -> tensor<64x128xi32> |
| 104 | + %62 = tt.addptr %60, %61 : tensor<64x128x!tt.ptr<f16>>, tensor<64x128xi32> |
| 105 | + %63 = arith.cmpi slt, %29, %cst_6 : tensor<256x1xi32> |
| 106 | + %64 = tt.broadcast %63 : tensor<256x1xi1> -> tensor<256x128xi1> |
| 107 | + %65 = arith.muli %arg10, %c64_i32 : i32 |
| 108 | + %66 = tt.splat %65 : i32 -> tensor<128x64xi32> |
| 109 | + %67 = arith.muli %arg13, %c64_i32 : i32 |
| 110 | + %68 = tt.splat %67 : i32 -> tensor<64x128xi32> |
| 111 | + %69 = arith.addi %16, %c1_i32 : i32 |
| 112 | + %70 = arith.muli %69, %c256_i32 : i32 |
| 113 | + %71 = arith.muli %18, %c786432_i32 : i32 |
| 114 | + %72 = tt.addptr %arg3, %71 : !tt.ptr<f32>, i32 |
| 115 | + %73 = arith.muli %17, %c16384_i32 : i32 |
| 116 | + %74 = tt.addptr %72, %73 : !tt.ptr<f32>, i32 |
| 117 | + %75 = tt.splat %74 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> |
| 118 | + %76 = tt.addptr %75, %22 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> |
| 119 | + %77 = arith.subi %70, %c16384_i32 : i32 |
| 120 | + %78 = arith.cmpi sgt, %77, %c0_i32 : i32 |
| 121 | + %79 = arith.muli %18, %arg14 : i32 |
| 122 | + %80 = tt.addptr %arg4, %79 : !tt.ptr<f16>, i32 |
| 123 | + %81 = arith.muli %17, %arg15 : i32 |
| 124 | + %82 = tt.addptr %80, %81 : !tt.ptr<f16>, i32 |
| 125 | + %83 = tt.splat %arg16 : i32 -> tensor<256x1xi32> |
| 126 | + %84 = arith.muli %29, %83 : tensor<256x1xi32> |
| 127 | + %85 = tt.splat %82 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>> |
| 128 | + %86 = tt.addptr %85, %84 : tensor<256x1x!tt.ptr<f16>>, tensor<256x1xi32> |
| 129 | + %87 = tt.broadcast %86 : tensor<256x1x!tt.ptr<f16>> -> tensor<256x128x!tt.ptr<f16>> |
| 130 | + %88 = tt.addptr %87, %36 : tensor<256x128x!tt.ptr<f16>>, tensor<256x128xi32> |
| 131 | + %89 = scf.if %78 -> (tensor<256x128xi1>) { |
| 132 | + scf.yield %64 : tensor<256x128xi1> |
| 133 | + } else { |
| 134 | + scf.yield %cst_4 : tensor<256x128xi1> |
| 135 | + } |
| 136 | + scf.while (%arg27 = %c0_i32) : (i32) -> () { |
| 137 | + %90 = arith.cmpi slt, %arg27, %c1_i32 : i32 |
| 138 | + scf.condition(%90) |
| 139 | + } do { |
| 140 | + %90 = tt.load %37, %64, %cst_3 : tensor<256x128x!tt.ptr<f16>> |
| 141 | + %91:5 = scf.for %arg27 = %c0_i32 to %c8192_i32 step %c64_i32 iter_args(%arg28 = %cst_2, %arg29 = %cst_7, %arg30 = %cst_8, %arg31 = %50, %arg32 = %62) -> (tensor<256x128xf32>, tensor<256xf32>, tensor<256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x128x!tt.ptr<f16>>) : i32 { |
| 142 | + %97 = tt.load %arg31 : tensor<128x64x!tt.ptr<f16>> |
| 143 | + %98 = tt.dot %90, %97, %cst : tensor<256x128xf16> * tensor<128x64xf16> -> tensor<256x64xf32> |
| 144 | + %99 = "tt.reduce"(%98) <{axis = 1 : i32}> ({ |
| 145 | + ^bb0(%arg33: f32, %arg34: f32): |
| 146 | + %121 = arith.maxnumf %arg33, %arg34 : f32 |
| 147 | + tt.reduce.return %121 : f32 |
| 148 | + }) : (tensor<256x64xf32>) -> tensor<256xf32> |
| 149 | + %100 = arith.maxnumf %arg30, %99 : tensor<256xf32> |
| 150 | + %101 = arith.mulf %100, %cst_0 : tensor<256xf32> |
| 151 | + %102 = arith.mulf %98, %cst_1 : tensor<256x64xf32> |
| 152 | + %103 = tt.expand_dims %101 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32> |
| 153 | + %104 = tt.broadcast %103 : tensor<256x1xf32> -> tensor<256x64xf32> |
| 154 | + %105 = arith.subf %102, %104 : tensor<256x64xf32> |
| 155 | + %106 = math.exp2 %105 : tensor<256x64xf32> |
| 156 | + %107 = "tt.reduce"(%106) <{axis = 1 : i32}> ({ |
| 157 | + ^bb0(%arg33: f32, %arg34: f32): |
| 158 | + %121 = arith.addf %arg33, %arg34 : f32 |
| 159 | + tt.reduce.return %121 : f32 |
| 160 | + }) : (tensor<256x64xf32>) -> tensor<256xf32> |
| 161 | + %108 = arith.mulf %arg30, %cst_0 : tensor<256xf32> |
| 162 | + %109 = arith.subf %108, %101 : tensor<256xf32> |
| 163 | + %110 = math.exp2 %109 : tensor<256xf32> |
| 164 | + %111 = tt.expand_dims %110 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32> |
| 165 | + %112 = tt.broadcast %111 : tensor<256x1xf32> -> tensor<256x128xf32> |
| 166 | + %113 = arith.mulf %arg28, %112 : tensor<256x128xf32> |
| 167 | + %114 = tt.load %arg32 : tensor<64x128x!tt.ptr<f16>> |
| 168 | + %115 = arith.mulf %arg29, %110 : tensor<256xf32> |
| 169 | + %116 = arith.addf %115, %107 : tensor<256xf32> |
| 170 | + %117 = arith.truncf %106 : tensor<256x64xf32> to tensor<256x64xf16> |
| 171 | + %118 = tt.dot %117, %114, %113 : tensor<256x64xf16> * tensor<64x128xf16> -> tensor<256x128xf32> |
| 172 | + %119 = tt.addptr %arg31, %66 : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32> |
| 173 | + %120 = tt.addptr %arg32, %68 : tensor<64x128x!tt.ptr<f16>>, tensor<64x128xi32> |
| 174 | + scf.yield %118, %116, %100, %119, %120 : tensor<256x128xf32>, tensor<256xf32>, tensor<256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x128x!tt.ptr<f16>> |
| 175 | + } |
| 176 | + gpu.barrier |
| 177 | + %92 = tt.expand_dims %91#1 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32> |
| 178 | + %93 = arith.divf %cst_5, %92 : tensor<256x1xf32> |
| 179 | + %94 = tt.broadcast %93 : tensor<256x1xf32> -> tensor<256x128xf32> |
| 180 | + %95 = arith.mulf %91#0, %94 : tensor<256x128xf32> |
| 181 | + %96 = arith.truncf %95 : tensor<256x128xf32> to tensor<256x128xf16> |
| 182 | + scf.if %78 { |
| 183 | + %97 = arith.subi %c16640_i32, %70 : i32 |
| 184 | + %98 = tt.splat %97 : i32 -> tensor<256xi32> |
| 185 | + %99 = arith.cmpi slt, %20, %98 : tensor<256xi32> |
| 186 | + %100 = math.log2 %91#1 : tensor<256xf32> |
| 187 | + %101 = arith.addf %91#2, %100 : tensor<256xf32> |
| 188 | + tt.store %76, %101, %99 : tensor<256x!tt.ptr<f32>> |
| 189 | + } else { |
| 190 | + %97 = math.log2 %91#1 : tensor<256xf32> |
| 191 | + %98 = arith.addf %91#2, %97 : tensor<256xf32> |
| 192 | + tt.store %76, %98 : tensor<256x!tt.ptr<f32>> |
| 193 | + } |
| 194 | + tt.store %88, %96, %89 : tensor<256x128x!tt.ptr<f16>> |
| 195 | + scf.yield %c1_i32 : i32 |
| 196 | + } |
| 197 | + tt.return |
| 198 | + } |
| 199 | +} |
0 commit comments