Skip to content

Commit 692565e

Browse files
fix insert/extract; add a test case
1 parent 2fdc350 commit 692565e

File tree

3 files changed

+236
-0
lines changed

3 files changed

+236
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
17061706
}
17071707

17081708
Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
1709+
// If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
1710+
// type in LLT and IRTranslator will replace it by the scalar.
1711+
if (isVector1(I.getType()))
1712+
return &I;
1713+
17091714
SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(),
17101715
I.getOperand(1)->getType(),
17111716
I.getOperand(2)->getType()};
@@ -1719,6 +1724,11 @@ Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
17191724

17201725
Instruction *
17211726
SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) {
1727+
// If it's a <1 x Type> vector type, don't modify it. It's not a legal vector
1728+
// type in LLT and IRTranslator will replace it by the scalar.
1729+
if (isVector1(I.getVectorOperandType()))
1730+
return &I;
1731+
17221732
IRBuilder<> B(I.getParent());
17231733
B.SetInsertPoint(&I);
17241734
SmallVector<Type *, 3> Types = {I.getType(), I.getVectorOperandType(),

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,11 @@ inline const Type *unifyPtrType(const Type *Ty) {
383383
return toTypedPointer(const_cast<Type *>(Ty));
384384
}
385385

386+
inline bool isVector1(Type *Ty) {
387+
auto *FVTy = dyn_cast<FixedVectorType>(Ty);
388+
return FVTy && FVTy->getNumElements() == 1;
389+
}
390+
386391
// Modify an LLVM type to conform with future transformations in IRTranslator.
387392
// At the moment use cases comprise only a <1 x Type> vector. To extend when/if
388393
// needed.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
; This is an excerpt from the tutorial of the Triton language converted into
2+
; LLVM IR via the Triton XPU backend and cleaned of irrelevant details.
3+
; The only pass criterion is that spirv-val considers output valid.
4+
5+
; Ths particular case is related to translation of <1 x Ty> vectors.
6+
7+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}
8+
9+
define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) {
10+
%8 = tail call spir_func i64 @_Z12get_group_idj(i32 0)
11+
%9 = trunc i64 %8 to i32
12+
%10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0)
13+
%11 = trunc i64 %10 to i32
14+
%12 = tail call spir_func i64 @_Z12get_local_idj(i32 0)
15+
%13 = trunc i64 %12 to i32
16+
%14 = and i32 %13, 255
17+
%15 = or disjoint i32 %14, 256
18+
%16 = or disjoint i32 %14, 512
19+
%17 = or disjoint i32 %14, 768
20+
%18 = icmp slt i32 %14, %5
21+
%19 = icmp slt i32 %15, %5
22+
%20 = icmp slt i32 %16, %5
23+
%21 = icmp slt i32 %17, %5
24+
%22 = icmp sgt i32 %4, %9
25+
br i1 %22, label %.lr.ph, label %._crit_edge
26+
27+
.lr.ph: ; preds = %7
28+
%23 = lshr i64 %12, 5
29+
%24 = and i32 %13, 31
30+
%25 = zext nneg i32 %15 to i64
31+
%26 = zext nneg i32 %16 to i64
32+
%27 = zext nneg i32 %17 to i64
33+
%28 = and i64 %12, 255
34+
%29 = and i64 %23, 7
35+
%30 = icmp eq i32 %24, 0
36+
%31 = getelementptr float, ptr addrspace(3) %6, i64 %29
37+
%32 = icmp slt i32 %13, 8
38+
%sext = shl i64 %12, 32
39+
%33 = ashr exact i64 %sext, 30
40+
%34 = getelementptr i8, ptr addrspace(3) %6, i64 %33
41+
%35 = and i32 %13, 7
42+
%36 = icmp eq i32 %35, 0
43+
%37 = and i1 %32, %36
44+
br label %38
45+
46+
38: ; preds = %.lr.ph, %123
47+
%39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ]
48+
%40 = mul i32 %39, %2
49+
%41 = sext i32 %40 to i64
50+
%42 = getelementptr float, ptr addrspace(1) %1, i64 %41
51+
%43 = getelementptr float, ptr addrspace(1) %42, i64 %25
52+
%44 = getelementptr float, ptr addrspace(1) %42, i64 %26
53+
%45 = getelementptr float, ptr addrspace(1) %42, i64 %27
54+
br i1 %18, label %46, label %49
55+
56+
46: ; preds = %38
57+
%47 = getelementptr float, ptr addrspace(1) %42, i64 %28
58+
%48 = load <1 x float>, ptr addrspace(1) %47, align 4
59+
br label %49
60+
61+
49: ; preds = %46, %38
62+
%50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ]
63+
%51 = extractelement <1 x float> %50, i64 0
64+
br i1 %19, label %52, label %54
65+
66+
52: ; preds = %49
67+
%53 = load <1 x float>, ptr addrspace(1) %43, align 4
68+
br label %54
69+
70+
54: ; preds = %52, %49
71+
%55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ]
72+
%56 = extractelement <1 x float> %55, i64 0
73+
br i1 %20, label %57, label %59
74+
75+
57: ; preds = %54
76+
%58 = load <1 x float>, ptr addrspace(1) %44, align 4
77+
br label %59
78+
79+
59: ; preds = %57, %54
80+
%60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ]
81+
%61 = extractelement <1 x float> %60, i64 0
82+
br i1 %21, label %62, label %64
83+
84+
62: ; preds = %59
85+
%63 = load <1 x float>, ptr addrspace(1) %45, align 4
86+
br label %64
87+
88+
64: ; preds = %62, %59
89+
%65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ]
90+
%66 = extractelement <1 x float> %65, i64 0
91+
tail call spir_func void @_Z7barrierj(i32 1)
92+
%67 = tail call float @llvm.maxnum.f32(float %51, float %56)
93+
%68 = tail call float @llvm.maxnum.f32(float %67, float %61)
94+
%69 = tail call float @llvm.maxnum.f32(float %68, float %66)
95+
%70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69)
96+
br i1 %30, label %71, label %72
97+
98+
71: ; preds = %64
99+
store float %70, ptr addrspace(3) %31, align 4
100+
br label %72
101+
102+
72: ; preds = %71, %64
103+
tail call spir_func void @_Z7barrierj(i32 1)
104+
br i1 %32, label %74, label %.thread1
105+
106+
.thread1: ; preds = %72
107+
%73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float undef, i32 8)
108+
br label %78
109+
110+
74: ; preds = %72
111+
%75 = load float, ptr addrspace(3) %34, align 4
112+
%76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8)
113+
br i1 %37, label %77, label %78
114+
115+
77: ; preds = %74
116+
store float %76, ptr addrspace(3) %34, align 4
117+
br label %78
118+
119+
78: ; preds = %.thread1, %77, %74
120+
tail call spir_func void @_Z7barrierj(i32 1)
121+
%79 = load float, ptr addrspace(3) %6, align 4
122+
%80 = fsub float %51, %79
123+
%81 = fsub float %56, %79
124+
%82 = fsub float %61, %79
125+
%83 = fsub float %66, %79
126+
%84 = fmul float %80, 0x3FF7154760000000
127+
%85 = tail call float @llvm.exp2.f32(float %84)
128+
%86 = fmul float %81, 0x3FF7154760000000
129+
%87 = tail call float @llvm.exp2.f32(float %86)
130+
%88 = fmul float %82, 0x3FF7154760000000
131+
%89 = tail call float @llvm.exp2.f32(float %88)
132+
%90 = fmul float %83, 0x3FF7154760000000
133+
%91 = tail call float @llvm.exp2.f32(float %90)
134+
tail call spir_func void @_Z7barrierj(i32 1)
135+
%92 = fadd float %85, %87
136+
%93 = fadd float %89, %92
137+
%94 = fadd float %91, %93
138+
%95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94)
139+
br i1 %30, label %96, label %97
140+
141+
96: ; preds = %78
142+
store float %95, ptr addrspace(3) %31, align 4
143+
br label %97
144+
145+
97: ; preds = %96, %78
146+
tail call spir_func void @_Z7barrierj(i32 1)
147+
br i1 %32, label %99, label %.thread
148+
149+
.thread: ; preds = %97
150+
%98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float undef, i32 8)
151+
br label %103
152+
153+
99: ; preds = %97
154+
%100 = load float, ptr addrspace(3) %34, align 4
155+
%101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8)
156+
br i1 %37, label %102, label %103
157+
158+
102: ; preds = %99
159+
store float %101, ptr addrspace(3) %34, align 4
160+
br label %103
161+
162+
103: ; preds = %.thread, %102, %99
163+
tail call spir_func void @_Z7barrierj(i32 1)
164+
%104 = load float, ptr addrspace(3) %6, align 4
165+
%105 = fdiv float %87, %104
166+
%106 = fdiv float %89, %104
167+
%107 = fdiv float %91, %104
168+
%108 = mul i32 %39, %3
169+
%109 = sext i32 %108 to i64
170+
%110 = getelementptr float, ptr addrspace(1) %0, i64 %109
171+
%111 = getelementptr float, ptr addrspace(1) %110, i64 %25
172+
%112 = getelementptr float, ptr addrspace(1) %110, i64 %26
173+
%113 = getelementptr float, ptr addrspace(1) %110, i64 %27
174+
br i1 %18, label %114, label %117
175+
176+
114: ; preds = %103
177+
%115 = fdiv float %85, %104
178+
%116 = getelementptr float, ptr addrspace(1) %110, i64 %28
179+
store float %115, ptr addrspace(1) %116, align 4
180+
br label %117
181+
182+
117: ; preds = %114, %103
183+
br i1 %19, label %118, label %119
184+
185+
118: ; preds = %117
186+
store float %105, ptr addrspace(1) %111, align 4
187+
br label %119
188+
189+
119: ; preds = %118, %117
190+
br i1 %20, label %120, label %121
191+
192+
120: ; preds = %119
193+
store float %106, ptr addrspace(1) %112, align 4
194+
br label %121
195+
196+
121: ; preds = %120, %119
197+
br i1 %21, label %122, label %123
198+
199+
122: ; preds = %121
200+
store float %107, ptr addrspace(1) %113, align 4
201+
br label %123
202+
203+
123: ; preds = %122, %121
204+
%124 = add i32 %39, %11
205+
%125 = icmp slt i32 %124, %4
206+
br i1 %125, label %38, label %._crit_edge
207+
208+
._crit_edge: ; preds = %123, %7
209+
ret void
210+
}
211+
212+
declare float @llvm.maxnum.f32(float, float)
213+
declare spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, float, i32)
214+
declare spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32, i32, float)
215+
declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32, i32, float, i32)
216+
declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, float)
217+
declare spir_func void @_Z7barrierj(i32)
218+
declare spir_func i64 @_Z12get_local_idj(i32)
219+
declare spir_func i64 @_Z14get_num_groupsj(i32)
220+
declare spir_func i64 @_Z12get_group_idj(i32)
221+
declare float @llvm.exp2.f32(float)

0 commit comments

Comments
 (0)