|
| 1 | +; This is an excerpt from the tutorial of the Triton language converted into |
| 2 | +; LLVM IR via the Triton XPU backend and cleaned of irrelevant details. |
| 3 | +; The only pass criterion is that spirv-val considers output valid. |
| 4 | + |
| 5 | +; Ths particular case is related to translation of <1 x Ty> vectors. |
| 6 | + |
| 7 | +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %} |
| 8 | + |
| 9 | +define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) { |
| 10 | + %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) |
| 11 | + %9 = trunc i64 %8 to i32 |
| 12 | + %10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0) |
| 13 | + %11 = trunc i64 %10 to i32 |
| 14 | + %12 = tail call spir_func i64 @_Z12get_local_idj(i32 0) |
| 15 | + %13 = trunc i64 %12 to i32 |
| 16 | + %14 = and i32 %13, 255 |
| 17 | + %15 = or disjoint i32 %14, 256 |
| 18 | + %16 = or disjoint i32 %14, 512 |
| 19 | + %17 = or disjoint i32 %14, 768 |
| 20 | + %18 = icmp slt i32 %14, %5 |
| 21 | + %19 = icmp slt i32 %15, %5 |
| 22 | + %20 = icmp slt i32 %16, %5 |
| 23 | + %21 = icmp slt i32 %17, %5 |
| 24 | + %22 = icmp sgt i32 %4, %9 |
| 25 | + br i1 %22, label %.lr.ph, label %._crit_edge |
| 26 | + |
| 27 | +.lr.ph: ; preds = %7 |
| 28 | + %23 = lshr i64 %12, 5 |
| 29 | + %24 = and i32 %13, 31 |
| 30 | + %25 = zext nneg i32 %15 to i64 |
| 31 | + %26 = zext nneg i32 %16 to i64 |
| 32 | + %27 = zext nneg i32 %17 to i64 |
| 33 | + %28 = and i64 %12, 255 |
| 34 | + %29 = and i64 %23, 7 |
| 35 | + %30 = icmp eq i32 %24, 0 |
| 36 | + %31 = getelementptr float, ptr addrspace(3) %6, i64 %29 |
| 37 | + %32 = icmp slt i32 %13, 8 |
| 38 | + %sext = shl i64 %12, 32 |
| 39 | + %33 = ashr exact i64 %sext, 30 |
| 40 | + %34 = getelementptr i8, ptr addrspace(3) %6, i64 %33 |
| 41 | + %35 = and i32 %13, 7 |
| 42 | + %36 = icmp eq i32 %35, 0 |
| 43 | + %37 = and i1 %32, %36 |
| 44 | + br label %38 |
| 45 | + |
| 46 | +38: ; preds = %.lr.ph, %123 |
| 47 | + %39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ] |
| 48 | + %40 = mul i32 %39, %2 |
| 49 | + %41 = sext i32 %40 to i64 |
| 50 | + %42 = getelementptr float, ptr addrspace(1) %1, i64 %41 |
| 51 | + %43 = getelementptr float, ptr addrspace(1) %42, i64 %25 |
| 52 | + %44 = getelementptr float, ptr addrspace(1) %42, i64 %26 |
| 53 | + %45 = getelementptr float, ptr addrspace(1) %42, i64 %27 |
| 54 | + br i1 %18, label %46, label %49 |
| 55 | + |
| 56 | +46: ; preds = %38 |
| 57 | + %47 = getelementptr float, ptr addrspace(1) %42, i64 %28 |
| 58 | + %48 = load <1 x float>, ptr addrspace(1) %47, align 4 |
| 59 | + br label %49 |
| 60 | + |
| 61 | +49: ; preds = %46, %38 |
| 62 | + %50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ] |
| 63 | + %51 = extractelement <1 x float> %50, i64 0 |
| 64 | + br i1 %19, label %52, label %54 |
| 65 | + |
| 66 | +52: ; preds = %49 |
| 67 | + %53 = load <1 x float>, ptr addrspace(1) %43, align 4 |
| 68 | + br label %54 |
| 69 | + |
| 70 | +54: ; preds = %52, %49 |
| 71 | + %55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ] |
| 72 | + %56 = extractelement <1 x float> %55, i64 0 |
| 73 | + br i1 %20, label %57, label %59 |
| 74 | + |
| 75 | +57: ; preds = %54 |
| 76 | + %58 = load <1 x float>, ptr addrspace(1) %44, align 4 |
| 77 | + br label %59 |
| 78 | + |
| 79 | +59: ; preds = %57, %54 |
| 80 | + %60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ] |
| 81 | + %61 = extractelement <1 x float> %60, i64 0 |
| 82 | + br i1 %21, label %62, label %64 |
| 83 | + |
| 84 | +62: ; preds = %59 |
| 85 | + %63 = load <1 x float>, ptr addrspace(1) %45, align 4 |
| 86 | + br label %64 |
| 87 | + |
| 88 | +64: ; preds = %62, %59 |
| 89 | + %65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ] |
| 90 | + %66 = extractelement <1 x float> %65, i64 0 |
| 91 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 92 | + %67 = tail call float @llvm.maxnum.f32(float %51, float %56) |
| 93 | + %68 = tail call float @llvm.maxnum.f32(float %67, float %61) |
| 94 | + %69 = tail call float @llvm.maxnum.f32(float %68, float %66) |
| 95 | + %70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69) |
| 96 | + br i1 %30, label %71, label %72 |
| 97 | + |
| 98 | +71: ; preds = %64 |
| 99 | + store float %70, ptr addrspace(3) %31, align 4 |
| 100 | + br label %72 |
| 101 | + |
| 102 | +72: ; preds = %71, %64 |
| 103 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 104 | + br i1 %32, label %74, label %.thread1 |
| 105 | + |
| 106 | +.thread1: ; preds = %72 |
| 107 | + %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float undef, i32 8) |
| 108 | + br label %78 |
| 109 | + |
| 110 | +74: ; preds = %72 |
| 111 | + %75 = load float, ptr addrspace(3) %34, align 4 |
| 112 | + %76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8) |
| 113 | + br i1 %37, label %77, label %78 |
| 114 | + |
| 115 | +77: ; preds = %74 |
| 116 | + store float %76, ptr addrspace(3) %34, align 4 |
| 117 | + br label %78 |
| 118 | + |
| 119 | +78: ; preds = %.thread1, %77, %74 |
| 120 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 121 | + %79 = load float, ptr addrspace(3) %6, align 4 |
| 122 | + %80 = fsub float %51, %79 |
| 123 | + %81 = fsub float %56, %79 |
| 124 | + %82 = fsub float %61, %79 |
| 125 | + %83 = fsub float %66, %79 |
| 126 | + %84 = fmul float %80, 0x3FF7154760000000 |
| 127 | + %85 = tail call float @llvm.exp2.f32(float %84) |
| 128 | + %86 = fmul float %81, 0x3FF7154760000000 |
| 129 | + %87 = tail call float @llvm.exp2.f32(float %86) |
| 130 | + %88 = fmul float %82, 0x3FF7154760000000 |
| 131 | + %89 = tail call float @llvm.exp2.f32(float %88) |
| 132 | + %90 = fmul float %83, 0x3FF7154760000000 |
| 133 | + %91 = tail call float @llvm.exp2.f32(float %90) |
| 134 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 135 | + %92 = fadd float %85, %87 |
| 136 | + %93 = fadd float %89, %92 |
| 137 | + %94 = fadd float %91, %93 |
| 138 | + %95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94) |
| 139 | + br i1 %30, label %96, label %97 |
| 140 | + |
| 141 | +96: ; preds = %78 |
| 142 | + store float %95, ptr addrspace(3) %31, align 4 |
| 143 | + br label %97 |
| 144 | + |
| 145 | +97: ; preds = %96, %78 |
| 146 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 147 | + br i1 %32, label %99, label %.thread |
| 148 | + |
| 149 | +.thread: ; preds = %97 |
| 150 | + %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float undef, i32 8) |
| 151 | + br label %103 |
| 152 | + |
| 153 | +99: ; preds = %97 |
| 154 | + %100 = load float, ptr addrspace(3) %34, align 4 |
| 155 | + %101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8) |
| 156 | + br i1 %37, label %102, label %103 |
| 157 | + |
| 158 | +102: ; preds = %99 |
| 159 | + store float %101, ptr addrspace(3) %34, align 4 |
| 160 | + br label %103 |
| 161 | + |
| 162 | +103: ; preds = %.thread, %102, %99 |
| 163 | + tail call spir_func void @_Z7barrierj(i32 1) |
| 164 | + %104 = load float, ptr addrspace(3) %6, align 4 |
| 165 | + %105 = fdiv float %87, %104 |
| 166 | + %106 = fdiv float %89, %104 |
| 167 | + %107 = fdiv float %91, %104 |
| 168 | + %108 = mul i32 %39, %3 |
| 169 | + %109 = sext i32 %108 to i64 |
| 170 | + %110 = getelementptr float, ptr addrspace(1) %0, i64 %109 |
| 171 | + %111 = getelementptr float, ptr addrspace(1) %110, i64 %25 |
| 172 | + %112 = getelementptr float, ptr addrspace(1) %110, i64 %26 |
| 173 | + %113 = getelementptr float, ptr addrspace(1) %110, i64 %27 |
| 174 | + br i1 %18, label %114, label %117 |
| 175 | + |
| 176 | +114: ; preds = %103 |
| 177 | + %115 = fdiv float %85, %104 |
| 178 | + %116 = getelementptr float, ptr addrspace(1) %110, i64 %28 |
| 179 | + store float %115, ptr addrspace(1) %116, align 4 |
| 180 | + br label %117 |
| 181 | + |
| 182 | +117: ; preds = %114, %103 |
| 183 | + br i1 %19, label %118, label %119 |
| 184 | + |
| 185 | +118: ; preds = %117 |
| 186 | + store float %105, ptr addrspace(1) %111, align 4 |
| 187 | + br label %119 |
| 188 | + |
| 189 | +119: ; preds = %118, %117 |
| 190 | + br i1 %20, label %120, label %121 |
| 191 | + |
| 192 | +120: ; preds = %119 |
| 193 | + store float %106, ptr addrspace(1) %112, align 4 |
| 194 | + br label %121 |
| 195 | + |
| 196 | +121: ; preds = %120, %119 |
| 197 | + br i1 %21, label %122, label %123 |
| 198 | + |
| 199 | +122: ; preds = %121 |
| 200 | + store float %107, ptr addrspace(1) %113, align 4 |
| 201 | + br label %123 |
| 202 | + |
| 203 | +123: ; preds = %122, %121 |
| 204 | + %124 = add i32 %39, %11 |
| 205 | + %125 = icmp slt i32 %124, %4 |
| 206 | + br i1 %125, label %38, label %._crit_edge |
| 207 | + |
| 208 | +._crit_edge: ; preds = %123, %7 |
| 209 | + ret void |
| 210 | +} |
| 211 | + |
| 212 | +declare float @llvm.maxnum.f32(float, float) |
| 213 | +declare spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, float, i32) |
| 214 | +declare spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32, i32, float) |
| 215 | +declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32, i32, float, i32) |
| 216 | +declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, float) |
| 217 | +declare spir_func void @_Z7barrierj(i32) |
| 218 | +declare spir_func i64 @_Z12get_local_idj(i32) |
| 219 | +declare spir_func i64 @_Z14get_num_groupsj(i32) |
| 220 | +declare spir_func i64 @_Z12get_group_idj(i32) |
| 221 | +declare float @llvm.exp2.f32(float) |
0 commit comments