@@ -53,12 +53,26 @@ struct StoreNdInstruction : public Instruction {
5353 return B->getInstructionKind () == InstructionKind::STORE_ND;
5454 }
5555 // Source :
56- // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups .html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
56+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io .html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
5757 // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
5858 // the specified pointer
59- llvm::ArrayRef<int > getSortedLaneVectorLengths () const {
60- const static int sortedLaneVectorLengths[] = {1 , 2 , 4 , 8 };
61- return sortedLaneVectorLengths;
59+ std::optional<
60+ std::tuple<llvm::ArrayRef<int >, llvm::ArrayRef<int >, llvm::ArrayRef<int >>>
61+ getBlockWidthHeightCount (Type elemTy) const {
62+ const static int kHeight [] = {1 , 2 , 4 , 8 };
63+ const static int kWidth16 [] = {16 };
64+ const static int kWidth32 [] = {16 };
65+ const static int kCount [] = {1 };
66+ const int elemByteSize = elemTy.getIntOrFloatBitWidth () / 8 ;
67+ if (elemByteSize == 1 )
68+ return std::make_tuple (llvm::ArrayRef<int >(kWidth32 ),
69+ llvm::ArrayRef<int >(kHeight ),
70+ llvm::ArrayRef<int >(kCount ));
71+ else if (elemByteSize == 2 || elemByteSize == 4 )
72+ return std::make_tuple (llvm::ArrayRef<int >(kWidth16 ),
73+ llvm::ArrayRef<int >(kHeight ),
74+ llvm::ArrayRef<int >(kCount ));
75+ return std::nullopt ;
6276 }
6377};
6478
@@ -68,13 +82,49 @@ struct LoadNdInstruction : public Instruction {
6882 static bool classof (const Instruction *B) {
6983 return B->getInstructionKind () == InstructionKind::LOAD_ND;
7084 }
85+
7186 // Source :
72- // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups .html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
87+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io .html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
7388 // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
7489 // the specified pointer.
75- llvm::ArrayRef<int > getSortedLaneVectorLengths () const {
76- const static int sortedLaneVectorLengths[] = {1 , 2 , 4 , 8 };
77- return sortedLaneVectorLengths;
90+ std::optional<
91+ std::tuple<llvm::ArrayRef<int >, llvm::ArrayRef<int >, llvm::ArrayRef<int >>>
92+ getBlockWidthHeightCount (Type elemTy, bool hasTransform, bool hasTranspose,
93+ bool upConv = false ) const {
94+ static const int kHeightAtLeast1 [] = {1 , 2 , 4 , 8 , 16 , 32 };
95+ static const int kHeightAtLeast8 [] = {8 , 16 , 32 };
96+ static const int kHeightAtLeast16 [] = {16 , 32 };
97+ static const int kHeightAtLeast32 [] = {32 };
98+
99+ static const int kWidth32 [] = {32 };
100+ static const int kWidth16 [] = {16 };
101+ static const int kWidth8 [] = {8 };
102+
103+ static const int32_t kCount1 [] = {1 };
104+ static const int32_t kCount2 [] = {1 , 2 };
105+ static const int32_t kCount4 [] = {1 , 2 , 4 };
106+ static const int32_t kCount4Only [] = {4 };
107+ // (elemBytes, transform, transpose, upConvert)
108+ using Key = std::tuple<int , uint8_t , uint8_t , uint8_t >;
109+ // (widths, heights, counts)
110+ using Value = std::tuple<llvm::ArrayRef<int32_t >, llvm::ArrayRef<int32_t >,
111+ llvm::ArrayRef<int32_t >>;
112+ static const llvm::DenseMap<Key, Value> kMap = {
113+ {{1 , false , false , false }, {kWidth32 , kHeightAtLeast1 , kCount2 }},
114+ {{1 , false , false , true }, {kWidth16 , kHeightAtLeast8 , kCount4Only }},
115+ {{2 , false , false , false }, {kWidth16 , kHeightAtLeast1 , kCount2 }},
116+ {{4 , false , false , false }, {kWidth16 , kHeightAtLeast1 , kCount1 }},
117+ // Block Loads with Transform:
118+ {{1 , true , false , false }, {kWidth16 , kHeightAtLeast32 , kCount4 }},
119+ {{2 , true , false , false }, {kWidth16 , kHeightAtLeast16 , kCount2 }},
120+ // Block Loads with Transpose:
121+ {{4 , false , true , false }, {kWidth8 , kHeightAtLeast16 , kCount1 }},
122+ };
123+ const int elemByteSize = elemTy.getIntOrFloatBitWidth () / 8 ;
124+ auto it = kMap .find ({elemByteSize, hasTransform, hasTranspose, upConv});
125+ if (it != kMap .end ())
126+ return it->second ;
127+ return std::nullopt ;
78128 }
79129};
80130
@@ -86,29 +136,39 @@ struct PrefetchNdInstruction : public Instruction {
86136 }
87137 // Source :
88138 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
89- llvm::ArrayRef<int > getSortedLaneVectorLengths (int elementBitwidth) const {
90- const static int sortedNarrowTypesLengths[] = {1 , 2 , 4 , 8 , 16 };
91- const static int sortedWideTypesLengths[] = {1 , 2 , 4 , 8 };
92- switch (elementBitwidth) {
93- case 8 :
94- case 16 :
95- return sortedNarrowTypesLengths;
96- case 32 :
97- case 64 :
98- return sortedWideTypesLengths;
99- default :
100- llvm_unreachable (" Unsupported element bitwidth" );
101- }
139+ std::optional<
140+ std::tuple<llvm::ArrayRef<int >, llvm::ArrayRef<int >, llvm::ArrayRef<int >>>
141+ getBlockWidthHeightCount (Type elemTy) const {
142+ static const int kHeightAtLeast1 [] = {1 , 2 , 4 , 8 , 16 , 32 };
143+
144+ static const int kWidth32 [] = {32 };
145+ static const int kWidth16 [] = {16 };
146+
147+ static const int32_t kCount1 [] = {1 };
148+ static const int32_t kCount2 [] = {1 , 2 };
149+ // elemBytes
150+ using Key = int ;
151+ // (widths, heights, counts)
152+ using Value = std::tuple<llvm::ArrayRef<int32_t >, llvm::ArrayRef<int32_t >,
153+ llvm::ArrayRef<int32_t >>;
154+ static const llvm::DenseMap<Key, Value> kMap = {
155+ {1 , {kWidth32 , kHeightAtLeast1 , kCount2 }},
156+ {2 , {kWidth16 , kHeightAtLeast1 , kCount2 }},
157+ {4 , {kWidth16 , kHeightAtLeast1 , kCount1 }},
158+ };
159+ const int elemByteSize = elemTy.getIntOrFloatBitWidth () / 8 ;
160+ auto it = kMap .find (elemByteSize);
161+ if (it != kMap .end ())
162+ return it->second ;
163+ return std::nullopt ;
102164 }
103165};
104166
105167struct DPASInstruction : public Instruction , public MMAInstructionInterface {
106- DPASInstruction (unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
107- unsigned packedFormatBitSizeC)
168+ DPASInstruction (unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
108169 : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
109170 packedFormatBitSizeA (packedFormatBitSizeA),
110- packedFormatBitSizeB(packedFormatBitSizeB),
111- packedFormatBitSizeC(packedFormatBitSizeC) {}
171+ packedFormatBitSizeB(packedFormatBitSizeB) {}
112172 static bool classof (const Instruction *B) {
113173 return B->getInstructionKind () == InstructionKind::DPAS;
114174 }
@@ -142,12 +202,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
142202
143203 unsigned getPackedFormatBitSizeA () const { return packedFormatBitSizeA; }
144204 unsigned getPackedFormatBitSizeB () const { return packedFormatBitSizeB; }
145- unsigned getPackedFormatBitSizeC () const { return packedFormatBitSizeC; }
146205
147206protected:
148207 const unsigned packedFormatBitSizeA;
149208 const unsigned packedFormatBitSizeB;
150- const unsigned packedFormatBitSizeC;
151209};
152210
153211// ===----------------------------------------------------------------------===//
@@ -156,7 +214,7 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
156214
157215struct PVCuArch final : public Xe2Plus {
158216 static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr () {
159- static const DPASInstruction dpasInst{16 , 32 , 32 };
217+ static const DPASInstruction dpasInst{16 , 32 };
160218 static const StoreNdInstruction loadNdInst;
161219 static const StoreNdInstruction storeNdInst;
162220 static const PrefetchNdInstruction prefetchNdInst;
@@ -179,7 +237,7 @@ struct PVCuArch final : public Xe2Plus {
179237
180238struct BMGuArch : public Xe2Plus {
181239 static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr () {
182- static const DPASInstruction dpasInst{16 , 32 , 32 };
240+ static const DPASInstruction dpasInst{16 , 32 };
183241 static const StoreNdInstruction loadNdInst;
184242 static const StoreNdInstruction storeNdInst;
185243 static const PrefetchNdInstruction prefetchNdInst;
0 commit comments