Skip to content

Commit 8589fa2

Browse files
committed
save work
1 parent 141c551 commit 8589fa2

File tree

3 files changed

+61
-31
lines changed

3 files changed

+61
-31
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ LayoutAttr getLayoutAttr(const Value value);
7676
/// it will check the operand itself and its defining op.
7777
LayoutAttr getLayoutAttr(const OpOperand &opr);
7878

79+
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
80+
template <typename T,
81+
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
82+
std::is_same_v<T, OpResult>>>
83+
void removeLayoutAttr(const T &operandOrResult);
84+
85+
/// Removes the LayoutAttr for each OpOperand and OpResult of the given
86+
/// operation if they exist. If the operation contains regions, it is also
87+
/// applied recursively to the contained operations
88+
void removeLayoutAttrs(Operation *op);
89+
7990
/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
8091
/// it to the owner's dictionary attributes
8192
template <typename T,

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,6 @@ static Value resolveDistributedTy(Value orig, T expected,
136136
return orig;
137137
}
138138

139-
/// Helper function to filter out the temporary layout attributes attached
140-
/// during the layout assignment process. These are not needed after going to
141-
/// SIMT.
142-
static SmallVector<NamedAttribute>
143-
removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
144-
SmallVector<NamedAttribute> newAttrs;
145-
for (NamedAttribute attr : attrs) {
146-
if (!isa<xegpu::LayoutAttr>(attr.getValue()))
147-
newAttrs.push_back(attr);
148-
}
149-
return newAttrs;
150-
}
151-
152139
/// Helper function to check if the layout is packed. Layout is packed if it is
153140
/// 2D and lane_data[0] != 1 (data packed from col dimension).
154141
static bool hasPackedLayout(xegpu::LayoutAttr layout) {
@@ -412,9 +399,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
412399
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
413400
distributedTensorDescTy, rewriter));
414401

415-
rewriter.create<xegpu::StoreNdOp>(
416-
newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
417-
removeTemporaryLayoutAttributes(storeOp->getAttrs()));
402+
auto newStoreOp = rewriter.create<xegpu::StoreNdOp>(
403+
newWarpOp.getLoc(), TypeRange{}, newStoreOperands, storeOp->getAttrs());
404+
xegpu::removeLayoutAttrs(newStoreOp);
418405
rewriter.eraseOp(storeOp);
419406
return success();
420407
}
@@ -508,7 +495,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
508495
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
509496
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
510497
distributedTensorDescTy, rewriter),
511-
removeTemporaryLayoutAttributes(loadOp->getAttrs()));
498+
loadOp->getAttrs());
499+
xegpu::removeLayoutAttrs(newLoadOp);
512500
// Set the packed attribute if the layout requires it.
513501
newLoadOp.setPacked(hasPackedLayout(layout));
514502
Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -639,14 +627,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
639627
resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
640628
newDpasOperandExpectedTypes[i], rewriter));
641629
}
642-
Value newDpasOp = rewriter.create<xegpu::DpasOp>(
643-
newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
644-
removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
630+
auto newDpasOp =
631+
rewriter.create<xegpu::DpasOp>(newWarpOp->getLoc(), distributedResultTy,
632+
newDpasOperands, dpasOp->getAttrs());
633+
xegpu::removeLayoutAttrs(newDpasOp);
645634
Value distributedVal = newWarpOp.getResult(operandIdx);
646635
// Resolve the output type.
647-
newDpasOp = resolveDistributedTy(
648-
newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
649-
rewriter.replaceAllUsesWith(distributedVal, newDpasOp);
636+
Value typeResolved =
637+
resolveDistributedTy(newDpasOp.getResult(),
638+
distResultTypeByWarpOpOrFailure.value(), rewriter);
639+
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
650640
return success();
651641
}
652642
};
@@ -726,14 +716,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
726716
}
727717
}
728718
// Create a new update op outside the warp op.
729-
Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
719+
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
730720
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
731-
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
721+
updateOp->getAttrs());
722+
xegpu::removeLayoutAttrs(newUpdateOp);
732723
Value distributedVal = newWarpOp.getResult(operandIdx);
733724
// Resolve the distributed type with the original type.
734-
newUpdateOp =
735-
resolveDistributedTy(newUpdateOp, distributedVal.getType(), rewriter);
736-
rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
725+
Value typeResolved = resolveDistributedTy(
726+
newUpdateOp.getResult(), distributedVal.getType(), rewriter);
727+
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
737728
return success();
738729
}
739730
};
@@ -792,9 +783,10 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
792783
rewriter.setInsertionPointAfter(newWarpOp);
793784
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
794785
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
795-
rewriter.create<xegpu::PrefetchNdOp>(
796-
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
797-
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
786+
rewriter.create<xegpu::PrefetchNdOp>(newWarpOp.getLoc(), TypeRange{},
787+
newPrefetchOperands,
788+
prefetchOp->getAttrs());
789+
xegpu::removeLayoutAttrs(prefetchOp);
798790
rewriter.eraseOp(prefetchOp);
799791
return success();
800792
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,33 @@ void xegpu::setLayoutAttrs(Operation *op,
184184
});
185185
}
186186

187+
template <typename T, typename>
188+
void xegpu::removeLayoutAttr(const T &operandOrResult) {
189+
Operation *owner = operandOrResult.getOwner();
190+
std::string name = xegpu::getLayoutName(operandOrResult);
191+
if (owner->hasAttrOfType<LayoutAttr>(name))
192+
owner->removeAttr(name);
193+
}
194+
195+
// Explicit instantiation for OpResult
196+
template void
197+
xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
198+
199+
// Explicit instantiation for OpOperand
200+
template void
201+
xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
202+
203+
void xegpu::removeLayoutAttrs(Operation *op) {
204+
op->walk([&](Operation *nestOp) {
205+
for (OpOperand &opr : nestOp->getOpOperands()) {
206+
removeLayoutAttr(opr);
207+
}
208+
for (OpResult result : nestOp->getOpResults()) {
209+
removeLayoutAttr(result);
210+
}
211+
});
212+
}
213+
187214
SmallVector<Value>
188215
xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
189216
Value value, ArrayRef<int64_t> shape) {

0 commit comments

Comments
 (0)