Skip to content

Commit f64581a

Browse files
committed
save work
1 parent f7b7bdf commit f64581a

File tree

1 file changed

+62
-44
lines changed

1 file changed

+62
-44
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@ namespace xegpu {
3232
using namespace mlir;
3333
using namespace mlir::dataflow;
3434

35-
constexpr unsigned subgroupSize = 16;
36-
constexpr unsigned packedASizeInBits = 16;
37-
constexpr unsigned packedBSizeInBits = 32;
35+
/// HW dependent constants.
36+
/// TODO: These constants should be queried from the uArch interface.
37+
constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
38+
/// If DPAS A or B operands have low precision element types they must be packed
39+
/// according to the following sizes.
40+
constexpr unsigned packedSizeInBitsForDefault =
41+
16; // Minimum packing size per register for DPAS A.
42+
constexpr unsigned packedSizeInBitsForDpasB =
43+
32; // Minimum packing size per register for DPAS B.
3844

3945
namespace {
4046

@@ -51,7 +57,7 @@ struct Layout {
5157
Layout(std::initializer_list<int64_t> list) : layout(list) {}
5258
void print(llvm::raw_ostream &os) const;
5359
size_t size() const { return layout.size(); }
54-
int64_t operator[](size_t idx) const { return layout[idx]; }
60+
int64_t operator[](size_t idx) const;
5561
};
5662

5763
void Layout::print(llvm::raw_ostream &os) const {
@@ -60,6 +66,11 @@ void Layout::print(llvm::raw_ostream &os) const {
6066
os << "]";
6167
}
6268

69+
int64_t Layout::operator[](size_t idx) const {
70+
assert(idx < layout.size() && "Index out of bounds.");
71+
return layout[idx];
72+
}
73+
6374
/// WiLayout represents the layout of work items within a subgroup when it
6475
/// accesses some value. WiData represents the layout of data owned by each work
6576
/// item.
@@ -86,14 +97,14 @@ using WiData = Layout;
8697

8798
struct SGMap {
8899
private:
89-
WiLayout layout;
90-
WiData data;
100+
WiLayout wiLayout;
101+
WiData wiData;
91102

92103
public:
93104
SGMap() = default;
94105
SGMap(const SGMap &other) = default;
95106
SGMap(const WiLayout &layout, const WiData &data)
96-
: layout(layout), data(data) {}
107+
: wiLayout(layout), wiData(data) {}
97108

98109
/// Two lattice values are equal if they have `some` layout. The actual
99110
/// content of the layout does not matter.
@@ -107,20 +118,20 @@ struct SGMap {
107118

108119
void print(raw_ostream &os) const;
109120

110-
bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
121+
bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
111122

112123
SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
113124

114-
const WiLayout &getLayout() const { return layout; }
115-
const WiData &getData() const { return data; }
125+
const WiLayout &getLayout() const { return wiLayout; }
126+
const WiData &getData() const { return wiData; }
116127
};
117128

118129
void SGMap::print(raw_ostream &os) const {
119130
if (isAssigned()) {
120131
os << "wi_layout: ";
121-
layout.print(os);
132+
wiLayout.print(os);
122133
os << ", wi_data: ";
123-
data.print(os);
134+
wiData.print(os);
124135
} else
125136
os << "Not assigned.";
126137
}
@@ -143,8 +154,8 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
143154
WiLayout newLayout;
144155
WiData newData;
145156
for (auto idx : permutation) {
146-
newLayout.layout.push_back(layout.layout[idx]);
147-
newData.layout.push_back(data.layout[idx]);
157+
newLayout.layout.push_back(wiLayout.layout[idx]);
158+
newData.layout.push_back(wiData.layout[idx]);
148159
}
149160
return SGMap(newLayout, newData);
150161
}
@@ -159,54 +170,61 @@ struct SGMapLattice : public Lattice<SGMap> {
159170
using Lattice::Lattice;
160171
};
161172

162-
/// Helper Functions
163-
164-
/// Helper Function to get the expected layouts for DPAS operands.
165-
static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
166-
int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
167-
int packingFactorForA =
168-
operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
169-
? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
170-
: 1;
171-
return SGMap(WiLayout({1, subgroupSize}),
172-
WiData({operandNum == 1 ? packingFactorForB : 1,
173-
operandNum == 0 ? packingFactorForA : 1}));
174-
}
175-
176-
/// Helper Function to get the default layout for a given type. Usually this is,
177-
/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
178-
/// However, the minimum granularity of data access per work item is 16-bits.
179-
/// So, if the bitwidth of the type is less than 16, we need to pack the data to
180-
/// 16-bits.
181-
// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
182-
// int packingFactor = 1;
183-
// if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
184-
// packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
185-
// return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
186-
// }
173+
/// Helper Functions to get default layouts. A `default layout` is a layout that
174+
/// is assigned to a value when the layout is not fixed by some anchor operation
175+
/// (like DPAS). This is the natural layout work items are arranged in a
176+
/// subgroup.
187177

188178
/// Helper Function to get the default layout for uniform values like constants.
179+
/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
180+
/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
189181
static SGMap getDefaultSgMap(unsigned rank) {
190182
assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
191183
if (rank == 1)
192184
return SGMap(WiLayout({subgroupSize}), WiData({1}));
193185
return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
194186
}
195187

188+
/// Helper to get the default layout for a vector type.
196189
static SGMap getDefaultSgMap(VectorType vectorTy) {
197190
/// Expecting a 1D or 2D vector.
198191
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
199192
"Expected 1D or 2D vector.");
193+
/// Expecting int or float element type.
194+
assert(vectorTy.getElementType().isIntOrFloat() &&
195+
"Expected int or float element type.");
200196
/// If the rank is 1, then return default layout for 1D vector.
201197
if (vectorTy.getRank() == 1)
202198
return getDefaultSgMap(1);
199+
/// Packing factor is determined by the element type bitwidth.
203200
int packingFactor = 1;
204201
auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
205-
if (bitwidth < packedASizeInBits)
206-
packingFactor = packedBSizeInBits / bitwidth;
202+
if (bitwidth < packedSizeInBitsForDefault)
203+
packingFactor = packedSizeInBitsForDefault / bitwidth;
207204
return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
208205
}
209206

207+
/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
208+
/// set according to the following criteria:
209+
/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
210+
/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
211+
static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
212+
auto elementTy = vectorTy.getElementType();
213+
assert(elementTy.isIntOrFloat() &&
214+
"Expected int or float type in DPAS operands");
215+
WiLayout layout({1, subgroupSize});
216+
/// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
217+
/// must have the VNNI format.
218+
if (operandNum == 1 &&
219+
elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
220+
WiData data(
221+
{packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
222+
return SGMap(layout, data);
223+
}
224+
/// Otherwise, return the default layout for the vector type.
225+
return getDefaultSgMap(vectorTy);
226+
}
227+
210228
///===----------------------------------------------------------------------===///
211229
/// SGMapPropagation
212230
///===----------------------------------------------------------------------===///
@@ -360,14 +378,14 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
360378
void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
361379
ArrayRef<SGMapLattice *> operands,
362380
ArrayRef<const SGMapLattice *> results) {
363-
auto aTy = dpas.getLhsType().getElementType();
364-
auto bTy = dpas.getRhsType().getElementType();
381+
auto aTy = dpas.getLhsType();
382+
auto bTy = dpas.getRhsType();
365383
propagateIfChanged(operands[0],
366384
operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
367385
propagateIfChanged(operands[1],
368386
operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
369387
if (operands.size() > 2) {
370-
auto cTy = dpas.getAccType().getElementType();
388+
auto cTy = dpas.getAccType();
371389
propagateIfChanged(operands[2],
372390
operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
373391
}

0 commit comments

Comments
 (0)