@@ -36,8 +36,7 @@ struct Xe2Plus : public uArch {
3636 const XeCoreInfo &xeCore)
3737 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
3838 int getSubgroupSize () const override { return 16 ; }
39- unsigned getPackedFormatBitSize () const override { return 16 ; }
40- unsigned getPackedFormatBitSizeGatherScatter () const override { return 32 ; }
39+ unsigned getGeneralPackedFormatBitSize () const override { return 32 ; }
4140
4241protected:
4342 XeCoreInfo xeCore;
@@ -46,16 +45,15 @@ struct Xe2Plus : public uArch {
4645// ===----------------------------------------------------------------------===//
4746// uArch instructions
4847// ===----------------------------------------------------------------------===//
49- struct StoreNdInstruction : public Instruction {
50- StoreNdInstruction ()
51- : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
48+ struct Subgroup2DBlockStoreInstruction : public Instruction {
49+ Subgroup2DBlockStoreInstruction ()
50+ : Instruction(InstructionKind::Subgroup2DBlockStore,
51+ InstructionScope::Subgroup) {}
5252 static bool classof (const Instruction *B) {
53- return B->getInstructionKind () == InstructionKind::STORE_ND ;
53+ return B->getInstructionKind () == InstructionKind::Subgroup2DBlockStore ;
5454 }
5555 // Source :
5656 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57- // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
58- // the specified pointer
5957 std::optional<
6058 std::tuple<llvm::ArrayRef<int >, llvm::ArrayRef<int >, llvm::ArrayRef<int >>>
6159 getBlockWidthHeightCount (Type elemTy) const {
@@ -74,19 +72,20 @@ struct StoreNdInstruction : public Instruction {
7472 llvm::ArrayRef<int >(kCount ));
7573 return std::nullopt ;
7674 }
75+
76+ int32_t getPackedFormatBitSize () const { return 16 ; }
7777};
7878
79- struct LoadNdInstruction : public Instruction {
80- LoadNdInstruction ()
81- : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
79+ struct Subgroup2DBlockLoadInstruction : public Instruction {
80+ Subgroup2DBlockLoadInstruction ()
81+ : Instruction(InstructionKind::Subgroup2DBlockLoad,
82+ InstructionScope::Subgroup) {}
8283 static bool classof (const Instruction *B) {
83- return B->getInstructionKind () == InstructionKind::LOAD_ND ;
84+ return B->getInstructionKind () == InstructionKind::Subgroup2DBlockLoad ;
8485 }
8586
8687 // Source :
8788 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
88- // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
89- // the specified pointer.
9089 std::optional<
9190 std::tuple<llvm::ArrayRef<int >, llvm::ArrayRef<int >, llvm::ArrayRef<int >>>
9291 getBlockWidthHeightCount (Type elemTy, bool hasTransform, bool hasTranspose,
@@ -126,13 +125,16 @@ struct LoadNdInstruction : public Instruction {
126125 return it->second ;
127126 return std::nullopt ;
128127 }
128+
129+ int32_t getPackedFormatBitSize () const { return 16 ; }
129130};
130131
131- struct PrefetchNdInstruction : public Instruction {
132- PrefetchNdInstruction ()
133- : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
132+ struct Subgroup2DBlockPrefetchInstruction : public Instruction {
133+ Subgroup2DBlockPrefetchInstruction ()
134+ : Instruction(InstructionKind::Subgroup2DBlockPrefetch,
135+ InstructionScope::Subgroup) {}
134136 static bool classof (const Instruction *B) {
135- return B->getInstructionKind () == InstructionKind::PREFETCH_ND ;
137+ return B->getInstructionKind () == InstructionKind::Subgroup2DBlockPrefetch ;
136138 }
137139 // Source :
138140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
@@ -162,15 +164,20 @@ struct PrefetchNdInstruction : public Instruction {
162164 return it->second ;
163165 return std::nullopt ;
164166 }
167+ int32_t getPackedFormatBitSize () const { return 16 ; }
165168};
166169
167- struct DPASInstruction : public Instruction , public MMAInstructionInterface {
168- DPASInstruction (unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
169- : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
170+ struct SubgroupMatrixMultiplyAcc : public Instruction ,
171+ public MMAInstructionInterface {
172+ SubgroupMatrixMultiplyAcc (unsigned packedFormatBitSizeA,
173+ unsigned packedFormatBitSizeB)
174+ : Instruction(InstructionKind::SubgroupMatrixMultiplyAcc,
175+ InstructionScope::Subgroup),
170176 packedFormatBitSizeA (packedFormatBitSizeA),
171177 packedFormatBitSizeB(packedFormatBitSizeB) {}
172178 static bool classof (const Instruction *B) {
173- return B->getInstructionKind () == InstructionKind::DPAS;
179+ return B->getInstructionKind () ==
180+ InstructionKind::SubgroupMatrixMultiplyAcc;
174181 }
175182 // Source:
176183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
@@ -214,10 +221,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
214221
215222struct PVCuArch final : public Xe2Plus {
216223 static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr () {
217- static const DPASInstruction dpasInst{16 , 32 };
218- static const StoreNdInstruction loadNdInst;
219- static const StoreNdInstruction storeNdInst;
220- static const PrefetchNdInstruction prefetchNdInst;
224+ static const SubgroupMatrixMultiplyAcc dpasInst{16 , 32 };
225+ static const Subgroup2DBlockLoadInstruction loadNdInst;
226+ static const Subgroup2DBlockStoreInstruction storeNdInst;
227+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
221228 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
222229 &prefetchNdInst};
223230 return arr;
@@ -237,10 +244,10 @@ struct PVCuArch final : public Xe2Plus {
237244
238245struct BMGuArch : public Xe2Plus {
239246 static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr () {
240- static const DPASInstruction dpasInst{16 , 32 };
241- static const StoreNdInstruction loadNdInst;
242- static const StoreNdInstruction storeNdInst;
243- static const PrefetchNdInstruction prefetchNdInst;
247+ static const SubgroupMatrixMultiplyAcc dpasInst{16 , 32 };
248+ static const Subgroup2DBlockLoadInstruction loadNdInst;
249+ static const Subgroup2DBlockStoreInstruction storeNdInst;
250+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
244251 static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
245252 &prefetchNdInst};
246253 return arr;
@@ -276,7 +283,8 @@ inline const uArch *getUArch(llvm::StringRef archName) {
276283// ===----------------------------------------------------------------------===//
277284
278285inline llvm::SmallVector<std::pair<uint32_t , uint32_t >, 16 >
279- DPASInstruction::getSupportedShapes (Type dataType, MMAOpndKind matrixType) {
286+ SubgroupMatrixMultiplyAcc::getSupportedShapes (Type dataType,
287+ MMAOpndKind matrixType) {
280288 auto combineVectors = [](const llvm::SmallVector<uint32_t , 8 > &a,
281289 const llvm::SmallVector<uint32_t , 8 > &b)
282290 -> llvm::SmallVector<std::pair<uint32_t , uint32_t >, 16 > {
@@ -312,8 +320,8 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
312320}
313321
314322inline llvm::SmallVector<Type, 8 >
315- DPASInstruction ::getSupportedTypes (MLIRContext &context,
316- MMAOpndKind matrixType) {
323+ SubgroupMatrixMultiplyAcc ::getSupportedTypes (MLIRContext &context,
324+ MMAOpndKind matrixType) {
317325 Type bf16Type = BFloat16Type::get (&context);
318326 Type f16Type = Float16Type::get (&context);
319327 Type tf32Type = FloatTF32Type::get (&context);
@@ -332,8 +340,10 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
332340 return {};
333341}
334342
335- inline bool DPASInstruction::checkSupportedTypes (Type AType, Type BType,
336- Type CType, Type DType) {
343+ inline bool SubgroupMatrixMultiplyAcc::checkSupportedTypes (Type AType,
344+ Type BType,
345+ Type CType,
346+ Type DType) {
337347 if (AType.isF16 () || BType.isF16 ()) {
338348 if (AType != BType || (CType && (!CType.isF32 () && !CType.isF16 ())) ||
339349 (!DType.isF32 () && !DType.isF16 ())) {
@@ -363,7 +373,7 @@ inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
363373 return true ;
364374}
365375
366- inline bool DPASInstruction ::checkSupportedShapesAndTypes (
376+ inline bool SubgroupMatrixMultiplyAcc ::checkSupportedShapesAndTypes (
367377 std::pair<uint32_t , uint32_t > AShape, std::pair<uint32_t , uint32_t > BShape,
368378 std::pair<uint32_t , uint32_t > CShape, std::pair<uint32_t , uint32_t > DShape,
369379 Type AType, Type BType, Type CType, Type DType) {
@@ -378,23 +388,21 @@ inline bool DPASInstruction::checkSupportedShapesAndTypes(
378388 checkSupportedTypes (AType, BType, CType, DType);
379389}
380390
381- inline bool DPASInstruction::validate (std::pair<uint32_t , uint32_t > AShape,
382- std::pair<uint32_t , uint32_t > BShape,
383- std::pair<uint32_t , uint32_t > CShape,
384- std::pair<uint32_t , uint32_t > DShape,
385- Type AType, Type BType, Type CType,
386- Type DType) {
391+ inline bool SubgroupMatrixMultiplyAcc::validate (
392+ std::pair<uint32_t , uint32_t > AShape, std::pair<uint32_t , uint32_t > BShape,
393+ std::pair<uint32_t , uint32_t > CShape, std::pair<uint32_t , uint32_t > DShape,
394+ Type AType, Type BType, Type CType, Type DType) {
387395 return checkSupportedShapesAndTypes (AShape, BShape, CShape, DShape, AType,
388396 BType, CType, DType);
389397}
390398
391399inline llvm::SmallVector<uint32_t , 8 >
392- DPASInstruction ::getSupportedM (Type type) const {
400+ SubgroupMatrixMultiplyAcc ::getSupportedM (Type type) const {
393401 return {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
394402}
395403
396404inline llvm::SmallVector<uint32_t , 8 >
397- DPASInstruction ::getSupportedK (Type type) const {
405+ SubgroupMatrixMultiplyAcc ::getSupportedK (Type type) const {
398406 // assert if data type is not int or float type
399407 assert (type.isIntOrFloat () && " Matrix type must be int or float" );
400408 auto bitWidth = type.getIntOrFloatBitWidth ();
@@ -422,7 +430,7 @@ DPASInstruction::getSupportedK(Type type) const {
422430}
423431
424432inline llvm::SmallVector<uint32_t , 8 >
425- DPASInstruction ::getSupportedN (Type type) const {
433+ SubgroupMatrixMultiplyAcc ::getSupportedN (Type type) const {
426434 return {16 };
427435}
428436
0 commit comments