Skip to content

Commit 6248fda

Browse files
authored
Handle coop matrix in fix storage class (KhronosGroup#5729)
1 parent 7c77897 commit 6248fda

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

source/opt/fix_storage_class.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
312312
case spv::Op::OpTypeRuntimeArray:
313313
case spv::Op::OpTypeMatrix:
314314
case spv::Op::OpTypeVector:
315+
case spv::Op::OpTypeCooperativeMatrixKHR:
315316
id = type_inst->GetSingleWordInOperand(0);
316317
break;
317318
case spv::Op::OpTypeStruct: {

test/opt/fix_storage_class_test.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,43 @@ TEST_F(FixStorageClassTest, SupportsU64Index) {
916916
SinglePassRunAndMatch<FixStorageClass>(text, false);
917917
}
918918

919+
TEST_F(FixStorageClassTest, CorrectlyProcessAccessChainOnCoopMatrix) {
920+
const std::string text = R"(OpCapability CooperativeMatrixKHR
921+
OpCapability Shader
922+
OpExtension "SPV_KHR_cooperative_matrix"
923+
OpMemoryModel Logical GLSL450
924+
OpEntryPoint GLCompute %1 "main"
925+
OpExecutionMode %1 LocalSize 64 1 1
926+
OpSource HLSL 600
927+
%int = OpTypeInt 32 1
928+
%int_0 = OpConstant %int 0
929+
%uint = OpTypeInt 32 0
930+
%uint_0 = OpConstant %uint 0
931+
%uint_3 = OpConstant %uint 3
932+
%uint_16 = OpConstant %uint 16
933+
%uint_4 = OpConstant %uint 4
934+
%9 = OpTypeCooperativeMatrixKHR %int %uint_3 %uint_16 %uint_4 %uint_0
935+
%void = OpTypeVoid
936+
%11 = OpTypeFunction %void
937+
%_struct_12 = OpTypeStruct %9
938+
%_ptr_Function__struct_12 = OpTypePointer Function %_struct_12
939+
%_ptr_Function_9 = OpTypePointer Function %9
940+
%_ptr_Function_int = OpTypePointer Function %int
941+
%_ptr_Function__ptr_Function_int = OpTypePointer Function %_ptr_Function_int
942+
%1 = OpFunction %void None %11
943+
%17 = OpLabel
944+
%18 = OpVariable %_ptr_Function__ptr_Function_int Function
945+
%19 = OpVariable %_ptr_Function__struct_12 Function
946+
%20 = OpAccessChain %_ptr_Function_9 %19 %int_0
947+
%21 = OpAccessChain %_ptr_Function_int %20 %uint_4
948+
OpStore %18 %21
949+
OpReturn
950+
OpFunctionEnd
951+
)";
952+
953+
SinglePassRunAndCheck<FixStorageClass>(text, text, false, false);
954+
}
955+
919956
} // namespace
920957
} // namespace opt
921958
} // namespace spvtools

0 commit comments

Comments
 (0)