Skip to content

Conversation

@aadeshps-mcw
Copy link
Contributor

--Added support for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate

--Added test files for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate

…accumulate

--Added test files for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate
@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Aadesh PremKumar (aadeshps-mcw)

Changes

--Added support for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate

--Added test files for the extension SPV_INTEL_subgroup_matrix_multiply_accumulate


Patch is 41.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131572.diff

9 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h (+5)
  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+34)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+37-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+4-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+13)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+41)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll (+261)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index d009244a92259..66c58ba9d6ba3 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -52,6 +52,11 @@ namespace ExecutionModel {
 #include "SPIRVGenTables.inc"
 } // namespace ExecutionModel
 
+namespace MatrixMultiplyAccumulate {
+#define GET_MatrixMultiplyAccumulate_DECL
+#include "SPIRVGenTables.inc"
+} // namespace MatrixMultiplyAccumulate
+
 namespace MemoryModel {
 #define GET_MemoryModel_DECL
 #include "SPIRVGenTables.inc"
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 63dcf0067b515..b77637bdeb869 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -237,6 +237,40 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
           }
           break;
         }
+        case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
+          const unsigned NumOps = MI->getNumOperands();
+          if (NumFixedOps == NumOps)
+            break; // No extra operands, so no flags to process
+
+          OS << ' ';
+
+          // Extract the last operand only if it exists
+          if (NumOps > NumFixedOps) {
+            const unsigned Flags = MI->getOperand(NumOps - 1).getImm();
+
+            if (Flags == 0) {
+              printSymbolicOperand<
+                  OperandCategory::MatrixMultiplyAccumulateOperand>(
+                  MI, NumOps - 1, OS);
+            } else {
+              std::string Buffer;
+              for (unsigned Mask = 0x1;
+                   Mask != SPIRV::MatrixMultiplyAccumulate::
+                               MatrixBPackedBFloat16INTEL; // Replace with
+                                                           // actual last flag
+                   Mask <<= 1) {
+                if (Flags & Mask) {
+                  if (!Buffer.empty())
+                    Buffer += '|';
+                  Buffer += getSymbolicOperandMnemonic(
+                      OperandCategory::MatrixMultiplyAccumulateOperand, Mask);
+                }
+              }
+              OS << Buffer;
+            }
+          }
+          break;
+        }
         default:
           printRemainingVariableOps(MI, NumFixedOps, OS);
           break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 579e37f68d5d8..d5f4927e866ac 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -701,7 +701,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
+                              Register(0));
 
   Register ScopeRegister =
       buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -2266,6 +2267,38 @@ static bool generateBindlessImageINTELInst(const SPIRV::IncomingCall *Call,
   return buildBindlessImageINTELInst(Call, Opcode, MIRBuilder, GR);
 }
 
+static bool
+generateSubgroupMatrixMultiplyAccumulateInst(const SPIRV::IncomingCall *Call,
+                                             MachineIRBuilder &MIRBuilder,
+                                             SPIRVGlobalRegistry *GR) {
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  unsigned Opcode =
+      SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+
+  auto MIB = MIRBuilder.buildInstr(Opcode);
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+  MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
+
+  size_t size = Call->Arguments.size();
+
+  if (size > 4) {
+    // Add first 4 arguments normally
+    for (size_t i = 0; i < 4; i++) {
+      MIB.addUse(Call->Arguments[i]);
+    }
+    const uint32_t memop = getConstFromIntrinsic(Call->Arguments.back(), MRI);
+    MIB.addImm(memop);
+  } else {
+    // Add all arguments if there are ≤ 4
+    for (size_t i = 0; i < size; i++) {
+      MIB.addUse(Call->Arguments[i]);
+    }
+  }
+
+  return true;
+}
+
 static bool buildNDRange(const SPIRV::IncomingCall *Call,
                          MachineIRBuilder &MIRBuilder,
                          SPIRVGlobalRegistry *GR) {
@@ -2847,6 +2880,9 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateExtendedBitOpsInst(Call.get(), MIRBuilder, GR);
   case SPIRV::BindlessINTEL:
     return generateBindlessImageINTELInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::SubgroupMatrixMultiplyAccumulate:
+    return generateSubgroupMatrixMultiplyAccumulateInst(Call.get(), MIRBuilder,
+                                                        GR);
   }
   return false;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index c9a5c92ee3a66..eab3dec4da25d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -67,6 +67,7 @@ def CoopMatr : BuiltinGroup;
 def ICarryBorrow : BuiltinGroup;
 def ExtendedBitOps : BuiltinGroup;
 def BindlessINTEL : BuiltinGroup;
+def SubgroupMatrixMultiplyAccumulate : BuiltinGroup;
 
 //===----------------------------------------------------------------------===//
 // Class defining a demangled builtin record. The information in the record
@@ -1128,6 +1129,9 @@ defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock,
 defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
 defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
 
+//SPV_INTEL_subgroup_matrix_multiply_accumulate
+defm : DemangledNativeBuiltin<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", OpenCL_std, SubgroupMatrixMultiplyAccumulate, 4, 8, OpSubgroupMatrixMultiplyAccumulateINTEL>;
+
 //===----------------------------------------------------------------------===//
 // Class defining an atomic instruction on floating-point numbers.
 //
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 37119bf01545c..02e4f2241e868 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -92,7 +92,10 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_long_composites",
          SPIRV::Extension::Extension::SPV_INTEL_long_composites},
         {"SPV_INTEL_fp_max_error",
-         SPIRV::Extension::Extension::SPV_INTEL_fp_max_error}};
+         SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
+        {"SPV_INTEL_subgroup_matrix_multiply_accumulate",
+         SPIRV::Extension::Extension::
+             SPV_INTEL_subgroup_matrix_multiply_accumulate}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
                                   llvm::StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a8f862271dbab..2e83f651b3c71 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -956,3 +956,7 @@ def OpAliasScopeDeclINTEL: Op<5912, (outs ID:$res), (ins ID:$AliasDomain, variab
                   "$res = OpAliasScopeDeclINTEL $AliasDomain">;
 def OpAliasScopeListDeclINTEL: Op<5913, (outs ID:$res), (ins variable_ops),
                   "$res = OpAliasScopeListDeclINTEL">;
+
+//SPV_INTEL_subgroup_matrix_multiply_accumulate
+def OpSubgroupMatrixMultiplyAccumulateINTEL:Op<6237, (outs ID:$res), (ins TYPE:$result_type, ID:$dim, ID:$a, ID:$b, ID:$c, variable_ops),
+                  "$res = OpSubgroupMatrixMultiplyAccumulateINTEL $result_type $dim $a $b $c">;
\ No newline at end of file
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 63894acacbc73..4ddb0f1ed1305 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1699,6 +1699,19 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(
         SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
     break;
+  case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
+    if (!ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
+      report_fatal_error("This matrix instructions require the "
+                         "following SPIR-V extension: "
+                         "SPV_INTEL_subgroup_matrix_multiply_accumulate",
+                         false);
+    Reqs.addExtension(
+        SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
+    Reqs.addCapability(
+        SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
+    break;
+  }
   case SPIRV::OpConvertHandleToImageINTEL:
   case SPIRV::OpConvertHandleToSamplerINTEL:
   case SPIRV::OpConvertHandleToSampledImageINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index caee778eddbc4..f1e73eaacd29d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -172,6 +172,7 @@ def KernelProfilingInfoOperand : OperandCategory;
 def OpcodeOperand : OperandCategory;
 def CooperativeMatrixLayoutOperand : OperandCategory;
 def CooperativeMatrixOperandsOperand : OperandCategory;
+def MatrixMultiplyAccumulateOperand :OperandCategory;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Extesions enum values and at the same time
@@ -313,6 +314,7 @@ defm SPV_INTEL_bindless_images : ExtensionOperand<116>;
 defm SPV_INTEL_long_composites : ExtensionOperand<117>;
 defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
 defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
+defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<120>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -512,6 +514,7 @@ defm FunctionFloatControlINTEL : CapabilityOperand<5821, 0, 0, [SPV_INTEL_float_
 defm LongCompositesINTEL : CapabilityOperand<6089, 0, 0, [SPV_INTEL_long_composites], []>;
 defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_images], []>;
 defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
+defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
 defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
 
 //===----------------------------------------------------------------------===//
@@ -1741,3 +1744,41 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
 defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
 defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
 defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define Matrix Multiply Accumulate Operands enum values and at the same time
+// SymbolicOperand entries with string mnemonics and capabilities.
+//===----------------------------------------------------------------------===//
+def MatrixMultiplyAccumulate : GenericEnum, Operand<i32> {
+  let FilterClass = "MatrixMultiplyAccumulate";
+  let NameField = "Name";
+  let ValueField = "Value";
+  let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class MatrixMultiplyAccumulate<string name, bits<32> value> {
+  string Name = name;
+  bits<32> Value = value;
+}
+
+multiclass  MatrixMultiplyAccumulateOperand<bits<32> value, list<Extension> reqExtensions> {
+def : MatrixMultiplyAccumulate<NAME, value>;
+  defm : SymbolicOperandWithRequirements< MatrixMultiplyAccumulateOperand, value, NAME, 0, 0, reqExtensions, []>;
+}
+
+defm None :  MatrixMultiplyAccumulateOperand<0x0, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixASignedComponentsINTEL :  MatrixMultiplyAccumulateOperand<0x1, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBSignedComponentsINTEL :  MatrixMultiplyAccumulateOperand<0x2, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixCBFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x4, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixResultBFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x8, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixAPackedInt8INTEL :  MatrixMultiplyAccumulateOperand<0x10, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBPackedInt8INTEL :  MatrixMultiplyAccumulateOperand<0x20, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixAPackedInt4INTEL :  MatrixMultiplyAccumulateOperand<0x40, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBPackedInt4INTEL :  MatrixMultiplyAccumulateOperand<0x80, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixATF32INTEL :  MatrixMultiplyAccumulateOperand<0x100, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBTF32INTEL :  MatrixMultiplyAccumulateOperand<0x200, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixAPackedFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x400, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBPackedFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixAPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+defm MatrixBPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll
new file mode 100644
index 0000000000000..925cb34ab3cd2
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll
@@ -0,0 +1,261 @@
+; generated with mma.cl:
+; #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+; 
+; // all combinations of parameter types
+; int  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int  Matrix_A, int8 Matrix_B, int  Matrix_C, int Operands);
+; int2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int2 Matrix_A, int8 Matrix_B, int2 Matrix_C, int Operands);
+; int4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int4 Matrix_A, int8 Matrix_B, int4 Matrix_C, int Operands);
+; int8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int8 Matrix_A, int8 Matrix_B, int8 Matrix_C, int Operands);
+; 
+; float  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int  Matrix_A, int8 Matrix_B, float  Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int2 Matrix_A, int8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int4 Matrix_A, int8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int8 Matrix_A, int8 Matrix_B, float8 Matrix_C, int Operands);
+; 
+; int  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short  Matrix_A, int8 Matrix_B, int  Matrix_C, int Operands);
+; int2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, int2 Matrix_C, int Operands);
+; int4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, int4 Matrix_C, int Operands);
+; int8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, int8 Matrix_C, int Operands);
+; 
+; float  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short  Matrix_A, int8 Matrix_B, float  Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, float8 Matrix_C, int Operands);
+; 
+; half  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short  Matrix_A, int8 Matrix_B, half  Matrix_C, int Operands);
+; half2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, half2 Matrix_C, int Operands);
+; half4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, half4 Matrix_C, int Operands);
+; half8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, half8 Matrix_C, int Operands);
+; 
+; short  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short  Matrix_A, int8 Matrix_B, short  Matrix_C, int Operands);
+; short2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, short2 Matrix_C, int Operands);
+; short4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, short4 Matrix_C, int Operands);
+; short8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, short8 Matrix_C, int Operands);
+; 
+; float  __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float  Matrix_A, float8 Matrix_B, float  Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float2 Matrix_A, float8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float4 Matrix_A, float8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float8 Matrix_A, float8 Matrix_B, float8 Matrix_C, int Operands);
+; 
+; // no operands
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, float4 Matrix_C);
+; 
+; void foo(int iM, int2 iM2, int4 iM4, int8 iM8,
+;          short sM, short2 sM2, short4 sM4, short8 sM8,
+;          float fM, float2 fM2, float4 fM4, float8 fM8,
+;          half hM, half2 hM2, half4 hM4, half8 hM8) {
+;     const int i = 42;
+;     int D = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM, iM8, iM, 0xA);
+;     int2 D2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM2, iM8, iM2, 0xA);
+;     int4 D4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM4, iM8, iM4, 0xA);
+;     int8 D8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM8, iM8, iM8, 0xA);
+; 
+;     float fD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM, iM8, fM, 0xA);
+;     float2 fD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM2, iM8, fM2, 0xA);
+;     float4 fD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM4, iM8, fM4, 0xA);
+;     float8 fD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM8, iM8, fM8, 0xA);
+; 
+;     int sD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, iM, 0xA);
+;     int2 sD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, iM2, 0xA);
+;     int4 sD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, iM4, 0xA);
+;     int8 sD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM8, iM8, iM8, 0xA);
+; 
+;     float sfD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, fM, 0xA);
+;     float2 sfD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, fM2, 0xA);
+;     float4 sfD4 = __spirv_Subgro...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants