Skip to content

Commit 5a9241b

Browse files
[Backport to llvm_release_140] Add FP4/FP8 operand support for SubgroupMatrixMultiplyAccumulateINTEL (#3609) (#3634)
Extend SubgroupMatrixMultiplyAccumulateINTEL to support packed 4-bit and 8-bit floating-point matrix operands by implementing extensions: - SPV_INTEL_subgroup_matrix_multiply_accumulate_float4 - SPV_INTEL_subgroup_matrix_multiply_accumulate_float8 These extensions add operand flags that interpret packed integer data as FP4/FP8 without requiring actual FP4/FP8 type support added by SPV_INTEL_float4 or SPV_EXT_float8. FP4 operands: `MatrixAPackedFloat4E2M1INTEL` (0x40000) / `MatrixBPackedFloat4E2M1INTEL` (0x80000) FP8 operands: `MatrixAPackedFloat8E4M3INTEL` (0x4000) / `MatrixBPackedFloat8E4M3INTEL` (0x8000) `MatrixAPackedFloat8E5M2INTEL` (0x10000) / `MatrixBPackedFloat8E5M2INTEL` (0x20000) Specs: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate_float4.asciidoc https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate_float8.asciidoc Co-authored-by: Viktoria Maximova <viktoria.maksimova@intel.com>
1 parent f508931 commit 5a9241b

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ EXT(SPV_INTEL_maximum_registers)
7474
EXT(SPV_INTEL_bindless_images)
7575
EXT(SPV_INTEL_2d_block_io)
7676
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
77+
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate_float4)
78+
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate_float8)
7779
EXT(SPV_KHR_bfloat16)
7880
EXT(SPV_INTEL_bfloat16_arithmetic)
7981
EXT(SPV_INTEL_16bit_atomics)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4092,6 +4092,66 @@ class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst
40924092
llvm::Optional<ExtensionID> getRequiredExtension() const override {
40934093
return ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate;
40944094
}
4095+
4096+
protected:
4097+
void validate() const override {
4098+
SPIRVInstTemplateBase::validate();
4099+
4100+
// Check if FP4 or FP8 matrix operands are used
4101+
// Operands parameter is the last operand (index 4)
4102+
auto *NonConstThis =
4103+
const_cast<SPIRVSubgroupMatrixMultiplyAccumulateINTELInst *>(this);
4104+
if (NonConstThis->getOperands().size() > 4) {
4105+
const SPIRVConstant *OperandsConst =
4106+
static_cast<const SPIRVConstant *>(NonConstThis->getOperand(4));
4107+
uint64_t OperandsMask = OperandsConst->getZExtIntValue();
4108+
4109+
// FP4 operand bits
4110+
constexpr uint64_t FP4Mask =
4111+
spv::internal::
4112+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat4E2M1INTELMask |
4113+
spv::internal::
4114+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat4E2M1INTELMask;
4115+
4116+
// FP8 operand bits
4117+
constexpr uint64_t FP8Mask =
4118+
spv::internal::
4119+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E4M3INTELMask |
4120+
spv::internal::
4121+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E4M3INTELMask |
4122+
spv::internal::
4123+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E5M2INTELMask |
4124+
spv::internal::
4125+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E5M2INTELMask;
4126+
4127+
if ((OperandsMask & FP4Mask) != 0) {
4128+
getModule()->getErrorLog().checkError(
4129+
getModule()->isAllowedToUseExtension(
4130+
ExtensionID::
4131+
SPV_INTEL_subgroup_matrix_multiply_accumulate_float4),
4132+
SPIRVEC_RequiresExtension,
4133+
"SPV_INTEL_subgroup_matrix_multiply_accumulate_float4\n"
4134+
"SubgroupMatrixMultiplyAccumulateINTEL with FP4 operand flags "
4135+
"requires this extension");
4136+
getModule()->addExtension(
4137+
ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate_float4);
4138+
}
4139+
4140+
if ((OperandsMask & FP8Mask) != 0) {
4141+
getModule()->getErrorLog().checkError(
4142+
getModule()->isAllowedToUseExtension(
4143+
ExtensionID::
4144+
SPV_INTEL_subgroup_matrix_multiply_accumulate_float8),
4145+
SPIRVEC_RequiresExtension,
4146+
"SPV_INTEL_subgroup_matrix_multiply_accumulate_float8\n"
4147+
"SubgroupMatrixMultiplyAccumulateINTEL with FP8 operand flags "
4148+
"requires this extension");
4149+
getModule()->addExtension(
4150+
ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate_float8);
4151+
}
4152+
}
4153+
}
4154+
40954155
SPIRVCapVec getRequiredCapability() const override {
40964156
return getVec(CapabilitySubgroupMatrixMultiplyAccumulateINTEL);
40974157
}

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ enum InternalBuiltIn {
178178
IBuiltInDeviceBarrierValidINTEL = 6186,
179179
};
180180

181+
enum InternalMatrixMultiplyAccumulateOperandsMask {
182+
// FP8 matrix operands
183+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E4M3INTELMask = 0x4000,
184+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E4M3INTELMask = 0x8000,
185+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E5M2INTELMask = 0x10000,
186+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E5M2INTELMask = 0x20000,
187+
// FP4 matrix operands
188+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat4E2M1INTELMask = 0x40000,
189+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat4E2M1INTELMask = 0x80000,
190+
};
191+
181192
enum class LoadCacheControlINTEL {
182193
Uncached = 0,
183194
Cached = 1,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; This test checks that SubgroupMatrixMultiplyAccumulateINTEL with FP4 operand flags
2+
; requires the SPV_INTEL_subgroup_matrix_multiply_accumulate_float4 extension.
3+
4+
; RUN: llvm-as %s -o %t.bc
5+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate,+SPV_INTEL_subgroup_matrix_multiply_accumulate_float4
6+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
7+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
8+
9+
; RUN: not llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
11+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
12+
; CHECK-ERROR: SPV_INTEL_subgroup_matrix_multiply_accumulate_float4
13+
14+
; CHECK-SPIRV-DAG: Capability SubgroupMatrixMultiplyAccumulateINTEL
15+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate"
16+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate_float4"
17+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 262144
18+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 524288
19+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 786432
20+
21+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
22+
target triple = "spir64-unknown-unknown"
23+
24+
; Test MatrixAPackedFloat4E2M1INTEL operand (0x40000 = 262144)
25+
define spir_func <4 x float> @test_fp4_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
26+
entry:
27+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 262144)
28+
ret <4 x float> %result
29+
}
30+
31+
; Test MatrixBPackedFloat4E2M1INTEL operand (0x80000 = 524288)
32+
define spir_func <4 x float> @test_fp4_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
33+
entry:
34+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 524288)
35+
ret <4 x float> %result
36+
}
37+
38+
; Test both FP4 operands (0xC0000 = 786432)
39+
define spir_func <4 x float> @test_fp4_matrix_both(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
40+
entry:
41+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 786432)
42+
ret <4 x float> %result
43+
}
44+
45+
declare spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32, <4 x i8>, <8 x i8>, <4 x float>, i32)
46+
47+
!opencl.spir.version = !{!0}
48+
!spirv.Source = !{!1}
49+
!llvm.ident = !{!2}
50+
51+
!0 = !{i32 1, i32 0}
52+
!1 = !{i32 4, i32 100000}
53+
!2 = !{!"clang version 17.0.0"}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; This test checks that SubgroupMatrixMultiplyAccumulateINTEL with FP8 operand flags
2+
; requires the SPV_INTEL_subgroup_matrix_multiply_accumulate_float8 extension.
3+
4+
; RUN: llvm-as %s -o %t.bc
5+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate,+SPV_INTEL_subgroup_matrix_multiply_accumulate_float8
6+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
7+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
8+
9+
; RUN: not llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
11+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
12+
; CHECK-ERROR: SPV_INTEL_subgroup_matrix_multiply_accumulate_float8
13+
14+
; CHECK-SPIRV-DAG: Capability SubgroupMatrixMultiplyAccumulateINTEL
15+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate"
16+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate_float8"
17+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 16384
18+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 32768
19+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 65536
20+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 131072
21+
22+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
23+
target triple = "spir64-unknown-unknown"
24+
25+
; Test MatrixAPackedFloat8E4M3INTEL operand (0x4000 = 16384)
26+
define spir_func <4 x float> @test_fp8_e4m3_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
27+
entry:
28+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 16384)
29+
ret <4 x float> %result
30+
}
31+
32+
; Test MatrixBPackedFloat8E4M3INTEL operand (0x8000 = 32768)
33+
define spir_func <4 x float> @test_fp8_e4m3_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
34+
entry:
35+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 32768)
36+
ret <4 x float> %result
37+
}
38+
39+
; Test MatrixAPackedFloat8E5M2INTEL operand (0x10000 = 65536)
40+
define spir_func <4 x float> @test_fp8_e5m2_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
41+
entry:
42+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 65536)
43+
ret <4 x float> %result
44+
}
45+
46+
; Test MatrixBPackedFloat8E5M2INTEL operand (0x20000 = 131072)
47+
define spir_func <4 x float> @test_fp8_e5m2_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
48+
entry:
49+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 131072)
50+
ret <4 x float> %result
51+
}
52+
53+
declare spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32, <4 x i8>, <8 x i8>, <4 x float>, i32)
54+
55+
!opencl.spir.version = !{!0}
56+
!spirv.Source = !{!1}
57+
!llvm.ident = !{!2}
58+
59+
!0 = !{i32 1, i32 0}
60+
!1 = !{i32 4, i32 100000}
61+
!2 = !{!"clang version 17.0.0"}

0 commit comments

Comments
 (0)