Skip to content

Commit 4c243ab

Browse files
committed
[SPIR-V] Support SPV_INTEL_int4 extension
Adds support for native 4-bit type. Spec: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc
1 parent 09fd8f0 commit 4c243ab

File tree

8 files changed

+100
-7
lines changed

8 files changed

+100
-7
lines changed

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
9999
{"SPV_INTEL_ternary_bitwise_function",
100100
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
101101
{"SPV_INTEL_2d_block_io",
102-
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
102+
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103+
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
103104

104105
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
105106
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
154154
report_fatal_error("Unsupported integer width!");
155155
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
156156
if (ST.canUseExtension(
157-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
157+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
158+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
158159
return Width;
159160
if (Width <= 8)
160161
Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
174175
const SPIRVSubtarget &ST =
175176
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
176177
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
177-
if ((!isPowerOf2_32(Width) || Width < 8) &&
178-
ST.canUseExtension(
179-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
178+
if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
179+
MIRBuilder.buildInstr(SPIRV::OpExtension)
180+
.addImm(SPIRV::Extension::SPV_INTEL_int4);
181+
MIRBuilder.buildInstr(SPIRV::OpCapability)
182+
.addImm(SPIRV::Capability::Int4TypeINTEL);
183+
} else if ((!isPowerOf2_32(Width) || Width < 8) &&
184+
ST.canUseExtension(
185+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
180186
MIRBuilder.buildInstr(SPIRV::OpExtension)
181187
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
182188
MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
15631569
const MachineInstr *NewMI =
15641570
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
15651571
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1572+
const Type *ET = getTypeForSPIRVType(ElemType);
1573+
if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1574+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
1575+
.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1576+
MIRBuilder.buildInstr(SPIRV::OpCapability)
1577+
.addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
1578+
}
15661579
return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
15671580
.addDef(createTypeVReg(MIRBuilder))
15681581
.addUse(getSPIRVTypeID(ElemType))

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
128128
bool IsExtendedInts =
129129
ST.canUseExtension(
130130
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
131-
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
131+
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
132+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
132133
auto extendedScalarsAndVectors =
133134
[IsExtendedInts](const LegalityQuery &Query) {
134135
const LLT Ty = Query.Types[0];

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
492492
bool IsExtendedInts =
493493
ST->canUseExtension(
494494
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
495-
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
495+
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
496+
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
496497

497498
for (MachineBasicBlock *MBB : post_order(&MF)) {
498499
if (MBB->empty())

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
316316
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
317317
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
318318
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
319+
defm SPV_INTEL_int4 : ExtensionOperand<123>;
319320

320321
//===----------------------------------------------------------------------===//
321322
// Multiclass used to define Capabilities enum values and at the same time
@@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
521522
defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
522523
defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
523524
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
525+
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
526+
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
524527

525528
//===----------------------------------------------------------------------===//
526529
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: Capability Int4TypeINTEL
5+
; CHECK-DAG: Capability CooperativeMatrixKHR
6+
; CHECK-DAG: Extension "SPV_INTEL_int4"
7+
; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
8+
; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
9+
10+
; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
11+
; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
12+
; CHECK: CompositeConstruct %[[#CoopMatTy]]
13+
14+
define spir_kernel void @foo() {
15+
entry:
16+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
17+
ret void
18+
}
19+
20+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
2+
3+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
4+
; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made
5+
; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
6+
7+
; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
8+
; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
9+
; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
10+
; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
11+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
12+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
13+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
14+
15+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
16+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
17+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
18+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
19+
20+
define spir_kernel void @foo() {
21+
entry:
22+
%0 = alloca i4
23+
store i4 1, ptr %0
24+
%1 = load i4, ptr %0
25+
call spir_func void @boo(i4 %1)
26+
ret void
27+
}
28+
29+
declare spir_func void @boo(i4)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: Capability Int4TypeINTEL
5+
; CHECK: Extension "SPV_INTEL_int4"
6+
; CHECK: %[[#Int4:]] = OpTypeInt 4 0
7+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
8+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
9+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
10+
11+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
12+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
13+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
14+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
15+
16+
define spir_kernel void @foo() {
17+
entry:
18+
%0 = alloca i4
19+
store i4 1, ptr %0
20+
%1 = load i4, ptr %0
21+
call spir_func void @boo(i4 %1)
22+
ret void
23+
}
24+
25+
declare spir_func void @boo(i4)

0 commit comments

Comments
 (0)