From 9950a518d5a0ad0d3fa02b3a42f54af02c2d0713 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Thu, 24 Oct 2024 19:13:42 -0700 Subject: [PATCH 01/11] Add a new form of load/store operations for cooperative matrices that accepts two separate arguments: the row index and the column index. --- lib/SPIRV/SPIRVReader.cpp | 1 + lib/SPIRV/libSPIRV/SPIRVEnum.h | 2 + lib/SPIRV/libSPIRV/SPIRVInstruction.h | 21 +++ lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 2 + lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 4 + lib/SPIRV/libSPIRV/spirv_internal.hpp | 7 + .../joint_matrix_load_store_offset.ll | 152 ++++++++++++++++++ 7 files changed, 189 insertions(+) create mode 100644 test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index b639ae8cd4..0c4f537938 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -3705,6 +3705,7 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI, case internal::OpJointMatrixLoadINTEL: case OpCooperativeMatrixLoadKHR: case internal::OpCooperativeMatrixLoadCheckedINTEL: + case internal::OpCooperativeMatrixLoadOffsetINTEL: case internal::OpTaskSequenceCreateINTEL: case internal::OpConvertHandleToImageINTEL: case internal::OpConvertHandleToSampledImageINTEL: diff --git a/lib/SPIRV/libSPIRV/SPIRVEnum.h b/lib/SPIRV/libSPIRV/SPIRVEnum.h index 4c318be39b..ce20841574 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVEnum.h @@ -221,6 +221,8 @@ template <> inline void SPIRVMap::init() { {CapabilityCooperativeMatrixKHR}); ADD_VEC_INIT(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL, {CapabilityCooperativeMatrixKHR}); + ADD_VEC_INIT(internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL, + {CapabilityCooperativeMatrixKHR}); } template <> inline void SPIRVMap::init() { diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index f6d7142942..0a6648b793 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3730,6 +3730,27 @@ _SPIRV_OP(CooperativeMatrixStoreChecked, false, 8, true, 8) _SPIRV_OP(CooperativeMatrixConstructChecked, true, 8) #undef _SPIRV_OP +class SPIRVJointMatrixOffsetInstructionsINTELInstBase + : public SPIRVInstTemplateBase { +protected: + std::optional getRequiredExtension() const override { + return ExtensionID::SPV_INTEL_joint_matrix; + } + SPIRVCapVec getRequiredCapability() const override { + return getVec( + internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL); + } +}; + +#define _SPIRV_OP(x, ...) \ + typedef SPIRVInstTemplate< \ + SPIRVJointMatrixOffsetInstructionsINTELInstBase, \ + internal::Op##x##INTEL, __VA_ARGS__> \ + SPIRV##x##INTEL; +_SPIRV_OP(CooperativeMatrixLoadOffset, true, 8, true, 6) +_SPIRV_OP(CooperativeMatrixStoreOffset, false, 7, true, 7) +#undef _SPIRV_OP + class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase : public SPIRVInstTemplateBase { protected: diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h index ec442aacfa..6e75fe05bc 100644 --- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h +++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h @@ -671,6 +671,8 @@ template <> inline void SPIRVMap::init() { "CooperativeMatrixInvocationInstructionsINTEL"); add(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL, "CooperativeMatrixCheckedInstructionsINTEL"); + add(internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL, + "CooperativeMatrixOffsetInstructionsINTEL"); add(internal::CapabilitySubgroupRequirementsINTEL, "SubgroupRequirementsINTEL"); add(internal::CapabilityTaskSequenceINTEL, "TaskSequenceINTEL"); diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index f290080376..754574c4ac 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -22,6 +22,10 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL, internal::OpCooperativeMatrixLoadCheckedINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixStoreCheckedINTEL, internal::OpCooperativeMatrixStoreCheckedINTEL) +_SPIRV_OP_INTERNAL(CooperativeMatrixLoadOffsetINTEL, + internal::OpCooperativeMatrixLoadOffsetINTEL) +_SPIRV_OP_INTERNAL(CooperativeMatrixStoreOffsetINTEL, + internal::OpCooperativeMatrixStoreOffsetINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL, internal::OpCooperativeMatrixConstructCheckedINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixApplyFunctionINTEL, diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index cdec3a959c..391f436207 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -77,6 +77,8 @@ enum InternalOp { IOpCooperativeMatrixLoadCheckedINTEL = 6193, IOpCooperativeMatrixStoreCheckedINTEL = 6194, IOpCooperativeMatrixConstructCheckedINTEL = 6195, + IOpCooperativeMatrixLoadOffsetINTEL = 6212, + IOpCooperativeMatrixStoreOffsetINTEL = 6213, IOpJointMatrixWorkItemLengthINTEL = 6410, IOpTypeTaskSequenceINTEL = 6199, IOpComplexFMulINTEL = 6415, @@ -114,6 +116,7 @@ enum InternalCapability { ICapGlobalVariableDecorationsINTEL = 6146, ICapabilityTaskSequenceINTEL = 6162, ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192, + ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6211, ICapabilityCooperativeMatrixPrefetchINTEL = 6411, ICapabilityComplexFloatMulDivINTEL = 6414, ICapabilityTensorFloat32RoundingINTEL = 6425, @@ -187,6 +190,10 @@ _SPIRV_OP(Op, CooperativeMatrixLoadCheckedINTEL) _SPIRV_OP(Op, CooperativeMatrixStoreCheckedINTEL) _SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL) +_SPIRV_OP(Capability, CooperativeMatrixOffsetInstructionsINTEL) +_SPIRV_OP(Op, CooperativeMatrixLoadOffsetINTEL) +_SPIRV_OP(Op, CooperativeMatrixStoreOffsetINTEL) + _SPIRV_OP(Capability, CooperativeMatrixInvocationInstructionsINTEL) _SPIRV_OP(Op, CooperativeMatrixApplyFunctionINTEL) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll new file mode 100644 index 0000000000..34dc472242 --- /dev/null +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -0,0 +1,152 @@ +; This is an adapted copy of test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll + +; RUN: llvm-as < %s -o %t.bc +; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix -o %t.spv +; RUN: llvm-spirv %t.spv -to-text -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV + +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM + +; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR +; CHECK-SPIRV-DAG: Capability CooperativeMatrixOffsetInstructionsINTEL +; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0 +; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const16:]] 16 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1 +; CHECK-SPIRV-DAG: TypeFloat [[#Float32Ty:]] 32 +; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy1:]] [[#Float32Ty]] [[#Const16]] [[#Const16]] [[#Const3]] [[#Const3]] [[#Const2]] +; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy2:]] [[#Int16Ty]] [[#Const16]] [[#Const16]] [[#Const0]] [[#Const3]] [[#Const0]] [[#Const1]] +; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy3:]] [[#Int16Ty]] [[#Const16]] [[#Const16]] [[#Const2:]] [[#Const3:]] [[#Const1:]] [[#Const1:]] +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] +; CHECK-SPIRV: JointMatrixMadINTEL [[#MatTy1]] +; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL + +; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z93__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fiiili(ptr addrspace(1) %add.ptr.i.i, i32 %conv.i, i32 %conv2.i, i32 0, i64 32, i32 0) +; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @"_Z95__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS145__spirv_JointMatrixINTEL__short_16_16_0_3_0_1PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %add.ptr.i.i120, i32 %conv.i, i32 %22, i32 0, i64 32, i32 0) +; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @"_Z95__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS145__spirv_JointMatrixINTEL__short_16_16_2_3_1_1PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %add.ptr.i.i128, i32 %23, i32 %conv2.i60, i32 2, i64 64, i32 0) +; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELPU3AS145__spirv_JointMatrixINTEL__short_16_16_0_3_0_1PU3AS145__spirv_JointMatrixINTEL__short_16_16_2_3_1_1PU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2i(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) %call3.i50, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) %call3.i61, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) %sub_c.i.sroa.0.0, i32 3) +; CHECK-LLVM: call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELPU3AS1fiiPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2ili(ptr addrspace(1) %add.ptr.i.i, i32 %conv.i, i32 %conv2.i, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) %sub_c.i.sroa.0.0, i32 0, i64 32, i32 0) + +; ModuleID = 'joint_matrix_bfloat16.cpp' +source_filename = "joint_matrix_bfloat16.cpp" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" } +%"class.sycl::_V1::detail::array" = type { [2 x i64] } +%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" } +%"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 } + +$_ZTS7imatrixIfLm16ELm16ELm16EE = comdat any + +@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 +@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 + +; Function Attrs: convergent mustprogress norecurse nounwind +define weak_odr dso_local spir_kernel void @_ZTS7imatrixIfLm16ELm16ELm16EE(ptr addrspace(1) noundef align 4 %_arg_accC, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accC2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accC3, i64 noundef %_arg_sg_size, ptr addrspace(1) noundef align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA6, ptr addrspace(1) noundef align 2 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB8, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB9) local_unnamed_addr #0 comdat { +entry: + %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accC2, i64 8 + %agg.tmp11.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx, align 8 + %agg.tmp12.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accC3, align 8 + %agg.tmp12.sroa.0.sroa.2.0._arg_accC3.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accC3, i64 8 + %agg.tmp12.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp12.sroa.0.sroa.2.0._arg_accC3.ascast.sroa_idx, align 8 + %mul.i6.i.i.i.i = mul i64 %agg.tmp12.sroa.0.sroa.0.0.copyload, %agg.tmp11.sroa.0.sroa.2.0.copyload + %0 = getelementptr float, ptr addrspace(1) %_arg_accC, i64 %mul.i6.i.i.i.i + %add.ptr.i = getelementptr float, ptr addrspace(1) %0, i64 %agg.tmp12.sroa.0.sroa.2.0.copyload + %agg.tmp15.sroa.0.sroa.2.0._arg_accA5.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA5, i64 8 + %agg.tmp15.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp15.sroa.0.sroa.2.0._arg_accA5.ascast.sroa_idx, align 8 + %agg.tmp16.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accA6, align 8 + %agg.tmp16.sroa.0.sroa.2.0._arg_accA6.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accA6, i64 8 + %agg.tmp16.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp16.sroa.0.sroa.2.0._arg_accA6.ascast.sroa_idx, align 8 + %mul.i6.i.i.i.i91 = mul i64 %agg.tmp16.sroa.0.sroa.0.0.copyload, %agg.tmp15.sroa.0.sroa.2.0.copyload + %1 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %_arg_accA, i64 %mul.i6.i.i.i.i91 + %add.ptr.i92 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %1, i64 %agg.tmp16.sroa.0.sroa.2.0.copyload + %agg.tmp19.sroa.0.sroa.2.0._arg_accB8.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accB8, i64 8 + %agg.tmp19.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp19.sroa.0.sroa.2.0._arg_accB8.ascast.sroa_idx, align 8 + %agg.tmp20.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accB9, align 8 + %agg.tmp20.sroa.0.sroa.2.0._arg_accB9.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accB9, i64 8 + %agg.tmp20.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp20.sroa.0.sroa.2.0._arg_accB9.ascast.sroa_idx, align 8 + %mul.i6.i.i.i.i107 = mul i64 %agg.tmp20.sroa.0.sroa.0.0.copyload, %agg.tmp19.sroa.0.sroa.2.0.copyload + %2 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %_arg_accB, i64 %mul.i6.i.i.i.i107 + %add.ptr.i108 = getelementptr %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %2, i64 %agg.tmp20.sroa.0.sroa.2.0.copyload + %3 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, i64 8), align 8 + %cmp.i28 = icmp ult i64 %3, 2147483648 + tail call void @llvm.assume(i1 %cmp.i28) + %4 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32 + %cmp.i24 = icmp ult i64 %4, 2147483648 + tail call void @llvm.assume(i1 %cmp.i24) + %5 = load i64, ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, i64 8), align 8 + %cmp.i35 = icmp ult i64 %5, 2147483648 + tail call void @llvm.assume(i1 %cmp.i35) + %sub.i = sub nsw i64 %3, %5 + %6 = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32 + %cmp.i31 = icmp ult i64 %6, 2147483648 + tail call void @llvm.assume(i1 %cmp.i31) + %sub5.i = sub nsw i64 %4, %6 + %add.i7.i.i.i.i.i = add i64 %mul.i6.i.i.i.i, %agg.tmp12.sroa.0.sroa.2.0.copyload + %idx.neg.i.i = sub i64 0, %add.i7.i.i.i.i.i + %add.ptr.i.i = getelementptr inbounds float, ptr addrspace(1) %add.ptr.i, i64 %idx.neg.i.i + %div.i = udiv i64 %sub5.i, %_arg_sg_size + %sub.i.tr = trunc i64 %sub.i to i32 + %conv.i = shl i32 %sub.i.tr, 4 + %div.i.tr = trunc i64 %div.i to i32 + %conv2.i = shl i32 %div.i.tr, 4 + %call4.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 + %add.i7.i.i.i.i.i118 = add i64 %mul.i6.i.i.i.i91, %agg.tmp16.sroa.0.sroa.2.0.copyload + %idx.neg.i.i119 = sub i64 0, %add.i7.i.i.i.i.i118 + %add.ptr.i.i120 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i92, i64 %idx.neg.i.i119 + %add.i7.i.i.i.i.i126 = add i64 %mul.i6.i.i.i.i107, %agg.tmp20.sroa.0.sroa.2.0.copyload + %idx.neg.i.i127 = sub i64 0, %add.i7.i.i.i.i.i126 + %add.ptr.i.i128 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i108, i64 %idx.neg.i.i127 + %conv2.i60 = shl i32 %div.i.tr, 5 + br label %for.cond.i + +for.cond.i: ; preds = %for.body.i, %entry + %sub_c.i.sroa.0.0 = phi target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) [ %call4.i, %entry ], [ %call.i63, %for.body.i ] + %k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ] + %cmp.i = icmp ult i32 %k.0.i, 2 + br i1 %cmp.i, label %for.body.i, label %_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm32ELm32ELm32ELm16ELm16ELm16EEvR10big_matrixIT_XT1_EXT2_EERS5_IT0_XT1_EXT3_EERS5_IS9_XdvT3_Li2EEXmlT2_Li2EEEENKUlRNS1_7handlerEE_clESF_ENKUlNS1_7nd_itemILi2EEEE_clESI_.exit + +for.body.i: ; preds = %for.cond.i + %7 = shl nuw nsw i32 %k.0.i, 4 + %call3.i50 = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i120, i32 noundef %conv.i, i32 noundef %7, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 + %8 = shl nuw nsw i32 %k.0.i, 3 + %call3.i61 = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i128, i32 noundef %8, i32 noundef %conv2.i60, i32 noundef 2, i64 noundef 64, i32 noundef 0) #3 + %call.i63 = tail call spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm16ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_2ELS7_3ELNS5_5Scope4FlagE3EEPNS5_24__spirv_JointMatrixINTELIT1_XT2_EXT4_EXT10_EXT11_EXT7_EEEPNSA_IT_XT2_EXT3_EXT8_EXT11_EXT5_EEEPNSA_IT0_XT3_EXT4_EXT9_EXT11_EXT6_EEESD_S9_(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) noundef %call3.i50, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) noundef %call3.i61, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 3) #3 + %add.i = add nuw nsw i32 %k.0.i, 1 + br label %for.cond.i + +_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm32ELm32ELm32ELm16ELm16ELm16EEvR10big_matrixIT_XT1_EXT2_EERS5_IT0_XT1_EXT3_EERS5_IS9_XdvT3_Li2EEXmlT2_Li2EEEENKUlRNS1_7handlerEE_clESF_ENKUlNS1_7nd_itemILi2EEEE_clESI_.exit: ; preds = %for.cond.i + tail call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEES3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 + ret void +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) +declare void @llvm.assume(i1 noundef) #1 + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 + +; Function Attrs: convergent nounwind +declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm16ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_2ELS7_3ELNS5_5Scope4FlagE3EEPNS5_24__spirv_JointMatrixINTELIT1_XT2_EXT4_EXT10_EXT11_EXT7_EEEPNSA_IT_XT2_EXT3_EXT8_EXT11_EXT5_EEEPNSA_IT0_XT3_EXT4_EXT9_EXT11_EXT6_EEESD_S9_(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) noundef, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) noundef, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef, i32 noundef) local_unnamed_addr #2 + +; Function Attrs: convergent nounwind +declare dso_local spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEES3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 + +attributes #0 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="joint_matrix_bfloat16.cpp" "sycl-optlevel"="2" "uniform-work-group-size"="true" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) } +attributes #2 = { convergent nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #3 = { convergent nounwind } From f6b02f7184d17bdb5717995af8926df753cf89ac Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 20 Nov 2024 10:45:02 -0800 Subject: [PATCH 02/11] code clean up and resolve comment --- lib/SPIRV/SPIRVReader.cpp | 5 +++++ lib/SPIRV/libSPIRV/SPIRVInstruction.h | 6 +++--- .../joint_matrix_load_store_offset.ll | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 0c4f537938..d971d74ffd 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -3685,6 +3685,11 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB) { assert(BB && "Invalid BB"); const auto OC = BI->getOpCode(); + int i = 1; + // if(static_cast(OC) == internal::OpCooperativeMatrixLoadOffsetINTEL) { + // i ++; + // } + // i++; bool AddRetTypePostfix = false; switch (static_cast(OC)) { diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 0a6648b793..0085c5a4c5 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3730,8 +3730,8 @@ _SPIRV_OP(CooperativeMatrixStoreChecked, false, 8, true, 8) _SPIRV_OP(CooperativeMatrixConstructChecked, true, 8) #undef _SPIRV_OP -class SPIRVJointMatrixOffsetInstructionsINTELInstBase - : public SPIRVInstTemplateBase { +class SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase + : public SPIRVInstTemplateBase { protected: std::optional getRequiredExtension() const override { return ExtensionID::SPV_INTEL_joint_matrix; @@ -3744,7 +3744,7 @@ class SPIRVJointMatrixOffsetInstructionsINTELInstBase #define _SPIRV_OP(x, ...) \ typedef SPIRVInstTemplate< \ - SPIRVJointMatrixOffsetInstructionsINTELInstBase, \ + SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase, \ internal::Op##x##INTEL, __VA_ARGS__> \ SPIRV##x##INTEL; _SPIRV_OP(CooperativeMatrixLoadOffset, true, 8, true, 6) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index 34dc472242..a71e30ef28 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -25,7 +25,7 @@ ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] -; CHECK-SPIRV: JointMatrixMadINTEL [[#MatTy1]] +; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL ; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z93__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fiiili(ptr addrspace(1) %add.ptr.i.i, i32 %conv.i, i32 %conv2.i, i32 0, i64 32, i32 0) From 183bd2d3e93be899a0c8e8bda44d99857e65c4b0 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 20 Nov 2024 10:54:30 -0800 Subject: [PATCH 03/11] code clean up --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 0085c5a4c5..c5e5be3964 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3744,7 +3744,7 @@ class SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase #define _SPIRV_OP(x, ...) \ typedef SPIRVInstTemplate< \ - SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase, \ + SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase, \ internal::Op##x##INTEL, __VA_ARGS__> \ SPIRV##x##INTEL; _SPIRV_OP(CooperativeMatrixLoadOffset, true, 8, true, 6) From e98123dfc9cb63cd8915e3b1a6db0615508fca52 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 20 Nov 2024 10:58:47 -0800 Subject: [PATCH 04/11] code clean up --- lib/SPIRV/SPIRVReader.cpp | 5 ----- .../SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index d971d74ffd..0c4f537938 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -3685,11 +3685,6 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB) { assert(BB && "Invalid BB"); const auto OC = BI->getOpCode(); - int i = 1; - // if(static_cast(OC) == internal::OpCooperativeMatrixLoadOffsetINTEL) { - // i ++; - // } - // i++; bool AddRetTypePostfix = false; switch (static_cast(OC)) { diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index a71e30ef28..b12eb3fe42 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -1,4 +1,4 @@ -; This is an adapted copy of test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll +; This is an adapted copy of test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll ; RUN: llvm-as < %s -o %t.bc ; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix -o %t.spv From a8409a46479e1ce5b3479f981de83daefaf11204 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Fri, 6 Dec 2024 08:11:10 -0800 Subject: [PATCH 05/11] update the tokens number --- lib/SPIRV/libSPIRV/spirv_internal.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index 391f436207..d7367c3fd5 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -77,8 +77,8 @@ enum InternalOp { IOpCooperativeMatrixLoadCheckedINTEL = 6193, IOpCooperativeMatrixStoreCheckedINTEL = 6194, IOpCooperativeMatrixConstructCheckedINTEL = 6195, - IOpCooperativeMatrixLoadOffsetINTEL = 6212, - IOpCooperativeMatrixStoreOffsetINTEL = 6213, + IOpCooperativeMatrixLoadOffsetINTEL = 6239, + IOpCooperativeMatrixStoreOffsetINTEL = 6240, IOpJointMatrixWorkItemLengthINTEL = 6410, IOpTypeTaskSequenceINTEL = 6199, IOpComplexFMulINTEL = 6415, @@ -116,7 +116,7 @@ enum InternalCapability { ICapGlobalVariableDecorationsINTEL = 6146, ICapabilityTaskSequenceINTEL = 6162, ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192, - ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6211, + ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6238, ICapabilityCooperativeMatrixPrefetchINTEL = 6411, ICapabilityComplexFloatMulDivINTEL = 6414, ICapabilityTensorFloat32RoundingINTEL = 6425, From d999545cc5e74e9d17100aef86deee0d0b0622b8 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sun, 8 Dec 2024 17:50:26 -0800 Subject: [PATCH 06/11] update the test case --- .../joint_matrix_load_store_offset.ll | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index b12eb3fe42..4eee6c9195 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -19,23 +19,23 @@ ; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0 ; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1 ; CHECK-SPIRV-DAG: TypeFloat [[#Float32Ty:]] 32 -; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy1:]] [[#Float32Ty]] [[#Const16]] [[#Const16]] [[#Const3]] [[#Const3]] [[#Const2]] -; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy2:]] [[#Int16Ty]] [[#Const16]] [[#Const16]] [[#Const0]] [[#Const3]] [[#Const0]] [[#Const1]] -; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy3:]] [[#Int16Ty]] [[#Const16]] [[#Const16]] [[#Const2:]] [[#Const3:]] [[#Const1:]] [[#Const1:]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Float32Ty]] [[#Const3]] [[#Const1]] [[#Const16]] [[#Const2]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int16Ty]] [[#Const3]] [[#Const1]] [[#Const16]] [[#Const0:]] +; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int16Ty]] [[#Const3]] [[#Const16]] [[#Const16]] [[#Const1]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL -; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z93__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fiiili(ptr addrspace(1) %add.ptr.i.i, i32 %conv.i, i32 %conv2.i, i32 0, i64 32, i32 0) -; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @"_Z95__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS145__spirv_JointMatrixINTEL__short_16_16_0_3_0_1PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %add.ptr.i.i120, i32 %conv.i, i32 %22, i32 0, i64 32, i32 0) -; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @"_Z95__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS145__spirv_JointMatrixINTEL__short_16_16_2_3_1_1PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %add.ptr.i.i128, i32 %23, i32 %conv2.i60, i32 2, i64 64, i32 0) -; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELPU3AS145__spirv_JointMatrixINTEL__short_16_16_0_3_0_1PU3AS145__spirv_JointMatrixINTEL__short_16_16_2_3_1_1PU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2i(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) %call3.i50, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) %call3.i61, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) %sub_c.i.sroa.0.0, i32 3) -; CHECK-LLVM: call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELPU3AS1fiiPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2ili(ptr addrspace(1) %add.ptr.i.i, i32 %conv.i, i32 %conv2.i, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) %sub_c.i.sroa.0.0, i32 0, i64 32, i32 0) +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z94__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__float_3_1_16_2PU3AS1fiiili(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 0, i64 128, i32 0) +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @"_Z94__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__short_3_1_16_0PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 0, i64 128, i32 0) +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @"_Z95__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS145__spirv_CooperativeMatrixKHR__short_3_16_16_1PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 2, i64 256, i32 0) +; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRPU3AS144__spirv_CooperativeMatrixKHR__short_3_1_16_0PU3AS145__spirv_CooperativeMatrixKHR__short_3_16_16_1PU3AS144__spirv_CooperativeMatrixKHR__float_3_1_16_2i(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) %{{.*}}, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) %{{.*}}, i32 64) +; CHECK-LLVM: call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELPU3AS1fiiPU3AS144__spirv_CooperativeMatrixKHR__float_3_1_16_2ili(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) %{{.*}}, i32 0, i64 128, i32 0) -; ModuleID = 'joint_matrix_bfloat16.cpp' -source_filename = "joint_matrix_bfloat16.cpp" +; ModuleID = 'joint_matrix_all_sizes.cpp' +source_filename = "joint_matrix_all_sizes.cpp" target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" target triple = "spir64-unknown-unknown" @@ -44,13 +44,13 @@ target triple = "spir64-unknown-unknown" %"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" } %"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 } -$_ZTS7imatrixIfLm16ELm16ELm16EE = comdat any +$_ZTSZZ15matrix_multiply = comdat any @__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 @__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 ; Function Attrs: convergent mustprogress norecurse nounwind -define weak_odr dso_local spir_kernel void @_ZTS7imatrixIfLm16ELm16ELm16EE(ptr addrspace(1) noundef align 4 %_arg_accC, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accC2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accC3, i64 noundef %_arg_sg_size, ptr addrspace(1) noundef align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA6, ptr addrspace(1) noundef align 2 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB8, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB9) local_unnamed_addr #0 comdat { +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 4 %_arg_accC, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accC2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accC3, i64 noundef %_arg_sg_size, ptr addrspace(1) noundef readonly align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA6, ptr addrspace(1) noundef readonly align 2 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB8, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB9) local_unnamed_addr #0 comdat { entry: %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accC2, i64 8 %agg.tmp11.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx, align 8 @@ -94,11 +94,10 @@ entry: %idx.neg.i.i = sub i64 0, %add.i7.i.i.i.i.i %add.ptr.i.i = getelementptr inbounds float, ptr addrspace(1) %add.ptr.i, i64 %idx.neg.i.i %div.i = udiv i64 %sub5.i, %_arg_sg_size - %sub.i.tr = trunc i64 %sub.i to i32 - %conv.i = shl i32 %sub.i.tr, 4 + %conv.i = trunc i64 %sub.i to i32 %div.i.tr = trunc i64 %div.i to i32 %conv2.i = shl i32 %div.i.tr, 4 - %call4.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 + %call4.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 %add.i7.i.i.i.i.i118 = add i64 %mul.i6.i.i.i.i91, %agg.tmp16.sroa.0.sroa.2.0.copyload %idx.neg.i.i119 = sub i64 0, %add.i7.i.i.i.i.i118 %add.ptr.i.i120 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i92, i64 %idx.neg.i.i119 @@ -109,22 +108,22 @@ entry: br label %for.cond.i for.cond.i: ; preds = %for.body.i, %entry - %sub_c.i.sroa.0.0 = phi target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) [ %call4.i, %entry ], [ %call.i63, %for.body.i ] + %sub_c.i.sroa.0.0 = phi target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) [ %call4.i, %entry ], [ %call.i63, %for.body.i ] %k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ] - %cmp.i = icmp ult i32 %k.0.i, 2 - br i1 %cmp.i, label %for.body.i, label %_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm32ELm32ELm32ELm16ELm16ELm16EEvR10big_matrixIT_XT1_EXT2_EERS5_IT0_XT1_EXT3_EERS5_IS9_XdvT3_Li2EEXmlT2_Li2EEEENKUlRNS1_7handlerEE_clESF_ENKUlNS1_7nd_itemILi2EEEE_clESI_.exit + %cmp.i = icmp samesign ult i32 %k.0.i, 8 + br i1 %cmp.i, label %for.body.i, label %_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm16ELm128ELm128ELi2ELm1ELm16ELm16E4multIS4_Lm1ELm16ELm16EEEvR10big_matrixIT_XT1_EXT2_EERS7_IT0_XT1_EXT3_EERS7_ISB_XdvT3_T4_EXmlT2_T4_EEENKUlRNS1_7handlerEE_clESH_ENKUlNS1_7nd_itemILi2EEEE_clESK_.exit for.body.i: ; preds = %for.cond.i %7 = shl nuw nsw i32 %k.0.i, 4 - %call3.i50 = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i120, i32 noundef %conv.i, i32 noundef %7, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 + %call3.i50 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i120, i32 noundef %conv.i, i32 noundef %7, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 %8 = shl nuw nsw i32 %k.0.i, 3 - %call3.i61 = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i128, i32 noundef %8, i32 noundef %conv2.i60, i32 noundef 2, i64 noundef 64, i32 noundef 0) #3 - %call.i63 = tail call spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm16ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_2ELS7_3ELNS5_5Scope4FlagE3EEPNS5_24__spirv_JointMatrixINTELIT1_XT2_EXT4_EXT10_EXT11_EXT7_EEEPNSA_IT_XT2_EXT3_EXT8_EXT11_EXT5_EEEPNSA_IT0_XT3_EXT4_EXT9_EXT11_EXT6_EEESD_S9_(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) noundef %call3.i50, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) noundef %call3.i61, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 3) #3 + %call3.i61 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i128, i32 noundef %8, i32 noundef %conv2.i60, i32 noundef 2, i64 noundef 256, i32 noundef 0) #3 + %call.i63 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef %call3.i50, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef %call3.i61, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i64 noundef 64) #3 %add.i = add nuw nsw i32 %k.0.i, 1 br label %for.cond.i -_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm32ELm32ELm32ELm16ELm16ELm16EEvR10big_matrixIT_XT1_EXT2_EERS5_IT0_XT1_EXT3_EERS5_IS9_XdvT3_Li2EEXmlT2_Li2EEEENKUlRNS1_7handlerEE_clESF_ENKUlNS1_7nd_itemILi2EEEE_clESI_.exit: ; preds = %for.cond.i - tail call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEES3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 0, i64 noundef 32, i32 noundef 0) #3 +_ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm16ELm128ELm128ELi2ELm1ELm16ELm16E4multIS4_Lm1ELm16ELm16EEEvR10big_matrixIT_XT1_EXT2_EERS7_IT0_XT1_EXT3_EERS7_ISB_XdvT3_T4_EXmlT2_T4_EEENKUlRNS1_7handlerEE_clESH_ENKUlNS1_7nd_itemILi2EEEE_clESK_.exit: ; preds = %for.cond.i + tail call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 ret void } @@ -132,21 +131,21 @@ _ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm32ELm32ELm32ELm16ELm16EL declare void @llvm.assume(i1 noundef) #1 ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1N4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm16ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_2ELS7_3ELNS5_5Scope4FlagE3EEPNS5_24__spirv_JointMatrixINTELIT1_XT2_EXT4_EXT10_EXT11_EXT7_EEEPNSA_IT_XT2_EXT3_EXT8_EXT11_EXT5_EEEPNSA_IT0_XT3_EXT4_EXT9_EXT11_EXT6_EEESD_S9_(target("spirv.JointMatrixINTEL", i16, 16, 16, 0, 3, 0, 1) noundef, target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1, 1) noundef, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i64 noundef) local_unnamed_addr #2 ; Function Attrs: convergent nounwind -declare dso_local spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_24__spirv_JointMatrixINTELIT0_XT1_EXT2_EXT4_EXT5_EXT3_EEES3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 -attributes #0 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="joint_matrix_bfloat16.cpp" "sycl-optlevel"="2" "uniform-work-group-size"="true" } +attributes #0 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../joint_matrix_all_sizes.cpp" "sycl-optlevel"="2" "uniform-work-group-size"="true" } attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) } attributes #2 = { convergent nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } attributes #3 = { convergent nounwind } From 9d9615b958a24a22657830a12751f38e79b50ba7 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sun, 8 Dec 2024 18:07:57 -0800 Subject: [PATCH 07/11] clean up test --- .../joint_matrix_load_store_offset.ll | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index 4eee6c9195..310ba45c3c 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -50,7 +50,7 @@ $_ZTSZZ15matrix_multiply = comdat any @__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 ; Function Attrs: convergent mustprogress norecurse nounwind -define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 4 %_arg_accC, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accC2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accC3, i64 noundef %_arg_sg_size, ptr addrspace(1) noundef readonly align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA6, ptr addrspace(1) noundef readonly align 2 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB8, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB9) local_unnamed_addr #0 comdat { +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 4 %_arg_accC, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accC2, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accC3, i64 noundef %_arg_sg_size, ptr addrspace(1) noundef readonly align 2 %_arg_accA, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accA5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accA6, ptr addrspace(1) noundef readonly align 2 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB8, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB9) comdat { entry: %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx = getelementptr inbounds i8, ptr %_arg_accC2, i64 8 %agg.tmp11.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp11.sroa.0.sroa.2.0._arg_accC2.ascast.sroa_idx, align 8 @@ -97,7 +97,7 @@ entry: %conv.i = trunc i64 %sub.i to i32 %div.i.tr = trunc i64 %div.i to i32 %conv2.i = shl i32 %div.i.tr, 4 - %call4.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 + %call4.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, i32 noundef 0, i64 noundef 128, i32 noundef 0) %add.i7.i.i.i.i.i118 = add i64 %mul.i6.i.i.i.i91, %agg.tmp16.sroa.0.sroa.2.0.copyload %idx.neg.i.i119 = sub i64 0, %add.i7.i.i.i.i.i118 %add.ptr.i.i120 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i92, i64 %idx.neg.i.i119 @@ -115,37 +115,32 @@ for.cond.i: ; preds = %for.body.i, %entry for.body.i: ; preds = %for.cond.i %7 = shl nuw nsw i32 %k.0.i, 4 - %call3.i50 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i120, i32 noundef %conv.i, i32 noundef %7, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 + %call3.i50 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i120, i32 noundef %conv.i, i32 noundef %7, i32 noundef 0, i64 noundef 128, i32 noundef 0) %8 = shl nuw nsw i32 %k.0.i, 3 - %call3.i61 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i128, i32 noundef %8, i32 noundef %conv2.i60, i32 noundef 2, i64 noundef 256, i32 noundef 0) #3 - %call.i63 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef %call3.i50, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef %call3.i61, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i64 noundef 64) #3 + %call3.i61 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef %add.ptr.i.i128, i32 noundef %8, i32 noundef %conv2.i60, i32 noundef 2, i64 noundef 256, i32 noundef 0) + %call.i63 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef %call3.i50, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef %call3.i61, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i64 noundef 64) %add.i = add nuw nsw i32 %k.0.i, 1 br label %for.cond.i _ZZZ15matrix_multiplyIfN4sycl3_V13ext6oneapi8bfloat16ELm16ELm128ELm128ELi2ELm1ELm16ELm16E4multIS4_Lm1ELm16ELm16EEEvR10big_matrixIT_XT1_EXT2_EERS7_IT0_XT1_EXT3_EERS7_ISB_XdvT3_T4_EXmlT2_T4_EEENKUlRNS1_7handlerEE_clESH_ENKUlNS1_7nd_itemILi2EEEE_clESK_.exit: ; preds = %for.cond.i - tail call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 0, i64 noundef 128, i32 noundef 0) #3 + tail call spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef %add.ptr.i.i, i32 noundef %conv.i, i32 noundef %conv2.i, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef %sub_c.i.sroa.0.0, i32 noundef 0, i64 noundef 128, i32 noundef 0) ret void } ; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) -declare void @llvm.assume(i1 noundef) #1 +declare void @llvm.assume(i1 noundef) ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm1ELm16ELN5__spv9MatrixUseE0ELNS6_12MatrixLayoutE0ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) @_Z40__spirv_CooperativeMatrixLoadOffsetINTELIU3AS1KN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm16ELN5__spv9MatrixUseE1ELNS6_12MatrixLayoutE2ELNS6_5Scope4FlagE3EEPNS6_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_iiS8_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) ; Function Attrs: convergent nounwind -declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i64 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRIN4sycl3_V13ext6oneapi8bfloat16ES4_fLm1ELm16ELm16ELN5__spv9MatrixUseE0ELS6_1ELS6_2ELNS5_12MatrixLayoutE0ELS7_0ELS7_0ELNS5_5Scope4FlagE3EEPNS5_28__spirv_CooperativeMatrixKHRIT1_XT11_EXT2_EXT4_EXT7_EEEPNSA_IT_XT11_EXT2_EXT3_EXT5_EEEPNSA_IT0_XT11_EXT3_EXT4_EXT6_EEESD_m(target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) noundef, target("spirv.CooperativeMatrixKHR", i16, 3, 16, 16, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i64 noundef) ; Function Attrs: convergent nounwind -declare dso_local spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2 - -attributes #0 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../joint_matrix_all_sizes.cpp" "sycl-optlevel"="2" "uniform-work-group-size"="true" } -attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) } -attributes #2 = { convergent nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } -attributes #3 = { convergent nounwind } +declare dso_local spir_func void @_Z41__spirv_CooperativeMatrixStoreOffsetINTELIU3AS1ffLm1ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEvPT_iiPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEES3_mi(ptr addrspace(1) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) noundef, i32 noundef, i64 noundef, i32 noundef) From 3fbd371c8a3c4a6bca7cdb4ce26629152ea01e9f Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sat, 14 Dec 2024 05:53:51 -0800 Subject: [PATCH 08/11] code clean up --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 3 +-- lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index c5e5be3964..81a154943a 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3737,8 +3737,7 @@ class SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase return ExtensionID::SPV_INTEL_joint_matrix; } SPIRVCapVec getRequiredCapability() const override { - return getVec( - internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL); + return getVec(internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL); } }; diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index 754574c4ac..5317be2f87 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -22,12 +22,12 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL, internal::OpCooperativeMatrixLoadCheckedINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixStoreCheckedINTEL, internal::OpCooperativeMatrixStoreCheckedINTEL) +_SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL, + internal::OpCooperativeMatrixConstructCheckedINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixLoadOffsetINTEL, internal::OpCooperativeMatrixLoadOffsetINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixStoreOffsetINTEL, internal::OpCooperativeMatrixStoreOffsetINTEL) -_SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL, - internal::OpCooperativeMatrixConstructCheckedINTEL) _SPIRV_OP_INTERNAL(CooperativeMatrixApplyFunctionINTEL, internal::OpCooperativeMatrixApplyFunctionINTEL) _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL) From e22918ebe86b39eb215f695173c6e0c140a9aa1b Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 17 Dec 2024 08:00:56 -0800 Subject: [PATCH 09/11] update the number of literal oprand --- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 81a154943a..ab40d08297 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3746,8 +3746,8 @@ class SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase SPIRVCooperativeMatrixOffsetInstructionsINTELInstBase, \ internal::Op##x##INTEL, __VA_ARGS__> \ SPIRV##x##INTEL; -_SPIRV_OP(CooperativeMatrixLoadOffset, true, 8, true, 6) -_SPIRV_OP(CooperativeMatrixStoreOffset, false, 7, true, 7) +_SPIRV_OP(CooperativeMatrixLoadOffset, true, 8, true, 5) +_SPIRV_OP(CooperativeMatrixStoreOffset, false, 7, true, 6) #undef _SPIRV_OP class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase From 0f4773c5ca12be617517ee4562411dd2de5b6535 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 18 Dec 2024 07:01:23 -0800 Subject: [PATCH 10/11] update the test --- .../joint_matrix_load_store_offset.ll | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index 310ba45c3c..20945322c6 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -13,20 +13,23 @@ ; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" ; CHECK-SPIRV-DAG: TypeInt [[#Int16Ty:]] 16 0 ; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0 -; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const16:]] 16 -; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3 -; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2 +; CHECK-SPIRV-DAG: TypeInt [[#Int64Ty:]] 64 0 ; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0 ; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3 +; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const16:]] 16 +; CHECK-SPIRV-DAG: Constant [[#Int64Ty]] [[#Const128:]] 128 0 +; CHECK-SPIRV-DAG: Constant [[#Int64Ty:]] [[#Const256:]] 256 0 ; CHECK-SPIRV-DAG: TypeFloat [[#Float32Ty:]] 32 ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Float32Ty]] [[#Const3]] [[#Const1]] [[#Const16]] [[#Const2]] ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int16Ty]] [[#Const3]] [[#Const1]] [[#Const16]] [[#Const0:]] ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int16Ty]] [[#Const3]] [[#Const16]] [[#Const16]] [[#Const1]] -; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] -; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] -; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] -; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] -; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] [[#]] [[#Ptr1:]] [[#]] [[#Index1:]] [[#Const0]] [[#Const128]] 0 +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] [[#Ptr2:]] [[#]] [[#Index2:]] [[#]] [[#Const0]] [[#Const128]] 0 +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] [[#Ptr3:]] [[#]] [[#]] [[#]] [[#Const2:]] [[#Const256:]] 0 +; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] [[#]] [[#Ptr2]] [[#Ptr3]] [[#Result:]] [[#]] +; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL [[#Ptr1]] [[#Index2]] [[#Index1]] [[#Result]] [[#Const0]] [[#Const128]] 0 ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z94__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__float_3_1_16_2PU3AS1fiiili(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 0, i64 128, i32 0) ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 1, 16, 0) @"_Z94__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__short_3_1_16_0PU3AS138class.sycl::_V1::ext::oneapi::bfloat16iiili"(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 0, i64 128, i32 0) From cedd51345443f1f0bc7945e66e9b8b38036f67a9 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 18 Dec 2024 07:38:29 -0800 Subject: [PATCH 11/11] apply the suggested change --- .../joint_matrix_load_store_offset.ll | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll index 20945322c6..af6eb15b8a 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix_load_store_offset.ll @@ -26,9 +26,9 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int16Ty]] [[#Const3]] [[#Const1]] [[#Const16]] [[#Const0:]] ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int16Ty]] [[#Const3]] [[#Const16]] [[#Const16]] [[#Const1]] ; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy1]] [[#]] [[#Ptr1:]] [[#]] [[#Index1:]] [[#Const0]] [[#Const128]] 0 -; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] [[#Ptr2:]] [[#]] [[#Index2:]] [[#]] [[#Const0]] [[#Const128]] 0 -; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] [[#Ptr3:]] [[#]] [[#]] [[#]] [[#Const2:]] [[#Const256:]] 0 -; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] [[#]] [[#Ptr2]] [[#Ptr3]] [[#Result:]] [[#]] +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy2]] [[#Load2:]] [[#]] [[#Index2:]] [[#]] [[#Const0]] [[#Const128]] 0 +; CHECK-SPIRV: CooperativeMatrixLoadOffsetINTEL [[#MatTy3]] [[#Load3:]] [[#]] [[#]] [[#]] [[#Const2:]] [[#Const256:]] 0 +; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] [[#]] [[#Load2]] [[#Load3]] [[#Result:]] 64 ; CHECK-SPIRV: CooperativeMatrixStoreOffsetINTEL [[#Ptr1]] [[#Index2]] [[#Index1]] [[#Result]] [[#Const0]] [[#Const128]] 0 ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 1, 16, 2) @_Z94__spirv_CooperativeMatrixLoadOffsetINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__float_3_1_16_2PU3AS1fiiili(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 0, i64 128, i32 0)