Skip to content

Commit ab448a3

Browse files
committed
add scf type conversion util
1 parent 3f73fda commit ab448a3

File tree

4 files changed

+215
-14
lines changed

4 files changed

+215
-14
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ std::string getLayoutName(OpOperand &opr);
6565
/// Retrieves the name for the LayoutAttr associated with a given OpResult.
6666
std::string getLayoutName(OpResult res);
6767

68+
/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
69+
/// cannot carry the layout attribute, they are converted into RankedTensorType
70+
/// first, which will convert back to VectorType in the second round.
71+
void doSCFStructuralTypeConversionWithTensorType(Operation *op);
72+
6873
} // namespace xegpu
6974

7075
} // namespace mlir

mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,33 @@ class XeGPUInstructionlizePass final
3838
void runOnOperation() override;
3939

4040
private:
41-
SmallVector<int64_t> getTileShape(TypedValue<ShapedType> value) const;
41+
// Get the tile shape for a given value. If the value has a layout
42+
// attribute and it is an SG layout, return the inst_data as the tile shape
43+
// if inst_data is available; otherwise, return the original shape of the
44+
// value. If the value does not have an SG layout, return std::nullopt.
45+
std::optional<SmallVector<int64_t>>
46+
getTileShape(TypedValue<ShapedType> value) const;
47+
48+
// Get the tile shape for a given operation.
4249
std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
50+
51+
// Determine if the operation requires unrolling. Return false if all operands
52+
// and results have tile shapes identical to their original types. Otherwise,
53+
// return true.
4354
bool needsUnroll(Operation *op) const;
4455
};
4556
} // namespace
4657

47-
SmallVector<int64_t>
58+
std::optional<SmallVector<int64_t>>
4859
XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
4960
assert(value && "value must be non-null");
5061
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
5162
if (layout && layout.isSgLayout()) {
5263
if (auto inst_data = layout.getInstData())
5364
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
65+
return llvm::to_vector(value.getType().getShape());
5466
}
55-
return llvm::to_vector(value.getType().getShape());
67+
return std::nullopt;
5668
}
5769

5870
std::optional<SmallVector<int64_t>>
@@ -67,38 +79,39 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
6779
if (isa<xegpu::DpasOp>(op)) {
6880
auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
6981
auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
70-
SmallVector<int64_t> aTileShape = getTileShape(a);
71-
SmallVector<int64_t> bTileShape = getTileShape(b);
82+
std::optional<SmallVector<int64_t>> aTile = getTileShape(a);
83+
std::optional<SmallVector<int64_t>> bTile = getTileShape(b);
7284

73-
if (aTileShape.size() != 2 || bTileShape.size() != 2)
85+
if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
7486
return std::nullopt;
7587

7688
// semantic check for A and B
77-
if (aTileShape[1] != bTileShape[0])
89+
if ((*aTile)[1] != (*bTile)[0])
7890
return std::nullopt;
7991

8092
// semantic check for C
8193
if (op->getNumOperands() == 3) {
8294
auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
83-
SmallVector<int64_t> cTileShape = getTileShape(c);
84-
int64_t expectedShape[2] = {aTileShape[0], bTileShape[1]};
85-
if (!llvm::equal(cTileShape, expectedShape))
95+
std::optional<SmallVector<int64_t>> cTile = getTileShape(c);
96+
int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
97+
if (!cTile || !llvm::equal(*cTile, expectedCTile))
8698
return std::nullopt;
8799
}
88100

89-
return SmallVector<int64_t>({aTileShape[0], aTileShape[1], bTileShape[1]});
101+
return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
90102
}
91103
return std::nullopt;
92104
}
93105

94106
bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
95107
for (Value opr : op->getOperands()) {
96108
if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
97-
auto tileShape = getTileShape(value);
109+
std::optional<SmallVector<int64_t>> tileShape = getTileShape(value);
98110
// the tile should have the same rank as the origial type
99-
if (tileShape.size() != static_cast<size_t>(value.getType().getRank()))
111+
if (!tileShape ||
112+
tileShape->size() != static_cast<size_t>(value.getType().getRank()))
100113
return false;
101-
if (!llvm::equal(tileShape, value.getType().getShape()))
114+
if (!llvm::equal(*tileShape, value.getType().getShape()))
102115
return true;
103116
}
104117
}

mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ add_mlir_dialect_library(MLIRXeGPUUtils
66

77
LINK_LIBS PUBLIC
88
MLIRIR
9+
MLIRSCFTransforms
910
MLIRXeGPUDialect
1011
)

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
14+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
15+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1416
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1517
#include "mlir/IR/Operation.h"
1618
#include "mlir/Interfaces/LoopLikeInterface.h"
19+
#include "mlir/Transforms/DialectConversion.h"
1720
#include "llvm/Support/FormatVariadic.h"
1821
#include <cstdint>
1922
#include <numeric>
@@ -127,3 +130,182 @@ std::string xegpu::getLayoutName(OpResult res) {
127130
const StringRef prefix = "layout_result_";
128131
return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
129132
}
133+
134+
void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
135+
MLIRContext *context = op->getContext();
136+
137+
auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
138+
Location loc) -> Value {
139+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
140+
.getResult(0);
141+
};
142+
143+
{ // convert VectorType to RankedTensorType for SCF Structural ops
144+
TypeConverter converter;
145+
converter.addConversion([&](Type type) -> Type { return type; });
146+
converter.addConversion([&](VectorType type) -> Type {
147+
return RankedTensorType::get(type.getShape(), type.getElementType());
148+
});
149+
converter.addSourceMaterialization(materializeCast);
150+
converter.addTargetMaterialization(materializeCast);
151+
152+
mlir::ConversionTarget target(*context);
153+
target.addLegalOp<UnrealizedConversionCastOp>();
154+
155+
mlir::RewritePatternSet patterns(context);
156+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
157+
target);
158+
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
159+
}
160+
161+
{ // propagate the layout attribute to RankedTensorType by checking
162+
// BuiltInUnrealizedCastOps
163+
// for VectorType to RankedTensorType cast.
164+
op->walk([&](UnrealizedConversionCastOp castOp) {
165+
if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
166+
return WalkResult::skip();
167+
168+
Value input = castOp.getInputs()[0];
169+
Value result = castOp.getResults()[0];
170+
auto inputTy = dyn_cast<VectorType>(input.getType());
171+
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
172+
173+
// Only look at ops casting from VectorType to RankedTensorType
174+
if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
175+
return WalkResult::skip();
176+
177+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
178+
if (!layout)
179+
return WalkResult::skip();
180+
181+
RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
182+
result.setType(newTy);
183+
184+
// update the arguments if user is a LoopLike op.
185+
for (OpOperand &use : result.getUses()) {
186+
if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
187+
BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
188+
arg.setType(newTy);
189+
}
190+
// whileOp has two regions, the BlockArgument of the after region
191+
// is not exposed by LoopLikeOpInterface
192+
if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
193+
unsigned idx = use.getOperandNumber();
194+
BlockArgument arg = whileOp.getAfterArguments()[idx];
195+
arg.setType(newTy);
196+
}
197+
}
198+
return WalkResult::advance();
199+
});
200+
201+
// using yieldOp as anchor to update the result type of its ParentOp
202+
op->walk([&](scf::YieldOp yieldOp) {
203+
Operation *parentOp = yieldOp->getParentOp();
204+
for (OpResult r : parentOp->getOpResults()) {
205+
unsigned idx = r.getResultNumber();
206+
Type resultTy = r.getType();
207+
Type yieldTy = yieldOp.getResults()[idx].getType();
208+
if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
209+
r.setType(yieldTy);
210+
}
211+
});
212+
}
213+
214+
{ // perform the conversion from RankedTensorType to VectorType based on the
215+
// LayoutAttr
216+
217+
auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
218+
DenseI32ArrayAttr sgDataAttr,
219+
DenseI32ArrayAttr sgLayoutAttr) {
220+
SmallVector<int64_t> tileShape;
221+
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
222+
if (sgDataAttr)
223+
tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
224+
else
225+
tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
226+
assert(tileShape.size() && "failed to compute tileShape");
227+
SmallVector<int64_t> distUnit =
228+
computeElementwiseMul(sgLayout, tileShape);
229+
int count = computeProduct(shape) / computeProduct(distUnit);
230+
return std::make_pair(tileShape, count);
231+
};
232+
233+
TypeConverter converter;
234+
converter.addConversion([&](Type type) -> Type { return type; });
235+
converter.addConversion(
236+
[&](RankedTensorType type,
237+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
238+
ArrayRef<int64_t> shape = type.getShape();
239+
auto encoding = type.getEncoding();
240+
Type elemTy = type.getElementType();
241+
242+
// init count and subShape to the default value. If the LayoutAttr
243+
// is not present, it will return a VectorType with original shape.
244+
int count = 1;
245+
SmallVector<int64_t> subShape(shape);
246+
247+
if (auto layout =
248+
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) {
249+
if (layout.isWgLayout()) {
250+
// for WgToSg, the subShape is either from sgData or computed as
251+
// shape/sgLayout
252+
std::tie(subShape, count) = computeTileShapeAndCount(
253+
shape, layout.getSgData(), layout.getSgLayout());
254+
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
255+
// for unrolling, the subShape is determined by inst_data
256+
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
257+
count = computeProduct(shape) / computeProduct(subShape);
258+
}
259+
}
260+
auto newTy = VectorType::get(subShape, elemTy);
261+
result.append(count, newTy);
262+
return success();
263+
});
264+
265+
converter.addConversion(
266+
[&](xegpu::TensorDescType type,
267+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
268+
MLIRContext *ctx = type.getContext();
269+
Type elemTy = type.getElementType();
270+
Attribute encoding = type.getEncoding();
271+
ArrayRef<int64_t> shape = type.getShape();
272+
273+
// init count and newTy to the default value. If the layout attribute
274+
// is not present, it will return the original type.
275+
int count = 1;
276+
Type newTy = type;
277+
278+
if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
279+
SmallVector<int64_t> subShape, distUnit;
280+
if (layout.isWgLayout()) {
281+
// for WgToSg, the subShape is either from sgData or computed as
282+
// shape/sgLayout
283+
std::tie(subShape, count) = computeTileShapeAndCount(
284+
shape, layout.getSgData(), layout.getSgLayout());
285+
layout = layout.dropSgLayoutAndData();
286+
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
287+
// for unrolling, the subShape is determined by inst_data
288+
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
289+
count = computeProduct(shape) / computeProduct(subShape);
290+
layout = layout.dropInstData();
291+
}
292+
newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
293+
layout);
294+
}
295+
296+
result.append(count, newTy);
297+
return success();
298+
});
299+
300+
converter.addSourceMaterialization(materializeCast);
301+
converter.addTargetMaterialization(materializeCast);
302+
303+
mlir::ConversionTarget target(*context);
304+
target.addLegalOp<UnrealizedConversionCastOp>();
305+
306+
mlir::RewritePatternSet patterns(context);
307+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
308+
target);
309+
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
310+
}
311+
}

0 commit comments

Comments
 (0)