Skip to content

Commit a46879b

Browse files
[Backport to llvm_release_200] initial support for SPV_INTEL_device_barrier (#3534)
Backport of PR #3461 into `llvm_release_200`. All commits applied cleanly. Co-authored-by: Ben Ashbaugh <ben.ashbaugh@intel.com>
1 parent fccd096 commit a46879b

File tree

13 files changed

+359
-7
lines changed

13 files changed

+359
-7
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ EXT(SPV_INTEL_runtime_aligned)
5757
EXT(SPV_EXT_arithmetic_fence)
5858
EXT(SPV_INTEL_arithmetic_fence)
5959
EXT(SPV_INTEL_bfloat16_conversion)
60+
EXT(SPV_INTEL_device_barrier)
6061
EXT(SPV_INTEL_joint_matrix)
6162
EXT(SPV_INTEL_hw_thread_queries)
6263
EXT(SPV_INTEL_global_variable_decorations)

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ CallInst *setAttrByCalledFunc(CallInst *Call);
10021002
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
10031003
// Transform builtin variable from GlobalVariable to builtin call.
10041004
// e.g.
1005-
// - GlobalInvolcationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
1005+
// - GlobalInvocationId[x] -> _Z33__spirv_BuiltInGlobalInvocationIdi(x)
10061006
// - WorkDim -> _Z22__spirv_BuiltInWorkDimv()
10071007
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
10081008
SPIRVBuiltinVariableKind Kind);

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,13 +2057,15 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
20572057
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
20582058
// Figure out which index the accumulated offset corresponds to. If we
20592059
// have a weird offset (e.g., trying to load byte 7), bail out.
2060-
Type *ScalarTy = ReplacementFunc->getReturnType();
20612060
APInt Index;
2062-
uint64_t Remainder;
2063-
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2064-
Index, Remainder);
2065-
if (Remainder != 0)
2066-
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2061+
Type *ScalarTy = ReplacementFunc->getReturnType();
2062+
if (!ScalarTy->isIntegerTy(1)) {
2063+
uint64_t Remainder;
2064+
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
2065+
Index, Remainder);
2066+
if (Remainder != 0)
2067+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2068+
}
20672069

20682070
IRBuilder<> Builder(Load);
20692071
Value *Replacement;

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,8 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
919919
return ExtensionID::SPV_INTEL_function_variants;
920920
case internal::CapabilityBFloat16ArithmeticINTEL:
921921
return ExtensionID::SPV_INTEL_bfloat16_arithmetic;
922+
case internal::CapabilityDeviceBarrierINTEL:
923+
return ExtensionID::SPV_INTEL_device_barrier;
922924
default:
923925
return {};
924926
}

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,8 @@ template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {
588588
{internal::CapabilityHWThreadQueryINTEL});
589589
ADD_VEC_INIT(internal::BuiltInGlobalHWThreadIDINTEL,
590590
{internal::CapabilityHWThreadQueryINTEL});
591+
ADD_VEC_INIT(internal::BuiltInDeviceBarrierValidINTEL,
592+
{internal::CapabilityDeviceBarrierINTEL});
591593
}
592594

593595
template <> inline void SPIRVMap<MemorySemanticsMask, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,13 +2488,40 @@ class SPIRVControlBarrier : public SPIRVInstruction {
24882488
return getValues(Operands);
24892489
}
24902490

2491+
SPIRVCapVec getRequiredCapability() const override {
2492+
if (isDeviceBarrier()) {
2493+
return getVec(internal::CapabilityDeviceBarrierINTEL);
2494+
}
2495+
return SPIRVInstruction::getRequiredCapability();
2496+
}
2497+
std::optional<ExtensionID> getRequiredExtension() const override {
2498+
if (isDeviceBarrier()) {
2499+
return ExtensionID::SPV_INTEL_device_barrier;
2500+
}
2501+
return std::nullopt;
2502+
}
2503+
24912504
protected:
24922505
_SPIRV_DEF_ENCDEC3(ExecScope, MemScope, MemSema)
24932506
void validate() const override {
24942507
assert(OpCode == OC);
24952508
assert(WordCount == 4);
24962509
SPIRVInstruction::validate();
24972510
}
2511+
2512+
bool isDeviceBarrier() const {
2513+
if (!getModule()->isAllowedToUseExtension(
2514+
ExtensionID::SPV_INTEL_device_barrier))
2515+
return false;
2516+
SPIRVValue *ESV = getValue(ExecScope);
2517+
if (ESV && ESV->getOpCode() == OpConstant) {
2518+
if (static_cast<SPIRVConstant *>(ESV)->getZExtIntValue() != ScopeDevice) {
2519+
return false;
2520+
}
2521+
}
2522+
return true;
2523+
}
2524+
24982525
SPIRVId ExecScope;
24992526
SPIRVId MemScope = SPIRVID_INVALID;
25002527
SPIRVId MemSema = SPIRVID_INVALID;

lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ inline bool isValid(spv::BuiltIn V) {
282282
case BuiltInCullMaskKHR:
283283
case internal::BuiltInSubDeviceIDINTEL:
284284
case internal::BuiltInGlobalHWThreadIDINTEL:
285+
case internal::BuiltInDeviceBarrierValidINTEL:
285286
return true;
286287
default:
287288
return false;

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ template <> inline void SPIRVMap<BuiltIn, std::string>::init() {
349349
add(BuiltInMax, "BuiltInMax");
350350
add(internal::BuiltInSubDeviceIDINTEL, "BuiltInSubDeviceIDINTEL");
351351
add(internal::BuiltInGlobalHWThreadIDINTEL, "BuiltInGlobalHWThreadIDINTEL");
352+
add(internal::BuiltInDeviceBarrierValidINTEL,
353+
"BuiltInDeviceBarrierValidINTEL");
352354
}
353355
SPIRV_DEF_NAMEMAP(BuiltIn, SPIRVBuiltInNameMap)
354356

@@ -697,6 +699,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
697699
add(CapabilityFloat8CooperativeMatrixEXT, "Float8CooperativeMatrixEXT");
698700
add(internal::CapabilityPredicatedIOINTEL, "PredicatedIOINTEL");
699701
add(internal::CapabilitySigmoidINTEL, "SigmoidINTEL");
702+
add(internal::CapabilityDeviceBarrierINTEL, "DeviceBarrierINTEL");
700703
add(internal::CapabilityFloat4E2M1INTEL, "Float4E2M1INTEL");
701704
add(internal::CapabilityFloat4E2M1CooperativeMatrixINTEL,
702705
"Float4E2M1CooperativeMatrixINTEL");

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ enum InternalCapability {
117117
ICapGlobalVariableDecorationsINTEL = 6146,
118118
ICapabilityTaskSequenceINTEL = 6162,
119119
ICapabilitySigmoidINTEL = 6167,
120+
ICapabilityDeviceBarrierINTEL = 6185,
120121
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
121122
ICapabilityFloat4E2M1INTEL = 6212,
122123
ICapabilityFloat4E2M1CooperativeMatrixINTEL = 6213,
@@ -175,6 +176,7 @@ enum InternalFPEncoding {
175176
enum InternalBuiltIn {
176177
IBuiltInSubDeviceIDINTEL = 6135,
177178
IBuiltInGlobalHWThreadIDINTEL = 6136,
179+
IBuiltInDeviceBarrierValidINTEL = 6186,
178180
};
179181

180182
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
@@ -210,6 +212,9 @@ _SPIRV_OP(Op, CooperativeMatrixStoreOffsetINTEL)
210212
_SPIRV_OP(Capability, CooperativeMatrixInvocationInstructionsINTEL)
211213
_SPIRV_OP(Op, CooperativeMatrixApplyFunctionINTEL)
212214

215+
_SPIRV_OP(Capability, DeviceBarrierINTEL)
216+
_SPIRV_OP(BuiltIn, DeviceBarrierValidINTEL)
217+
213218
_SPIRV_OP(Capability, HWThreadQueryINTEL)
214219
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
215220
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
;; kernel void test(global uint* dst)
2+
;; {
3+
;; int scope = magic_get_scope();
4+
;; __spirv_ControlBarrier(scope, 1, 264); // local
5+
;; __spirv_ControlBarrier(scope, 1, 520); // global
6+
;; __spirv_ControlBarrier(scope, 1, 2056); // image
7+
;;
8+
;; __spirv_ControlBarrier(scope, 0, 520); // global, all_svm_devices
9+
;; __spirv_ControlBarrier(scope, 1, 520); // global, device
10+
;; __spirv_ControlBarrier(scope, 2, 520); // global, work_group
11+
;; __spirv_ControlBarrier(scope, 3, 520); // global, subgroup
12+
;; __spirv_ControlBarrier(scope, 4, 520); // global, work_item
13+
;;}
14+
15+
; Test for SPV_INTEL_device_barrier (SPIR-V friendly LLVM IR)
16+
; RUN: llvm-as %s -o %t.bc
17+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_device_barrier
18+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
19+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
20+
21+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r --spirv-target-env=SPV-IR
22+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
23+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
24+
25+
; RUN: llvm-spirv %t.bc -o %t.disabled.spv
26+
; RUN: llvm-spirv %t.disabled.spv -o %t.disabled.spt --to-text
27+
; RUN: FileCheck < %t.disabled.spt %s --check-prefix=CHECK-SPIRV-EXTENSION-DISABLED
28+
29+
; ModuleID = 'device_barrier_spirv.cl'
30+
source_filename = "device_barrier_spirv.cl"
31+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
32+
target triple = "spir64"
33+
34+
; CHECK-SPIRV: Capability DeviceBarrierINTEL
35+
; CHECK-SPIRV: Extension "SPV_INTEL_device_barrier"
36+
; CHECK-SPIRV: TypeInt [[UINT:[0-9]+]] 32 0
37+
;
38+
;; When the SPV_INTEL_device_barrier extension is not enabled, a runtime variable
39+
;; should not cause the device barrier extension or capability to be declared.
40+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Capability DeviceBarrierINTEL
41+
; CHECK-SPIRV-EXTENSION-DISABLED-NOT: Extension "SPV_INTEL_device_barrier"
42+
;
43+
; Scopes:
44+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_CROSS_DEVICE:[0-9]+]] 0 {{$}}
45+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_DEVICE:[0-9]+]] 1 {{$}}
46+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_WORK_GROUP:[0-9]+]] 2 {{$}}
47+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_SUBGROUP:[0-9]+]] 3 {{$}}
48+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[SCOPE_INVOCATION:[0-9]+]] 4 {{$}}
49+
;
50+
; Memory Semantics:
51+
; 0x8 AcquireRelease + 0x100 WorkgroupMemory
52+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_LOCAL:[0-9]+]] 264
53+
; 0x8 AcquireRelease + 0x200 CrossWorkgroupMemory
54+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_GLOBAL:[0-9]+]] 520
55+
; 0x8 AcquireRelease + 0x800 ImageMemory
56+
; CHECK-SPIRV-DAG: Constant [[UINT]] [[ACQREL_IMAGE:[0-9]+]] 2056
57+
;
58+
; Runtime execution scope:
59+
; CHECK-SPIRV: FunctionCall [[#]] [[EXEC_SCOPE:[0-9]+]] [[#]]
60+
;
61+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_LOCAL]]
62+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
63+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_IMAGE]]
64+
;
65+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_CROSS_DEVICE]] [[ACQREL_GLOBAL]]
66+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_DEVICE]] [[ACQREL_GLOBAL]]
67+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_WORK_GROUP]] [[ACQREL_GLOBAL]]
68+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_SUBGROUP]] [[ACQREL_GLOBAL]]
69+
; CHECK-SPIRV: ControlBarrier [[EXEC_SCOPE]] [[SCOPE_INVOCATION]] [[ACQREL_GLOBAL]]
70+
71+
; CHECK-LLVM-LABEL: define spir_kernel void @test
72+
; Function Attrs: convergent norecurse nounwind
73+
define dso_local spir_kernel void @test(ptr addrspace(1) captures(none) noundef readnone align 4 %0) local_unnamed_addr #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
74+
%2 = call noundef i32 @magic_get_scope()
75+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 264) #2
76+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 264) #1
77+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
78+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
79+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 2056) #2
80+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 2056) #1
81+
82+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 0, i32 noundef 520) #2
83+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 0, i32 520) #1
84+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 1, i32 noundef 520) #2
85+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 1, i32 520) #1
86+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 2, i32 noundef 520) #2
87+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 2, i32 520) #1
88+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 3, i32 noundef 520) #2
89+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 3, i32 520) #1
90+
tail call spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef %2, i32 noundef 4, i32 noundef 520) #2
91+
; CHECK-LLVM: call spir_func void @_Z22__spirv_ControlBarrieriii(i32 %2, i32 4, i32 520) #1
92+
ret void
93+
}
94+
95+
; Function Attrs: convergent
96+
declare dso_local spir_func void @_Z22__spirv_ControlBarrieriii(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #1
97+
98+
declare spir_func i32 @magic_get_scope()
99+
100+
attributes #0 = { convergent norecurse nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" }
101+
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
102+
attributes #2 = { convergent nounwind }
103+
104+
!llvm.module.flags = !{!0, !1}
105+
!opencl.ocl.version = !{!2}
106+
!opencl.spir.version = !{!2}
107+
!llvm.ident = !{!3}
108+
109+
!0 = !{i32 1, !"wchar_size", i32 4}
110+
!1 = !{i32 7, !"frame-pointer", i32 2}
111+
!2 = !{i32 2, i32 0}
112+
!3 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project 861386dbd6ff0d91636b7c674c2abb2eccd9d3f2)"}
113+
!4 = !{i32 1}
114+
!5 = !{!"none"}
115+
!6 = !{!"uint*"}
116+
!7 = !{!""}

0 commit comments

Comments
 (0)