Skip to content

Commit 8cf1bf9

Browse files
spirv-val: Check OpTypeCooperativeMatrixKHR for bfloat16/fp8 (KhronosGroup#6220)
1 parent aaa9485 commit 8cf1bf9

File tree

4 files changed

+175
-14
lines changed

4 files changed

+175
-14
lines changed

source/val/validate_invalid_type.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
6969
case spv::Op::OpGroupNonUniformFMul:
7070
case spv::Op::OpGroupNonUniformFMin: {
7171
const uint32_t result_type = inst->type_id();
72-
if (_.IsBfloat16ScalarType(result_type) ||
73-
_.IsBfloat16VectorType(result_type)) {
72+
if (_.IsBfloat16Type(result_type)) {
7473
return _.diag(SPV_ERROR_INVALID_DATA, inst)
7574
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
7675
}
77-
if (_.IsFP8ScalarOrVectorType(result_type)) {
76+
if (_.IsFP8Type(result_type)) {
7877
return _.diag(SPV_ERROR_INVALID_DATA, inst)
7978
<< spvOpcodeString(opcode)
8079
<< " doesn't support FP8 E4M3/E5M2 types.";
@@ -103,12 +102,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
103102
case spv::Op::OpIsNormal:
104103
case spv::Op::OpSignBitSet: {
105104
const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
106-
if (_.IsBfloat16ScalarType(operand_type) ||
107-
_.IsBfloat16VectorType(operand_type)) {
105+
if (_.IsBfloat16Type(operand_type)) {
108106
return _.diag(SPV_ERROR_INVALID_DATA, inst)
109107
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
110108
}
111-
if (_.IsFP8ScalarOrVectorType(operand_type)) {
109+
if (_.IsFP8Type(operand_type)) {
112110
return _.diag(SPV_ERROR_INVALID_DATA, inst)
113111
<< spvOpcodeString(opcode)
114112
<< " doesn't support FP8 E4M3/E5M2 types.";
@@ -118,12 +116,11 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
118116

119117
case spv::Op::OpGroupNonUniformAllEqual: {
120118
const auto value_type = _.GetOperandTypeId(inst, 3);
121-
if (_.IsBfloat16ScalarType(value_type) ||
122-
_.IsBfloat16VectorType(value_type)) {
119+
if (_.IsBfloat16Type(value_type)) {
123120
return _.diag(SPV_ERROR_INVALID_DATA, inst)
124121
<< spvOpcodeString(opcode) << " doesn't support BFloat16 type.";
125122
}
126-
if (_.IsFP8ScalarOrVectorType(value_type)) {
123+
if (_.IsFP8Type(value_type)) {
127124
return _.diag(SPV_ERROR_INVALID_DATA, inst)
128125
<< spvOpcodeString(opcode)
129126
<< " doesn't support FP8 E4M3/E5M2 types.";
@@ -140,12 +137,12 @@ spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst) {
140137
uint32_t res_component_type = 0;
141138
if (_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
142139
&res_col_type, &res_component_type)) {
143-
if (_.IsBfloat16ScalarType(res_component_type)) {
140+
if (_.IsBfloat16Type(res_component_type)) {
144141
return _.diag(SPV_ERROR_INVALID_DATA, inst)
145142
<< spvOpcodeString(opcode)
146143
<< " doesn't support BFloat16 type.";
147144
}
148-
if (_.IsFP8ScalarOrVectorType(res_component_type)) {
145+
if (_.IsFP8Type(res_component_type)) {
149146
return _.diag(SPV_ERROR_INVALID_DATA, inst)
150147
<< spvOpcodeString(opcode)
151148
<< " doesn't support FP8 E4M3/E5M2 types.";

source/val/validation_state.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,24 @@ bool ValidationState_t::IsBfloat16VectorType(uint32_t id) const {
984984
return false;
985985
}
986986

987+
bool ValidationState_t::IsBfloat16CoopMatType(uint32_t id) const {
988+
const Instruction* inst = FindDef(id);
989+
if (!inst) {
990+
return false;
991+
}
992+
993+
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
994+
return IsBfloat16ScalarType(inst->word(2));
995+
}
996+
997+
return false;
998+
}
999+
1000+
bool ValidationState_t::IsBfloat16Type(uint32_t id) const {
1001+
return IsBfloat16ScalarType(id) || IsBfloat16VectorType(id) ||
1002+
IsBfloat16CoopMatType(id);
1003+
}
1004+
9871005
bool ValidationState_t::IsFP8ScalarType(uint32_t id) const {
9881006
const Instruction* inst = FindDef(id);
9891007
if (inst && inst->opcode() == spv::Op::OpTypeFloat) {
@@ -1011,8 +1029,21 @@ bool ValidationState_t::IsFP8VectorType(uint32_t id) const {
10111029
return false;
10121030
}
10131031

1014-
bool ValidationState_t::IsFP8ScalarOrVectorType(uint32_t id) const {
1015-
return IsFP8ScalarType(id) || IsFP8VectorType(id);
1032+
bool ValidationState_t::IsFP8CoopMatType(uint32_t id) const {
1033+
const Instruction* inst = FindDef(id);
1034+
if (!inst) {
1035+
return false;
1036+
}
1037+
1038+
if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
1039+
return IsFP8ScalarType(inst->word(2));
1040+
}
1041+
1042+
return false;
1043+
}
1044+
1045+
bool ValidationState_t::IsFP8Type(uint32_t id) const {
1046+
return IsFP8ScalarType(id) || IsFP8VectorType(id) || IsFP8CoopMatType(id);
10161047
}
10171048

10181049
bool ValidationState_t::IsFloatScalarType(uint32_t id) const {

source/val/validation_state.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,12 @@ class ValidationState_t {
638638
bool IsScalarType(uint32_t id) const;
639639
bool IsBfloat16ScalarType(uint32_t id) const;
640640
bool IsBfloat16VectorType(uint32_t id) const;
641+
bool IsBfloat16CoopMatType(uint32_t id) const;
642+
bool IsBfloat16Type(uint32_t id) const;
641643
bool IsFP8ScalarType(uint32_t id) const;
642644
bool IsFP8VectorType(uint32_t id) const;
643-
bool IsFP8ScalarOrVectorType(uint32_t id) const;
645+
bool IsFP8CoopMatType(uint32_t id) const;
646+
bool IsFP8Type(uint32_t id) const;
644647
bool IsFloatScalarType(uint32_t id) const;
645648
bool IsFloatArrayType(uint32_t id) const;
646649
bool IsFloatVectorType(uint32_t id) const;

test/val/val_memory_test.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8582,6 +8582,136 @@ OpCooperativeVectorReduceSumAccumulateNV %array_ptr %offset %f16c
85828582
HasSubstr("OpCooperativeVectorReduceSumAccumulateNV V type <id> "
85838583
"'28[%v4half]' is not a cooperative vector type."));
85848584
}
8585+
8586+
TEST_F(ValidateMemory, CoopMatMatrixBFloatFAdd) {
8587+
const std::string body =
8588+
R"(
8589+
OpCapability Shader
8590+
OpCapability Float16
8591+
OpCapability BFloat16TypeKHR
8592+
OpCapability BFloat16CooperativeMatrixKHR
8593+
OpCapability VulkanMemoryModel
8594+
OpCapability CooperativeMatrixKHR
8595+
OpExtension "SPV_KHR_bfloat16"
8596+
OpExtension "SPV_KHR_vulkan_memory_model"
8597+
OpExtension "SPV_KHR_cooperative_matrix"
8598+
OpMemoryModel Logical Vulkan
8599+
OpEntryPoint GLCompute %main "main" %_ %__0 %__1
8600+
OpExecutionMode %main LocalSize 32 1 1
8601+
OpDecorate %_arr_bfloat16_uint_64 ArrayStride 2
8602+
OpDecorate %A Block
8603+
OpMemberDecorate %A 0 Offset 0
8604+
OpDecorate %_ Binding 0
8605+
OpDecorate %_ DescriptorSet 0
8606+
OpDecorate %_arr_bfloat16_uint_64_0 ArrayStride 2
8607+
OpDecorate %B Block
8608+
OpMemberDecorate %B 0 Offset 0
8609+
OpDecorate %__0 Binding 1
8610+
OpDecorate %__0 DescriptorSet 0
8611+
OpDecorate %_arr_bfloat16_uint_64_1 ArrayStride 2
8612+
OpDecorate %R Block
8613+
OpMemberDecorate %R 0 Offset 0
8614+
OpDecorate %__1 Binding 2
8615+
OpDecorate %__1 DescriptorSet 0
8616+
%void = OpTypeVoid
8617+
%4 = OpTypeFunction %void
8618+
%bfloat16 = OpTypeFloat 16 BFloat16KHR
8619+
%uint = OpTypeInt 32 0
8620+
%uint_3 = OpConstant %uint 3
8621+
%uint_8 = OpConstant %uint 8
8622+
%uint_0 = OpConstant %uint 0
8623+
%12 = OpTypeCooperativeMatrixKHR %bfloat16 %uint_3 %uint_8 %uint_8 %uint_0
8624+
%_ptr_Function_12 = OpTypePointer Function %12
8625+
%uint_64 = OpConstant %uint 64
8626+
%_arr_bfloat16_uint_64 = OpTypeArray %bfloat16 %uint_64
8627+
%A = OpTypeStruct %_arr_bfloat16_uint_64
8628+
%_ptr_StorageBuffer_A = OpTypePointer StorageBuffer %A
8629+
%_ = OpVariable %_ptr_StorageBuffer_A StorageBuffer
8630+
%int = OpTypeInt 32 1
8631+
%int_0 = OpConstant %int 0
8632+
%_ptr_StorageBuffer_bfloat16 = OpTypePointer StorageBuffer %bfloat16
8633+
%_arr_bfloat16_uint_64_0 = OpTypeArray %bfloat16 %uint_64
8634+
%B = OpTypeStruct %_arr_bfloat16_uint_64_0
8635+
%_ptr_StorageBuffer_B = OpTypePointer StorageBuffer %B
8636+
%__0 = OpVariable %_ptr_StorageBuffer_B StorageBuffer
8637+
%v3uint = OpTypeVector %uint 3
8638+
%uint_32 = OpConstant %uint 32
8639+
%uint_1 = OpConstant %uint 1
8640+
%35 = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8641+
%_arr_bfloat16_uint_64_1 = OpTypeArray %bfloat16 %uint_64
8642+
%R = OpTypeStruct %_arr_bfloat16_uint_64_1
8643+
%_ptr_StorageBuffer_R = OpTypePointer StorageBuffer %R
8644+
%__1 = OpVariable %_ptr_StorageBuffer_R StorageBuffer
8645+
%main = OpFunction %void None %4
8646+
%6 = OpLabel
8647+
%matX = OpVariable %_ptr_Function_12 Function
8648+
%matY = OpVariable %_ptr_Function_12 Function
8649+
%23 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %_ %int_0 %uint_0
8650+
%24 = OpCooperativeMatrixLoadKHR %12 %23 %int_0 %uint_8 None
8651+
OpStore %matX %24
8652+
%30 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %__0 %int_0 %uint_0
8653+
%31 = OpCooperativeMatrixLoadKHR %12 %30 %int_0 %uint_8 None
8654+
OpStore %matY %31
8655+
%32 = OpLoad %12 %matX
8656+
%33 = OpLoad %12 %matY
8657+
%34 = OpFAdd %12 %32 %33
8658+
OpReturn
8659+
OpFunctionEnd
8660+
)";
8661+
8662+
CompileSuccessfully(body.c_str(), SPV_ENV_VULKAN_1_3);
8663+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_3));
8664+
EXPECT_THAT(getDiagnosticString(),
8665+
HasSubstr("FAdd doesn't support BFloat16 type"));
8666+
}
8667+
8668+
TEST_F(ValidateMemory, CoopMatMatrixFloat8FAdd) {
8669+
const std::string body =
8670+
R"(
8671+
OpCapability Shader
8672+
OpCapability Float8EXT
8673+
OpCapability Float8CooperativeMatrixEXT
8674+
OpCapability VulkanMemoryModel
8675+
OpCapability CooperativeMatrixKHR
8676+
OpExtension "SPV_EXT_float8"
8677+
OpExtension "SPV_KHR_cooperative_matrix"
8678+
OpExtension "SPV_KHR_vulkan_memory_model"
8679+
OpMemoryModel Logical Vulkan
8680+
OpEntryPoint GLCompute %main "main"
8681+
OpExecutionMode %main LocalSize 32 1 1
8682+
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
8683+
%void = OpTypeVoid
8684+
%4 = OpTypeFunction %void
8685+
%fp8e4m3 = OpTypeFloat 8 Float8E4M3EXT
8686+
%uint = OpTypeInt 32 0
8687+
%uint_3 = OpConstant %uint 3
8688+
%uint_16 = OpConstant %uint 16
8689+
%uint_0 = OpConstant %uint 0
8690+
%12 = OpTypeCooperativeMatrixKHR %fp8e4m3 %uint_3 %uint_16 %uint_16 %uint_0
8691+
%_ptr_Function_12 = OpTypePointer Function %12
8692+
%v3uint = OpTypeVector %uint 3
8693+
%uint_32 = OpConstant %uint 32
8694+
%uint_1 = OpConstant %uint 1
8695+
%gl_WorkGroupSize = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8696+
%main = OpFunction %void None %4
8697+
%6 = OpLabel
8698+
%matR = OpVariable %_ptr_Function_12 Function
8699+
%matX = OpVariable %_ptr_Function_12 Function
8700+
%matY = OpVariable %_ptr_Function_12 Function
8701+
%16 = OpLoad %12 %matX
8702+
%18 = OpLoad %12 %matY
8703+
%19 = OpFAdd %12 %16 %18
8704+
OpStore %matR %19
8705+
OpReturn
8706+
OpFunctionEnd
8707+
)";
8708+
8709+
CompileSuccessfully(body.c_str(), SPV_ENV_VULKAN_1_3);
8710+
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_3));
8711+
EXPECT_THAT(getDiagnosticString(),
8712+
HasSubstr("FAdd doesn't support FP8 E4M3/E5M2 types"));
8713+
}
8714+
85858715
} // namespace
85868716
} // namespace val
85878717
} // namespace spvtools

0 commit comments

Comments
 (0)