Skip to content

Commit 73b9262

Browse files
authored
[Backport to llvm_release_160] initial support for SPV_INTEL_device_barrier (#3554)
1 parent de396f2 commit 73b9262

File tree

13 files changed

+337
-7
lines changed

13 files changed

+337
-7
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ EXT(SPV_INTEL_runtime_aligned)
5757
EXT(SPV_EXT_arithmetic_fence)
5858
EXT(SPV_INTEL_arithmetic_fence)
5959
EXT(SPV_INTEL_bfloat16_conversion)
60+
EXT(SPV_INTEL_device_barrier)
6061
EXT(SPV_INTEL_joint_matrix)
6162
EXT(SPV_INTEL_hw_thread_queries)
6263
EXT(SPV_INTEL_global_variable_decorations)

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ CallInst *setAttrByCalledFunc(CallInst *Call);
10131013
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
10141014
// Transform builtin variable from GlobalVariable to builtin call.
10151015
// e.g.
1016-
// - GlobalInvolcationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
1016+
// - GlobalInvocationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
10171017
// - WorkDim -> _Z22__spirv_BuiltInWorkDimv()
10181018
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
10191019
SPIRVBuiltinVariableKind Kind);

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,13 +2151,15 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
21512151
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
21522152
// Figure out which index the accumulated offset corresponds to. If we
21532153
// have a weird offset (e.g., trying to load byte 7), bail out.
2154-
Type *ScalarTy = ReplacementFunc->getReturnType();
21552154
APInt Index;
2156-
uint64_t Remainder;
2157-
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2158-
Index, Remainder);
2159-
if (Remainder != 0)
2160-
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2155+
Type *ScalarTy = ReplacementFunc->getReturnType();
2156+
if (!ScalarTy->isIntegerTy(1)) {
2157+
uint64_t Remainder;
2158+
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2159+
Index, Remainder);
2160+
if (Remainder != 0)
2161+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2162+
}
21612163

21622164
IRBuilder<> Builder(Load);
21632165
Value *Replacement;

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
911911
return ExtensionID::SPV_INTEL_function_variants;
912912
case internal::CapabilityBFloat16ArithmeticINTEL:
913913
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
914+
case internal::CapabilityDeviceBarrierINTEL:
915+
return ExtensionID::SPV_INTEL_device_barrier;
914916
default:
915917
return {};
916918
}

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,8 @@ template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {
561561
{internal::CapabilityHWThreadQueryINTEL});
562562
ADD_VEC_INIT(internal::BuiltInGlobalHWThreadIDINTEL,
563563
{internal::CapabilityHWThreadQueryINTEL});
564+
ADD_VEC_INIT(internal::BuiltInDeviceBarrierValidINTEL,
565+
{internal::CapabilityDeviceBarrierINTEL});
564566
}
565567

566568
template <> inline void SPIRVMap<MemorySemanticsMask, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2358,13 +2358,40 @@ class SPIRVControlBarrier : public SPIRVInstruction {
23582358
return getValues(Operands);
23592359
}
23602360

2361+
SPIRVCapVec getRequiredCapability() const override {
2362+
if (isDeviceBarrier()) {
2363+
return getVec(internal::CapabilityDeviceBarrierINTEL);
2364+
}
2365+
return SPIRVInstruction::getRequiredCapability();
2366+
}
2367+
std::optional<ExtensionID> getRequiredExtension() const override {
2368+
if (isDeviceBarrier()) {
2369+
return ExtensionID::SPV_INTEL_device_barrier;
2370+
}
2371+
return std::nullopt;
2372+
}
2373+
23612374
protected:
23622375
_SPIRV_DEF_ENCDEC3(ExecScope, MemScope, MemSema)
23632376
void validate() const override {
23642377
assert(OpCode == OC);
23652378
assert(WordCount == 4);
23662379
SPIRVInstruction::validate();
23672380
}
2381+
2382+
bool isDeviceBarrier() const {
2383+
if (!getModule()->isAllowedToUseExtension(
2384+
ExtensionID::SPV_INTEL_device_barrier))
2385+
return false;
2386+
SPIRVValue *ESV = getValue(ExecScope);
2387+
if (ESV && ESV->getOpCode() == OpConstant) {
2388+
if (static_cast<SPIRVConstant *>(ESV)->getZExtIntValue() != ScopeDevice) {
2389+
return false;
2390+
}
2391+
}
2392+
return true;
2393+
}
2394+
23682395
SPIRVId ExecScope;
23692396
SPIRVId MemScope = SPIRVID_INVALID;
23702397
SPIRVId MemSema = SPIRVID_INVALID;

lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ inline bool isValid(spv::BuiltIn V) {
281281
case BuiltInCullMaskKHR:
282282
case internal::BuiltInSubDeviceIDINTEL:
283283
case internal::BuiltInGlobalHWThreadIDINTEL:
284+
case internal::BuiltInDeviceBarrierValidINTEL:
284285
return true;
285286
default:
286287
return false;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ template <> inline void SPIRVMap<BuiltIn, std::string>::init() {
343343
add(BuiltInMax, "BuiltInMax");
344344
add(internal::BuiltInSubDeviceIDINTEL, "BuiltInSubDeviceIDINTEL");
345345
add(internal::BuiltInGlobalHWThreadIDINTEL, "BuiltInGlobalHWThreadIDINTEL");
346+
add(internal::BuiltInDeviceBarrierValidINTEL,
347+
"BuiltInDeviceBarrierValidINTEL");
346348
}
347349
SPIRV_DEF_NAMEMAP(BuiltIn, SPIRVBuiltInNameMap)
348350

@@ -670,6 +672,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
670672
add(CapabilityFloat8CooperativeMatrixEXT, "Float8CooperativeMatrixEXT");
671673
add(internal::CapabilityPredicatedIOINTEL, "PredicatedIOINTEL");
672674
add(internal::CapabilitySigmoidINTEL, "SigmoidINTEL");
675+
add(internal::CapabilityDeviceBarrierINTEL, "DeviceBarrierINTEL");
673676
add(internal::CapabilityFloat4E2M1INTEL, "Float4E2M1INTEL");
674677
add(internal::CapabilityFloat4E2M1CooperativeMatrixINTEL,
675678
"Float4E2M1CooperativeMatrixINTEL");

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ enum InternalCapability {
110110
ICapabilityHWThreadQueryINTEL = 6134,
111111
ICapGlobalVariableDecorationsINTEL = 6146,
112112
ICapabilitySigmoidINTEL = 6167,
113+
ICapabilityDeviceBarrierINTEL = 6185,
113114
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
114115
ICapabilityFloat4E2M1INTEL = 6212,
115116
ICapabilityFloat4E2M1CooperativeMatrixINTEL = 6213,
@@ -150,6 +151,7 @@ enum InternalFPEncoding {
150151
enum InternalBuiltIn {
151152
IBuiltInSubDeviceIDINTEL = 6135,
152153
IBuiltInGlobalHWThreadIDINTEL = 6136,
154+
IBuiltInDeviceBarrierValidINTEL = 6186,
153155
};
154156

155157
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
@@ -172,6 +174,9 @@ _SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL)
172174
_SPIRV_OP(Capability, CooperativeMatrixPrefetchINTEL)
173175
_SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL)
174176

177+
_SPIRV_OP(Capability, DeviceBarrierINTEL)
178+
_SPIRV_OP(BuiltIn, DeviceBarrierValidINTEL)
179+
175180
_SPIRV_OP(Capability, HWThreadQueryINTEL)
176181
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
177182
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
;; kernel void test(global uint* dst)
2+
;; {
3+
;; int scope = magic_get_scope();
4+
;; __spirv_ControlBarrier(scope, 1, 264); // local
5+
;; __spirv_ControlBarrier(scope, 1, 520); // global
6+
;; __spirv_ControlBarrier(scope, 1, 2056); // image
7+
;;
8+
;; __spirv_ControlBarrier(scope, 0, 520); // global, all_svm_devices
9+
;; __spirv_ControlBarrier(scope, 1, 520); // global, device
10+
;; __spirv_ControlBarrier(scope, 2, 520); // global, work_group
11+
;; __spirv_ControlBarrier(scope, 3, 520); // global, subgroup
12+
;; __spirv_ControlBarrier(scope, 4, 520); // global, work_item
13+
;;}
14+
15+
; Test for SPV_INTEL_device_barrier (SPIR-V friendly LLVM IR)
16+
; RUN: llvm-as %s -o %t.bc
17+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_device_barrier
18+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
19+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
20+
21+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR
22+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
23+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
24+
25+
; RUN: llvm-spirv %t.bc -o %t.disabled.spv
26+
; RUN: llvm-spirv %t.disabled.spv -o %t.disabled.spt --to-text
27+
; RUN: FileCheck < %t.disabled.spt %s --check-prefix=CHECK-SPIRV-EXTENSION-DISABLED
28+
29+
; ModuleID = 'device_barrier_spirv.cl'
30+
source_filename = "device_barrier_spirv.cl"
31+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
32+
target triple = "spir64"
33+
34+
; CHECK-SPIRV: Capability DeviceBarrierINTEL
35+
; CHECK-SPIRV: Extension "SPV_INTEL_device_barrier"
36+
; CHECK-SPIRV: TypeInt [[UINT:[0-9]+]] 32 0
37+
;
38+
;; When the SPV_INTEL_device_barrier extension is not enabled, a runtime variable
39+
;; should not cause the device barrier extension or capability to be declared.
40+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Capability DeviceBarrierINTEL
41+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Extension "SPV_INTEL_device_barrier"
42+
;
43+
; Scopes:
44+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_CROSS_DEVICE:[0-9]+]] 0 {{$}}
45+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_DEVICE:[0-9]+]] 1 {{$}}
46+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_WORK_GROUP:[0-9]+]] 2 {{$}}
47+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_SUBGROUP:[0-9]+]] 3 {{$}}
48+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_INVOCATION:[0-9]+]] 4 {{$}}
49+
;
50+
; Memory Semantics:
51+
; 0x8 AcquireRelease + 0x100 WorkgroupMemory
52+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_LOCAL:[0-9]+]] 264
53+
; 0x8 AcquireRelease + 0x200 CrossWorkgroupMemory
54+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_GLOBAL:[0-9]+]] 520
55+
; 0x8 AcquireRelease + 0x800 ImageMemory
56+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_IMAGE:[0-9]+]] 2056
57+
;
58+
; Runtime execution scope:
59+
; CHECK-SPIRV: FunctionCall [[#]] [[EXEC_SCOPE:[0-9]+]] [[#]]
60+
;
61+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_LOCAL]]
62+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
63+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_IMAGE]]
64+
;
65+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_CROSS_DEVICE]] [[ACQREL_GLOBAL]]
66+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
67+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_WORK_GROUP]] [[ACQREL_GLOBAL]]
68+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_SUBGROUP]] [[ACQREL_GLOBAL]]
69+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_INVOCATION]] [[ACQREL_GLOBAL]]
70+
71+
; CHECK-LLVM-LABEL: define spir_kernel void @test
72+
; Function Attrs: convergent norecurse nounwind
73+
define dso_local spir_kernel void @test(ptr addrspace(1) nocapture noundef readnone align 4 %0) local_unnamed_addr #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
74+
%2 = call noundef i32 @magic_get_scope()
75+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 264) #2
76+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 264) #1
77+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
78+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
79+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 2056) #2
80+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 2056) #1
81+
82+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 0, i32 noundef 520) #2
83+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 0, i32 520) #1
84+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
85+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
86+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 2, i32 noundef 520) #2
87+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 2, i32 520) #1
88+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 3, i32 noundef 520) #2
89+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 3, i32 520) #1
90+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 4, i32 noundef 520) #2
91+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 4, i32 520) #1
92+
ret void
93+
}
94+
95+
; Function Attrs: convergent
96+
declare dso_local spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #1
97+
98+
declare spir_func i32 @magic_get_scope()
99+
100+
attributes #0 = { convergent norecurse nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" }
101+
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
102+
attributes #2 = { convergent nounwind }
103+
104+
!llvm.module.flags = !{!0, !1}
105+
!opencl.ocl.version = !{!2}
106+
!opencl.spir.version = !{!2}
107+
!llvm.ident = !{!3}
108+
109+
!0 = !{i32 1, !"wchar_size", i32 4}
110+
!1 = !{i32 7, !"frame-pointer", i32 2}
111+
!2 = !{i32 2, i32 0}
112+
!3 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project 861386dbd6ff0d91636b7c674c2abb2eccd9d3f2)"}
113+
!4 = !{i32 1}
114+
!5 = !{!"none"}
115+
!6 = !{!"uint*"}
116+
!7 = !{!""}

0 commit comments

Comments
 (0)