@@ -32,9 +32,15 @@ namespace xegpu {
3232using namespace mlir ;
3333using namespace mlir ::dataflow;
3434
35- constexpr unsigned subgroupSize = 16 ;
36- constexpr unsigned packedASizeInBits = 16 ;
37- constexpr unsigned packedBSizeInBits = 32 ;
35+ // / HW dependent constants.
36+ // / TODO: These constants should be queried from the uArch interface.
37+ constexpr unsigned subgroupSize = 16 ; // How many work items in a subgroup.
38+ // / If DPAS A or B operands have low precision element types they must be packed
39+ // / according to the following sizes.
40+ constexpr unsigned packedSizeInBitsForDefault =
41+ 16 ; // Minimum packing size per register for DPAS A.
42+ constexpr unsigned packedSizeInBitsForDpasB =
43+ 32 ; // Minimum packing size per register for DPAS B.
3844
3945namespace {
4046
@@ -51,7 +57,7 @@ struct Layout {
5157 Layout (std::initializer_list<int64_t > list) : layout(list) {}
5258 void print (llvm::raw_ostream &os) const ;
5359 size_t size () const { return layout.size (); }
54- int64_t operator [](size_t idx) const { return layout[idx]; }
60+ int64_t operator [](size_t idx) const ;
5561};
5662
5763void Layout::print (llvm::raw_ostream &os) const {
@@ -60,6 +66,11 @@ void Layout::print(llvm::raw_ostream &os) const {
6066 os << " ]" ;
6167}
6268
69+ int64_t Layout::operator [](size_t idx) const {
70+ assert (idx < layout.size () && " Index out of bounds." );
71+ return layout[idx];
72+ }
73+
6374// / WiLayout represents the layout of work items within a subgroup when it
6475// / accesses some value. WiData represents the layout of data owned by each work
6576// / item.
@@ -86,14 +97,14 @@ using WiData = Layout;
8697
8798struct SGMap {
8899private:
89- WiLayout layout ;
90- WiData data ;
100+ WiLayout wiLayout ;
101+ WiData wiData ;
91102
92103public:
93104 SGMap () = default ;
94105 SGMap (const SGMap &other) = default ;
95106 SGMap (const WiLayout &layout, const WiData &data)
96- : layout (layout), data (data) {}
107+ : wiLayout (layout), wiData (data) {}
97108
98109 // / Two lattice values are equal if they have `some` layout. The actual
99110 // / content of the layout does not matter.
@@ -107,20 +118,20 @@ struct SGMap {
107118
108119 void print (raw_ostream &os) const ;
109120
110- bool isAssigned () const { return layout .size () > 0 && data .size () > 0 ; }
121+ bool isAssigned () const { return wiLayout .size () > 0 && wiData .size () > 0 ; }
111122
112123 SGMap getTransposedLayout (ArrayRef<int64_t > permutation) const ;
113124
114- const WiLayout &getLayout () const { return layout ; }
115- const WiData &getData () const { return data ; }
125+ const WiLayout &getLayout () const { return wiLayout ; }
126+ const WiData &getData () const { return wiData ; }
116127};
117128
118129void SGMap::print (raw_ostream &os) const {
119130 if (isAssigned ()) {
120131 os << " wi_layout: " ;
121- layout .print (os);
132+ wiLayout .print (os);
122133 os << " , wi_data: " ;
123- data .print (os);
134+ wiData .print (os);
124135 } else
125136 os << " Not assigned." ;
126137}
@@ -143,8 +154,8 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
143154 WiLayout newLayout;
144155 WiData newData;
145156 for (auto idx : permutation) {
146- newLayout.layout .push_back (layout .layout [idx]);
147- newData.layout .push_back (data .layout [idx]);
157+ newLayout.layout .push_back (wiLayout .layout [idx]);
158+ newData.layout .push_back (wiData .layout [idx]);
148159 }
149160 return SGMap (newLayout, newData);
150161}
@@ -159,54 +170,61 @@ struct SGMapLattice : public Lattice<SGMap> {
159170 using Lattice::Lattice;
160171};
161172
162- // / Helper Functions
163-
164- // / Helper Function to get the expected layouts for DPAS operands.
165- static SGMap getSGMapForDPASOperand (Type operandTy, unsigned operandNum) {
166- int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth ();
167- int packingFactorForA =
168- operandTy.getIntOrFloatBitWidth () < packedBSizeInBits
169- ? packedASizeInBits / operandTy.getIntOrFloatBitWidth ()
170- : 1 ;
171- return SGMap (WiLayout ({1 , subgroupSize}),
172- WiData ({operandNum == 1 ? packingFactorForB : 1 ,
173- operandNum == 0 ? packingFactorForA : 1 }));
174- }
175-
176- // / Helper Function to get the default layout for a given type. Usually this is,
177- // / wi_layout = [1, subgroupSize] and wi_data = [1, 1].
178- // / However, the minimum granularity of data access per work item is 16-bits.
179- // / So, if the bitwidth of the type is less than 16, we need to pack the data to
180- // / 16-bits.
181- // static SGMap getDefaultSgMap(Type ty, unsigned rank) {
182- // int packingFactor = 1;
183- // if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
184- // packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
185- // return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
186- // }
173+ // / Helper Functions to get default layouts. A `default layout` is a layout that
174+ // / is assigned to a value when the layout is not fixed by some anchor operation
175+ // / (like DPAS). This is the natural layout work items are arranged in a
176+ // / subgroup.
187177
188178// / Helper Function to get the default layout for uniform values like constants.
179+ // / For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
180+ // / For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
189181static SGMap getDefaultSgMap (unsigned rank) {
190182 assert ((rank == 1 || rank == 2 ) && " Expected 0D or 1D vector." );
191183 if (rank == 1 )
192184 return SGMap (WiLayout ({subgroupSize}), WiData ({1 }));
193185 return SGMap (WiLayout ({1 , subgroupSize}), WiData ({1 , 1 }));
194186}
195187
188+ // / Helper to get the default layout for a vector type.
196189static SGMap getDefaultSgMap (VectorType vectorTy) {
197190 // / Expecting a 1D or 2D vector.
198191 assert ((vectorTy.getRank () == 1 || vectorTy.getRank () == 2 ) &&
199192 " Expected 1D or 2D vector." );
193+ // / Expecting int or float element type.
194+ assert (vectorTy.getElementType ().isIntOrFloat () &&
195+ " Expected int or float element type." );
200196 // / If the rank is 1, then return default layout for 1D vector.
201197 if (vectorTy.getRank () == 1 )
202198 return getDefaultSgMap (1 );
199+ // / Packing factor is determined by the element type bitwidth.
203200 int packingFactor = 1 ;
204201 auto bitwidth = vectorTy.getElementType ().getIntOrFloatBitWidth ();
205- if (bitwidth < packedASizeInBits )
206- packingFactor = packedBSizeInBits / bitwidth;
202+ if (bitwidth < packedSizeInBitsForDefault )
203+ packingFactor = packedSizeInBitsForDefault / bitwidth;
207204 return SGMap (WiLayout ({1 , subgroupSize}), WiData ({1 , packingFactor}));
208205}
209206
207+ // / Helper Function to get the expected layouts for DPAS operands. `wi_data` is
208+ // / set according to the following criteria:
209+ // / * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
210+ // / * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
211+ static SGMap getSGMapForDPASOperand (VectorType vectorTy, unsigned operandNum) {
212+ auto elementTy = vectorTy.getElementType ();
213+ assert (elementTy.isIntOrFloat () &&
214+ " Expected int or float type in DPAS operands" );
215+ WiLayout layout ({1 , subgroupSize});
216+ // / For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
217+ // / must have the VNNI format.
218+ if (operandNum == 1 &&
219+ elementTy.getIntOrFloatBitWidth () < packedSizeInBitsForDpasB) {
220+ WiData data (
221+ {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth (), 1 });
222+ return SGMap (layout, data);
223+ }
224+ // / Otherwise, return the default layout for the vector type.
225+ return getDefaultSgMap (vectorTy);
226+ }
227+
210228// /===----------------------------------------------------------------------===///
211229// / SGMapPropagation
212230// /===----------------------------------------------------------------------===///
@@ -360,14 +378,14 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
360378void SGMapPropagation::visitDpasOp (xegpu::DpasOp dpas,
361379 ArrayRef<SGMapLattice *> operands,
362380 ArrayRef<const SGMapLattice *> results) {
363- auto aTy = dpas.getLhsType (). getElementType () ;
364- auto bTy = dpas.getRhsType (). getElementType () ;
381+ auto aTy = dpas.getLhsType ();
382+ auto bTy = dpas.getRhsType ();
365383 propagateIfChanged (operands[0 ],
366384 operands[0 ]->meet (getSGMapForDPASOperand (aTy, 0 )));
367385 propagateIfChanged (operands[1 ],
368386 operands[1 ]->meet (getSGMapForDPASOperand (bTy, 1 )));
369387 if (operands.size () > 2 ) {
370- auto cTy = dpas.getAccType (). getElementType () ;
388+ auto cTy = dpas.getAccType ();
371389 propagateIfChanged (operands[2 ],
372390 operands[2 ]->meet (getSGMapForDPASOperand (cTy, 2 )));
373391 }
0 commit comments