Skip to content

Commit 87f4e80

Browse files
authored
[SPIRV] Add support for CodeSectionINTEL storage class in legalizer (#167961)
The [SPV_INTEL_function_pointers](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_pointers.asciidoc) extension defines a new storage class `CodeSectionINTEL` that is represented in LLVM IR as `addrspace(9)`. Per the spec, it is basically not allowed to be casted to or interact with pointers with other storage classes. Add `addrspace(9)` as a known pointer type to the legalizer, and then add some error cases for IR that is impossible to legalize. Right now, if you try to run the backend on input with SPIR-V, basically everything errors saying that it is unable to legalize because `ptr addrspace(9)` is not considered a pointer type. Ideally the FE should not generate the illegal IR or error out earlier, but we should catch it before generating invalid SPIR-V. --------- Signed-off-by: Nick Sarnie <[email protected]>
1 parent 3c6864a commit 87f4e80

File tree

7 files changed

+93
-14
lines changed

7 files changed

+93
-14
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,19 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
9090
const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
9191
const LLT p7 = LLT::pointer(7, PSize); // Input
9292
const LLT p8 = LLT::pointer(8, PSize); // Output
93+
const LLT p9 =
94+
LLT::pointer(9, PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers
9395
const LLT p10 = LLT::pointer(10, PSize); // Private
9496
const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
9597
const LLT p12 = LLT::pointer(12, PSize); // Uniform
9698

9799
// TODO: remove copy-pasting here by using concatenation in some way.
98100
auto allPtrsScalarsAndVectors = {
99-
p0, p1, p2, p3, p4, p5, p6, p7, p8,
100-
p10, p11, p12, s1, s8, s16, s32, s64, v2s1,
101-
v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
102-
v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32,
103-
v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
101+
p0, p1, p2, p3, p4, p5, p6, p7, p8,
102+
p9, p10, p11, p12, s1, s8, s16, s32, s64,
103+
v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32,
104+
v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
105+
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
104106

105107
auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
106108
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
@@ -131,10 +133,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
131133
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
132134
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
133135

134-
auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3,
135-
p4, p5, p6, p7, p8, p10, p11, p12};
136+
auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3, p4,
137+
p5, p6, p7, p8, p9, p10, p11, p12};
136138

137-
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
139+
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12};
138140

139141
auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
140142

@@ -246,15 +248,22 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
246248
.legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
247249

248250
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
251+
.unsupportedIf(LegalityPredicates::any(typeIs(0, p9), typeIs(1, p9)))
249252
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
250253

251-
getActionDefinitionsBuilder(G_MEMSET).legalIf(
252-
all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
254+
getActionDefinitionsBuilder(G_MEMSET)
255+
.unsupportedIf(typeIs(0, p9))
256+
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
253257

254258
getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
259+
.unsupportedIf(
260+
LegalityPredicates::any(all(typeIs(0, p9), typeIsNot(1, p9)),
261+
all(typeIsNot(0, p9), typeIs(1, p9))))
255262
.legalForCartesianProduct(allPtrs, allPtrs);
256263

257-
getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
264+
getActionDefinitionsBuilder({G_LOAD, G_STORE})
265+
.unsupportedIf(typeIs(1, p9))
266+
.legalIf(typeInSet(1, allPtrs));
258267

259268
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
260269
G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
@@ -323,9 +332,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
323332

324333
// ST.canDirectlyComparePointers() for pointer args is supported in
325334
// legalizeCustom().
326-
getActionDefinitionsBuilder(G_ICMP).customIf(
327-
all(typeInSet(0, allBoolScalarsAndVectors),
328-
typeInSet(1, allPtrsScalarsAndVectors)));
335+
getActionDefinitionsBuilder(G_ICMP)
336+
.unsupportedIf(LegalityPredicates::any(
337+
all(typeIs(0, p9), typeInSet(1, allPtrs), typeIsNot(1, p9)),
338+
all(typeInSet(0, allPtrs), typeIsNot(0, p9), typeIs(1, p9))))
339+
.customIf(all(typeInSet(0, allBoolScalarsAndVectors),
340+
typeInSet(1, allPtrsScalarsAndVectors)));
329341

330342
getActionDefinitionsBuilder(G_FCMP).legalIf(
331343
all(typeInSet(0, allBoolScalarsAndVectors),
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llc --global-isel %s -filetype=null 2>&1 | FileCheck %s
2+
target triple = "spirv64"
3+
4+
define void @addrspacecast(ptr addrspace(9) %a) {
5+
; CHECK: unable to legalize instruction: %{{.*}}:pid(p4) = G_ADDRSPACE_CAST %{{.*}}:pid(p9)
6+
%res1 = addrspacecast ptr addrspace(9) %a to ptr addrspace(4)
7+
ret void
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llc --global-isel %s -filetype=null 2>&1 | FileCheck %s
2+
target triple = "spirv64"
3+
4+
define void @do_load(ptr addrspace(9) %a) {
5+
; CHECK: unable to legalize instruction: %{{.*}}:iid(s32) = G_LOAD %{{.*}}:pid(p9)
6+
%val = load i32, ptr addrspace(9) %a
7+
ret void
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llc --global-isel %s -filetype=null 2>&1 | FileCheck %s
2+
target triple = "spirv64"
3+
4+
define void @memcpy(ptr addrspace(9) %a) {
5+
; CHECK: unable to legalize instruction: G_MEMCPY %{{.*}}:pid(p9), %{{.*}}:pid(p0), %{{.*}}:iid(s64), 0
6+
call void @llvm.memcpy.p9.p0.i64(ptr addrspace(9) %a, ptr null, i32 1, i1 0)
7+
ret void
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llc --global-isel %s -filetype=null 2>&1 | FileCheck %s
2+
target triple = "spirv64"
3+
4+
define void @memset(ptr addrspace(9) %a) {
5+
; CHECK: unable to legalize instruction: G_MEMSET %{{.*}}:pid(p9), %{{.*}}:iid(s8), %{{.*}}:iid(s64)
6+
call void @llvm.memset.p9.i32(ptr addrspace(9) %a, i8 0, i32 1, i1 0)
7+
ret void
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: not llc --global-isel %s -filetype=null 2>&1 | FileCheck %s
2+
target triple = "spirv64"
3+
4+
define void @do_store(ptr addrspace(9) %a) {
5+
; CHECK: unable to legalize instruction: G_STORE %{{.*}}:iid(s32), %{{.*}}:pid(p9)
6+
store i32 5, ptr addrspace(9) %a
7+
ret void
8+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: llc -verify-machineinstrs -O0 --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
2+
; TODO: %if spirv-tools %{ llc -O0 %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpCapability FunctionPointersINTEL
5+
; CHECK: OpExtension "SPV_INTEL_function_pointers"
6+
7+
; CHECK: OpName %[[F1:.*]] "f1"
8+
; CHECK: OpName %[[ARG:.*]] "arg"
9+
10+
; CHECK: %[[TyBool:.*]] = OpTypeBool
11+
12+
; CHECK: %[[F1Ptr:.*]] = OpConstantFunctionPointerINTEL %{{.*}} %[[F1]]
13+
14+
; CHECK: OpPtrEqual %[[TyBool]] %[[F1Ptr]] %[[ARG]]
15+
16+
target triple = "spirv64"
17+
18+
define spir_func void @f1() addrspace(9) {
19+
entry:
20+
ret void
21+
}
22+
23+
define spir_func i1 @foo(ptr addrspace(9) %arg) addrspace(9) {
24+
entry:
25+
%a = icmp eq ptr addrspace(9) @f1, %arg
26+
ret i1 %a
27+
}

0 commit comments

Comments
 (0)