Skip to content

Commit 6ec3604

Browse files
committed
cleanup layout attr
1 parent e2eb9e6 commit 6ec3604

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,39 @@ using namespace mlir;
3232

3333
namespace {
3434

35+
void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
36+
ValueRange inputs = castOp.getInputs();
37+
ValueRange outputs = castOp.getOutputs();
38+
39+
if (inputs.size() == 1 && outputs.size() == 1) {
40+
castOp->replaceAllUsesWith(inputs);
41+
castOp->erase();
42+
}
43+
44+
VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
45+
VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
46+
if (inputTy && outputTy) {
47+
OpBuilder builder(castOp);
48+
// unpack
49+
if (inputs.size() > 1 && outputs.size() == 1) {
50+
ArrayRef<int64_t> shape = outputTy.getShape();
51+
Value result = xegpu::createVectorWithShapeFromValues(
52+
builder, castOp.getLoc(), inputs, shape);
53+
castOp->replaceAllUsesWith(ValueRange(result));
54+
castOp->erase();
55+
}
56+
57+
// pack
58+
if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
59+
ArrayRef<int64_t> tileShape = outputTy.getShape();
60+
SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
61+
builder, castOp.getLoc(), inputs[0], tileShape);
62+
castOp->replaceAllUsesWith(results);
63+
castOp->erase();
64+
}
65+
}
66+
}
67+
3568
/// Unroll XeGPU ops to their instruction-level representation.
3669
class XeGPUInstructionlizePass final
3770
: public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
@@ -200,35 +233,22 @@ void XeGPUInstructionlizePass::runOnOperation() {
200233
populateXeGPUUnrollPatterns(patterns, options);
201234
(void)applyPatternsGreedily(mod, std::move(patterns));
202235

203-
mod->walk([&](UnrealizedConversionCastOp castOp) {
204-
ValueRange inputs = castOp.getInputs();
205-
ValueRange outputs = castOp.getOutputs();
236+
mod->walk([&](Operation *op) {
237+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
238+
resolveUnrealizedConversionCastOp(castOp);
206239

207-
if (inputs.size() == 1 && outputs.size() == 1) {
208-
castOp->replaceAllUsesWith(inputs);
209-
castOp->erase();
240+
for (OpOperand &opr : op->getOpOperands()) {
241+
std::string name = xegpu::getLayoutName(opr);
242+
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name))
243+
op->removeAttr(name);
210244
}
211245

212-
VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
213-
VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
214-
if (inputTy && outputTy) {
215-
OpBuilder builder(castOp);
216-
// unpack
217-
if (inputs.size() > 1 && outputs.size() == 1) {
218-
ArrayRef<int64_t> shape = outputTy.getShape();
219-
Value result = xegpu::createVectorWithShapeFromValues(
220-
builder, castOp.getLoc(), inputs, shape);
221-
castOp->replaceAllUsesWith(ValueRange(result));
222-
castOp->erase();
223-
}
224-
225-
// pack
226-
if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
227-
ArrayRef<int64_t> tileShape = outputTy.getShape();
228-
SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
229-
builder, castOp.getLoc(), inputs[0], tileShape);
230-
castOp->replaceAllUsesWith(results);
231-
castOp->erase();
246+
for (OpResult result : op->getOpResults()) {
247+
std::string name = xegpu::getLayoutName(result);
248+
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
249+
op->removeAttr(name);
250+
if (!isa<LoopLikeOpInterface>(op))
251+
xegpu::setLayoutAttr(result, layout.dropInstData());
232252
}
233253
}
234254
});

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
115115
if (!value)
116116
return nullptr;
117117

118-
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType()))
118+
if (auto tdescTy =
119+
dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
119120
return tdescTy.getLayoutAttr();
120121

121122
if (auto result = dyn_cast<OpResult>(value)) {
@@ -366,7 +367,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
366367
Type newTy = type;
367368

368369
if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
369-
SmallVector<int64_t> subShape, distUnit;
370+
SmallVector<int64_t> subShape(shape);
370371
if (layout.isWgLayout()) {
371372
// for WgToSg, the subShape is either from sgData or computed as
372373
// shape/sgLayout
@@ -378,6 +379,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
378379
count = computeProduct(shape) / computeProduct(subShape);
379380
layout = layout.dropInstData();
380381
}
382+
381383
newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
382384
layout);
383385
}

0 commit comments

Comments
 (0)