Skip to content

Commit ad63e7f

Browse files
committed
Add support for some FP16 vector atomics, via the SPV_NV_shader_atomic_fp16_vector extension.
1 parent 9edbf83 commit ad63e7f

File tree

8 files changed

+148
-3
lines changed

8 files changed

+148
-3
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
169169
- Adds atomic min and max instruction on floating-point numbers.
170170
* - ``SPV_INTEL_16bit_atomics``
171171
- Extends the SPV_EXT_shader_atomic_float_add and SPV_EXT_shader_atomic_float_min_max to support addition, minimum and maximum on 16-bit `bfloat16` floating-point numbers in memory.
172+
* - ``SPV_NV_shader_atomic_fp16_vector``
173+
- Adds atomic add, min and max instructions on 2 or 4-component vectors with 16-bit float components.
172174
* - ``SPV_INTEL_2d_block_io``
173175
- Adds additional subgroup block prefetch, load, load transposed, load transformed and store instructions to read two-dimensional blocks of data from a two-dimensional region of memory, or to write two-dimensional blocks of data to a two dimensional region of memory.
174176
* - ``SPV_ALTERA_arbitrary_precision_integers``

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
3131
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
3232
{"SPV_INTEL_16bit_atomics",
3333
SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
34+
{"SPV_NV_shader_atomic_fp16_vector",
35+
SPIRV::Extension::Extension::SPV_NV_shader_atomic_fp16_vector},
3436
{"SPV_EXT_arithmetic_fence",
3537
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
3638
{"SPV_EXT_demote_to_helper_invocation",

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
11931193
case TargetOpcode::G_ATOMICRMW_FSUB:
11941194
// Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
11951195
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
1196-
SPIRV::OpFNegate);
1196+
ResType->getOpcode() == SPIRV::OpTypeVector
1197+
? SPIRV::OpFNegateV : SPIRV::OpFNegate);
11971198
case TargetOpcode::G_ATOMICRMW_FMIN:
11981199
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
11991200
case TargetOpcode::G_ATOMICRMW_FMAX:

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
131131
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
132132
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
133133

134+
auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
135+
134136
auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, p2, p3,
135137
p4, p5, p6, p7, p8, p10, p11, p12};
136138

@@ -339,10 +341,12 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
339341

340342
getActionDefinitionsBuilder(
341343
{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
342-
.legalForCartesianProduct(allFloatScalars, allPtrs);
344+
.legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
345+
allPtrs);
343346

344347
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
345-
.legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
348+
.legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
349+
allPtrs);
346350

347351
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
348352
// TODO: add proper legalization rules.

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616

17+
// TODO: uses or report_fatal_error (which is also deprecated) /
18+
// ReportFatalUsageError in this file should be refactored, as per LLVM
19+
// best practices, to rely on the Diagnostic infrastructure.
20+
1721
#include "SPIRVModuleAnalysis.h"
1822
#include "MCTargetDesc/SPIRVBaseInfo.h"
1923
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
@@ -1071,13 +1075,50 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) {
10711075
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
10721076
"The atomic float instruction requires the following SPIR-V " \
10731077
"extension: SPV_EXT_shader_atomic_float" ExtName
1078+
static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1079+
SPIRV::RequirementHandler &Reqs,
1080+
const SPIRVSubtarget &ST) {
1081+
SPIRVType *VecTypeDef =
1082+
MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg());
1083+
1084+
const unsigned Rank = VecTypeDef->getOperand(2).getImm();
1085+
if (Rank != 2 && Rank != 4)
1086+
reportFatalUsageError("Result type of an atomic vector float instruction "
1087+
"must be a 2-component or 4 component vector");
1088+
1089+
SPIRVType *EltTypeDef =
1090+
MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg());
1091+
1092+
if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1093+
EltTypeDef->getOperand(1).getImm() != 16)
1094+
reportFatalUsageError(
1095+
"The element type for the result type of an atomic vector float "
1096+
"instruction must be a 16-bit floating-point scalar");
1097+
1098+
if (isBFloat16Type(EltTypeDef))
1099+
reportFatalUsageError(
1100+
"The element type for the result type of an atomic vector float "
1101+
"instruction cannot be a bfloat16 scalar");
1102+
if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1103+
reportFatalUsageError(
1104+
"The atomic float16 vector instruction requires the following SPIR-V "
1105+
"extension: SPV_NV_shader_atomic_fp16_vector");
1106+
1107+
Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1108+
Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV);
1109+
}
1110+
10741111
static void AddAtomicFloatRequirements(const MachineInstr &MI,
10751112
SPIRV::RequirementHandler &Reqs,
10761113
const SPIRVSubtarget &ST) {
10771114
assert(MI.getOperand(1).isReg() &&
10781115
"Expect register operand in atomic float instruction");
10791116
Register TypeReg = MI.getOperand(1).getReg();
10801117
SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1118+
1119+
if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1120+
return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1121+
10811122
if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
10821123
report_fatal_error("Result type of an atomic float instruction must be a "
10831124
"floating-point type scalar");

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ defm SPV_INTEL_bfloat16_arithmetic
391391
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
392392
defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
393393
defm SPV_ALTERA_arbitrary_precision_fixed_point : ExtensionOperand<131, [EnvOpenCL, EnvVulkan]>;
394+
defm SPV_NV_shader_atomic_fp16_vector
395+
: ExtensionOperand<132, [EnvVulkan, EnvOpenCL]>;
394396

395397
//===----------------------------------------------------------------------===//
396398
// Multiclass used to define Capabilities enum values and at the same time
@@ -573,6 +575,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
573575
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
574576
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
575577
defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
578+
defm AtomicFloat16VectorNV : CapabilityOperand<5404, 0, 0, [SPV_NV_shader_atomic_fp16_vector], []>;
576579
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
577580
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
578581
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
4+
5+
; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
6+
7+
; CHECK: Capability Float16
8+
; CHECK-DAG: Capability AtomicFloat16VectorNV
9+
; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
10+
; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
11+
; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
12+
; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
13+
; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
14+
; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
15+
; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
16+
; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
17+
; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
18+
; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
19+
; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
20+
; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
21+
; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
22+
; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
23+
; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
24+
; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
25+
26+
@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
27+
@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
28+
29+
; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
30+
; CHECK: %[[NegatedConstF16Vec2:[0-9]+]] = OpFNegate %[[TyF16Vec2]] %[[ConstF16Vec2]]
31+
; CHECK: OpAtomicFAddEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec2]]
32+
define dso_local spir_func void @test1() local_unnamed_addr {
33+
entry:
34+
%addval = atomicrmw fadd ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
35+
%subval = atomicrmw fsub ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
36+
ret void
37+
}
38+
39+
; CHECK-DAG: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
40+
; CHECK: %[[NegatedConstF16Vec4:[0-9]+]] = OpFNegate %[[TyF16Vec4]] %[[ConstF16Vec4]]
41+
; CHECK: OpAtomicFAddEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[NegatedConstF16Vec4]]
42+
define dso_local spir_func void @test2() local_unnamed_addr {
43+
entry:
44+
%addval = atomicrmw fadd ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
45+
%subval = atomicrmw fsub ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
46+
ret void
47+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_NV_shader_atomic_fp16_vector %s -o - | FileCheck %s
4+
5+
; CHECK-ERROR: LLVM ERROR: The atomic float16 vector instruction requires the following SPIR-V extension: SPV_NV_shader_atomic_fp16_vector
6+
7+
; CHECK: Capability Float16
8+
; CHECK-DAG: Capability AtomicFloat16VectorNV
9+
; CHECK: Extension "SPV_NV_shader_atomic_fp16_vector"
10+
; CHECK-DAG: %[[TyF16:[0-9]+]] = OpTypeFloat 16
11+
; CHECK: %[[TyF16Vec2:[0-9]+]] = OpTypeVector %[[TyF16]] 2
12+
; CHECK: %[[TyF16Vec4:[0-9]+]] = OpTypeVector %[[TyF16]] 4
13+
; CHECK: %[[TyF16Vec4Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec4]]
14+
; CHECK: %[[TyF16Vec2Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyF16Vec2]]
15+
; CHECK: %[[TyInt32:[0-9]+]] = OpTypeInt 32 0
16+
; CHECK: %[[ConstF16:[0-9]+]] = OpConstant %[[TyF16]] 20800{{$}}
17+
; CHECK: %[[Const0F16Vec2:[0-9]+]] = OpConstantNull %[[TyF16Vec2]]
18+
; CHECK: %[[f:[0-9]+]] = OpVariable %[[TyF16Vec2Ptr]] CrossWorkgroup %[[Const0F16Vec2]]
19+
; CHECK: %[[Const0F16Vec4:[0-9]+]] = OpConstantNull %[[TyF16Vec4]]
20+
; CHECK: %[[g:[0-9]+]] = OpVariable %[[TyF16Vec4Ptr]] CrossWorkgroup %[[Const0F16Vec4]]
21+
; CHECK: %[[ConstF16Vec2:[0-9]+]] = OpConstantComposite %[[TyF16Vec2]] %[[ConstF16]] %[[ConstF16]]
22+
; CHECK: %[[ScopeAllSvmDevices:[0-9]+]] = OpConstantNull %[[TyInt32]]
23+
; CHECK: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16{{$}}
24+
; CHECK: %[[ConstF16Vec4:[0-9]+]] = OpConstantComposite %[[TyF16Vec4]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]] %[[ConstF16]]
25+
26+
@f = common dso_local local_unnamed_addr addrspace(1) global <2 x half> <half 0.000000e+00, half 0.000000e+00>
27+
@g = common dso_local local_unnamed_addr addrspace(1) global <4 x half> <half 0.000000e+00, half 0.000000e+00, half 0.000000e+00, half 0.000000e+00>
28+
29+
; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
30+
; CHECK: OpAtomicFMaxEXT %[[TyF16Vec2]] %[[f]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec2]]
31+
define dso_local spir_func void @test1() local_unnamed_addr {
32+
entry:
33+
%minval = atomicrmw fmin ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
34+
%maxval = atomicrmw fmax ptr addrspace(1) @f, <2 x half> <half 42.000000e+00, half 42.000000e+00> seq_cst
35+
ret void
36+
}
37+
38+
; CHECK-DAG: OpAtomicFMinEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
39+
; CHECK: OpAtomicFMaxEXT %[[TyF16Vec4]] %[[g]] %[[ScopeAllSvmDevices]] %[[MemSeqCst]] %[[ConstF16Vec4]]
40+
define dso_local spir_func void @test2() local_unnamed_addr {
41+
entry:
42+
%minval = atomicrmw fmin ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
43+
%maxval = atomicrmw fmax ptr addrspace(1) @g, <4 x half> <half 42.000000e+00, half 42.000000e+00, half 42.000000e+00, half 42.000000e+00> seq_cst
44+
ret void
45+
}

0 commit comments

Comments
 (0)