@@ -8582,6 +8582,136 @@ OpCooperativeVectorReduceSumAccumulateNV %array_ptr %offset %f16c
8582
8582
HasSubstr (" OpCooperativeVectorReduceSumAccumulateNV V type <id> "
8583
8583
" '28[%v4half]' is not a cooperative vector type." ));
8584
8584
}
8585
+
8586
+ TEST_F (ValidateMemory, CoopMatMatrixBFloatFAdd) {
8587
+ const std::string body =
8588
+ R"(
8589
+ OpCapability Shader
8590
+ OpCapability Float16
8591
+ OpCapability BFloat16TypeKHR
8592
+ OpCapability BFloat16CooperativeMatrixKHR
8593
+ OpCapability VulkanMemoryModel
8594
+ OpCapability CooperativeMatrixKHR
8595
+ OpExtension "SPV_KHR_bfloat16"
8596
+ OpExtension "SPV_KHR_vulkan_memory_model"
8597
+ OpExtension "SPV_KHR_cooperative_matrix"
8598
+ OpMemoryModel Logical Vulkan
8599
+ OpEntryPoint GLCompute %main "main" %_ %__0 %__1
8600
+ OpExecutionMode %main LocalSize 32 1 1
8601
+ OpDecorate %_arr_bfloat16_uint_64 ArrayStride 2
8602
+ OpDecorate %A Block
8603
+ OpMemberDecorate %A 0 Offset 0
8604
+ OpDecorate %_ Binding 0
8605
+ OpDecorate %_ DescriptorSet 0
8606
+ OpDecorate %_arr_bfloat16_uint_64_0 ArrayStride 2
8607
+ OpDecorate %B Block
8608
+ OpMemberDecorate %B 0 Offset 0
8609
+ OpDecorate %__0 Binding 1
8610
+ OpDecorate %__0 DescriptorSet 0
8611
+ OpDecorate %_arr_bfloat16_uint_64_1 ArrayStride 2
8612
+ OpDecorate %R Block
8613
+ OpMemberDecorate %R 0 Offset 0
8614
+ OpDecorate %__1 Binding 2
8615
+ OpDecorate %__1 DescriptorSet 0
8616
+ %void = OpTypeVoid
8617
+ %4 = OpTypeFunction %void
8618
+ %bfloat16 = OpTypeFloat 16 BFloat16KHR
8619
+ %uint = OpTypeInt 32 0
8620
+ %uint_3 = OpConstant %uint 3
8621
+ %uint_8 = OpConstant %uint 8
8622
+ %uint_0 = OpConstant %uint 0
8623
+ %12 = OpTypeCooperativeMatrixKHR %bfloat16 %uint_3 %uint_8 %uint_8 %uint_0
8624
+ %_ptr_Function_12 = OpTypePointer Function %12
8625
+ %uint_64 = OpConstant %uint 64
8626
+ %_arr_bfloat16_uint_64 = OpTypeArray %bfloat16 %uint_64
8627
+ %A = OpTypeStruct %_arr_bfloat16_uint_64
8628
+ %_ptr_StorageBuffer_A = OpTypePointer StorageBuffer %A
8629
+ %_ = OpVariable %_ptr_StorageBuffer_A StorageBuffer
8630
+ %int = OpTypeInt 32 1
8631
+ %int_0 = OpConstant %int 0
8632
+ %_ptr_StorageBuffer_bfloat16 = OpTypePointer StorageBuffer %bfloat16
8633
+ %_arr_bfloat16_uint_64_0 = OpTypeArray %bfloat16 %uint_64
8634
+ %B = OpTypeStruct %_arr_bfloat16_uint_64_0
8635
+ %_ptr_StorageBuffer_B = OpTypePointer StorageBuffer %B
8636
+ %__0 = OpVariable %_ptr_StorageBuffer_B StorageBuffer
8637
+ %v3uint = OpTypeVector %uint 3
8638
+ %uint_32 = OpConstant %uint 32
8639
+ %uint_1 = OpConstant %uint 1
8640
+ %35 = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8641
+ %_arr_bfloat16_uint_64_1 = OpTypeArray %bfloat16 %uint_64
8642
+ %R = OpTypeStruct %_arr_bfloat16_uint_64_1
8643
+ %_ptr_StorageBuffer_R = OpTypePointer StorageBuffer %R
8644
+ %__1 = OpVariable %_ptr_StorageBuffer_R StorageBuffer
8645
+ %main = OpFunction %void None %4
8646
+ %6 = OpLabel
8647
+ %matX = OpVariable %_ptr_Function_12 Function
8648
+ %matY = OpVariable %_ptr_Function_12 Function
8649
+ %23 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %_ %int_0 %uint_0
8650
+ %24 = OpCooperativeMatrixLoadKHR %12 %23 %int_0 %uint_8 None
8651
+ OpStore %matX %24
8652
+ %30 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %__0 %int_0 %uint_0
8653
+ %31 = OpCooperativeMatrixLoadKHR %12 %30 %int_0 %uint_8 None
8654
+ OpStore %matY %31
8655
+ %32 = OpLoad %12 %matX
8656
+ %33 = OpLoad %12 %matY
8657
+ %34 = OpFAdd %12 %32 %33
8658
+ OpReturn
8659
+ OpFunctionEnd
8660
+ )" ;
8661
+
8662
+ CompileSuccessfully (body.c_str (), SPV_ENV_VULKAN_1_3);
8663
+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions (SPV_ENV_VULKAN_1_3));
8664
+ EXPECT_THAT (getDiagnosticString (),
8665
+ HasSubstr (" FAdd doesn't support BFloat16 type" ));
8666
+ }
8667
+
8668
+ TEST_F (ValidateMemory, CoopMatMatrixFloat8FAdd) {
8669
+ const std::string body =
8670
+ R"(
8671
+ OpCapability Shader
8672
+ OpCapability Float8EXT
8673
+ OpCapability Float8CooperativeMatrixEXT
8674
+ OpCapability VulkanMemoryModel
8675
+ OpCapability CooperativeMatrixKHR
8676
+ OpExtension "SPV_EXT_float8"
8677
+ OpExtension "SPV_KHR_cooperative_matrix"
8678
+ OpExtension "SPV_KHR_vulkan_memory_model"
8679
+ OpMemoryModel Logical Vulkan
8680
+ OpEntryPoint GLCompute %main "main"
8681
+ OpExecutionMode %main LocalSize 32 1 1
8682
+ OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
8683
+ %void = OpTypeVoid
8684
+ %4 = OpTypeFunction %void
8685
+ %fp8e4m3 = OpTypeFloat 8 Float8E4M3EXT
8686
+ %uint = OpTypeInt 32 0
8687
+ %uint_3 = OpConstant %uint 3
8688
+ %uint_16 = OpConstant %uint 16
8689
+ %uint_0 = OpConstant %uint 0
8690
+ %12 = OpTypeCooperativeMatrixKHR %fp8e4m3 %uint_3 %uint_16 %uint_16 %uint_0
8691
+ %_ptr_Function_12 = OpTypePointer Function %12
8692
+ %v3uint = OpTypeVector %uint 3
8693
+ %uint_32 = OpConstant %uint 32
8694
+ %uint_1 = OpConstant %uint 1
8695
+ %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8696
+ %main = OpFunction %void None %4
8697
+ %6 = OpLabel
8698
+ %matR = OpVariable %_ptr_Function_12 Function
8699
+ %matX = OpVariable %_ptr_Function_12 Function
8700
+ %matY = OpVariable %_ptr_Function_12 Function
8701
+ %16 = OpLoad %12 %matX
8702
+ %18 = OpLoad %12 %matY
8703
+ %19 = OpFAdd %12 %16 %18
8704
+ OpStore %matR %19
8705
+ OpReturn
8706
+ OpFunctionEnd
8707
+ )" ;
8708
+
8709
+ CompileSuccessfully (body.c_str (), SPV_ENV_VULKAN_1_3);
8710
+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions (SPV_ENV_VULKAN_1_3));
8711
+ EXPECT_THAT (getDiagnosticString (),
8712
+ HasSubstr (" FAdd doesn't support FP8 E4M3/E5M2 types" ));
8713
+ }
8714
+
8585
8715
} // namespace
8586
8716
} // namespace val
8587
8717
} // namespace spvtools
0 commit comments