Skip to content

Commit 387ac93

Browse files
committed
refactor
1 parent 061b6e0 commit 387ac93

File tree

3 files changed

+88
-108
lines changed

3 files changed

+88
-108
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class OpOperand;
1717
class OpResult;
1818
class OpBuilder;
1919
class ValueRange;
20+
class TypeConverter;
2021

2122
namespace xegpu {
2223
class LayoutAttr;
@@ -96,10 +97,12 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
9697
ValueRange values,
9798
ArrayRef<int64_t> shape);
9899

99-
/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
100-
/// cannot carry the layout attribute, they are converted into RankedTensorType
101-
/// first, which will convert back to VectorType in the second round.
102-
void doSCFStructuralTypeConversionWithTensorType(Operation *op);
100+
/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type
101+
/// convertion patterns. Since VectorType cannot carry the layout attribute, which is
102+
/// needed to guide the type conversion for XeGPU, they are first converted into
103+
/// RankedTensorType, where the layout attribute can be attached. And then upstream
104+
/// SCF structural type conversion patterns are applied with the provided converter.
105+
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter);
103106

104107
} // namespace xegpu
105108

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Interfaces/LoopLikeInterface.h"
1717
#include "mlir/Pass/Pass.h"
1818
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Transforms/DialectConversion.h"
1920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021

2122
namespace mlir {
@@ -207,7 +208,63 @@ void XeGPUBlockingPass::runOnOperation() {
207208
xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
208209

209210
// Perform type conversion for SCF control folow ops
210-
xegpu::doSCFStructuralTypeConversionWithTensorType(mod);
211+
TypeConverter converter;
212+
converter.addConversion([&](Type type) -> Type { return type; });
213+
converter.addConversion(
214+
[&](RankedTensorType type,
215+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
216+
Type elemTy = type.getElementType();
217+
ArrayRef<int64_t> shape = type.getShape();
218+
219+
// init count and subShape to the default value. If the LayoutAttr
220+
// is not present, it will return a VectorType with original shape.
221+
int count = 1;
222+
SmallVector<int64_t> subShape(shape);
223+
if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding())) {
224+
if (layout.isWgLayout())
225+
return failure();
226+
if (DenseI32ArrayAttr instData = layout.getInstData()) {
227+
// for unrolling, the subShape is determined by inst_data
228+
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
229+
count = computeProduct(shape) / computeProduct(subShape);
230+
}
231+
}
232+
auto newTy = VectorType::get(subShape, elemTy);
233+
result.append(count, newTy);
234+
return success();
235+
});
236+
237+
converter.addConversion(
238+
[&](xegpu::TensorDescType type,
239+
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
240+
MLIRContext *ctx = type.getContext();
241+
Type elemTy = type.getElementType();
242+
Attribute encoding = type.getEncoding();
243+
ArrayRef<int64_t> shape = type.getShape();
244+
245+
// init count and newTy to the default value. If the layout attribute
246+
// is not present, it will return the original type.
247+
int count = 1;
248+
SmallVector<int64_t> subShape(shape);
249+
250+
xegpu::LayoutAttr layout = type.getLayoutAttr();
251+
252+
if (layout) {
253+
if (layout.isWgLayout())
254+
return failure();
255+
256+
if (DenseI32ArrayAttr instData = layout.getInstData()) {
257+
// for unrolling, the subShape is determined by inst_data
258+
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
259+
count = computeProduct(shape) / computeProduct(subShape);
260+
layout = layout.dropInstData();
261+
}
262+
}
263+
auto newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
264+
result.append(count, newTy);
265+
return success();
266+
});
267+
xegpu::doSCFStructuralTypeConversionWithTensorType(mod, converter);
211268

212269
xegpu::UnrollOptions options;
213270
options.setFilterConstraint([&](Operation *op) -> LogicalResult {

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 23 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
225225
return result;
226226
}
227227

228-
void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
228+
void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter) {
229229
MLIRContext *context = op->getContext();
230230

231231
auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
@@ -307,109 +307,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
307307

308308
{ // perform the conversion from RankedTensorType to VectorType based on the
309309
// LayoutAttr
310-
auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
311-
DenseI32ArrayAttr sgDataAttr,
312-
DenseI32ArrayAttr sgLayoutAttr) {
313-
SmallVector<int64_t> tileShape;
314-
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
315-
if (sgDataAttr)
316-
tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
317-
else
318-
tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
319-
assert(tileShape.size() && "failed to compute tileShape");
320-
SmallVector<int64_t> distUnit =
321-
computeElementwiseMul(sgLayout, tileShape);
322-
int count = computeProduct(shape) / computeProduct(distUnit);
323-
return std::make_pair(tileShape, count);
324-
};
325-
326-
TypeConverter converter;
327-
converter.addConversion([&](Type type) -> Type { return type; });
328-
converter.addConversion(
329-
[&](RankedTensorType type,
330-
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
331-
ArrayRef<int64_t> shape = type.getShape();
332-
auto encoding = type.getEncoding();
333-
Type elemTy = type.getElementType();
334-
335-
// init count and subShape to the default value. If the LayoutAttr
336-
// is not present, it will return a VectorType with original shape.
337-
int count = 1;
338-
SmallVector<int64_t> subShape(shape);
339-
340-
if (auto layout =
341-
llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) {
342-
if (layout.isWgLayout()) {
343-
// for WgToSg, the subShape is either from sgData or computed as
344-
// shape/sgLayout
345-
std::tie(subShape, count) = computeTileShapeAndCount(
346-
shape, layout.getSgData(), layout.getSgLayout());
347-
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
348-
// for unrolling, the subShape is determined by inst_data
349-
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
350-
count = computeProduct(shape) / computeProduct(subShape);
351-
}
352-
}
353-
auto newTy = VectorType::get(subShape, elemTy);
354-
result.append(count, newTy);
355-
return success();
356-
});
357-
358-
converter.addConversion(
359-
[&](xegpu::TensorDescType type,
360-
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
361-
MLIRContext *ctx = type.getContext();
362-
Type elemTy = type.getElementType();
363-
Attribute encoding = type.getEncoding();
364-
ArrayRef<int64_t> shape = type.getShape();
365-
366-
// init count and newTy to the default value. If the layout attribute
367-
// is not present, it will return the original type.
368-
int count = 1;
369-
Type newTy = type;
370-
371-
if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
372-
SmallVector<int64_t> subShape(shape);
373-
if (layout.isWgLayout()) {
374-
// for WgToSg, the subShape is either from sgData or computed as
375-
// shape/sgLayout
376-
std::tie(subShape, count) = computeTileShapeAndCount(
377-
shape, layout.getSgData(), layout.getSgLayout());
378-
layout = layout.dropSgLayoutAndData();
379-
} else if (DenseI32ArrayAttr instData = layout.getInstData()) {
380-
// for unrolling, the subShape is determined by inst_data
381-
subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
382-
count = computeProduct(shape) / computeProduct(subShape);
383-
layout = layout.dropInstData();
384-
}
385-
386-
newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
387-
layout);
388-
}
389-
390-
result.append(count, newTy);
391-
return success();
392-
});
393-
394-
converter.addSourceMaterialization(materializeCast);
395-
converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
396-
ValueRange inputs, Location loc) {
397-
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
398-
.getResults();
399-
});
400-
401-
mlir::ConversionTarget target(*context);
402-
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
403-
[&](UnrealizedConversionCastOp op) {
404-
auto isTensorTy = [&](Type type) {
405-
return isa<RankedTensorType>(type);
406-
};
407-
if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
408-
llvm::any_of(op->getResultTypes(), isTensorTy))
409-
return false;
410-
return true;
411-
});
412310

311+
// Handle the UnrealizedConversionCastOp introduced by the first step.
312+
// For vector->RankedTensorType, it will simply forward the inputs.
313+
// For RankedTensorType->vector, it will update the inputs with the
314+
// one from the adaptor.
413315
class UnrealizedConversionCastOpPattern
414316
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
415317
using OpConversionPattern<
@@ -444,6 +346,24 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
444346
}
445347
};
446348

349+
converter.addSourceMaterialization(materializeCast);
350+
converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
351+
ValueRange inputs, Location loc) {
352+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
353+
.getResults();
354+
});
355+
356+
mlir::ConversionTarget target(*context);
357+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
358+
[&](UnrealizedConversionCastOp op) {
359+
auto isTensorTy = [&](Type type) {
360+
return isa<RankedTensorType>(type);
361+
};
362+
if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
363+
llvm::any_of(op->getResultTypes(), isTensorTy))
364+
return false;
365+
return true;
366+
});
447367
mlir::RewritePatternSet patterns(context);
448368
patterns.insert<UnrealizedConversionCastOpPattern>(context);
449369
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,

0 commit comments

Comments
 (0)