Skip to content

Commit 0e34f36

Browse files
committed
refine
1 parent ad5d0a8 commit 0e34f36

File tree

6 files changed

+40
-26
lines changed

6 files changed

+40
-26
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,21 @@ std::string getLayoutName(const OpResult result);
7373
/// Returns nullptr if no DistributeLayoutAttr is found.
7474
DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
7575

76+
template <typename AttrTy>
77+
AttrTy getDistributeLayoutAttrOfType(const Value value) {
78+
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
79+
}
80+
7681
/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
7782
/// first check the operand_layout_{id} of the owner operation. If not found,
7883
/// it will check the operand itself and its defining op.
7984
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
8085

86+
template <typename AttrTy>
87+
AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) {
88+
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(opr));
89+
}
90+
8191
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
8292
template <typename T,
8393
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
@@ -94,13 +104,14 @@ void removeLayoutAttrs(Operation *op);
94104
template <typename T,
95105
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
96106
std::is_same_v<T, OpResult>>>
97-
void setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout);
107+
void setDistributeLayoutAttr(const T &operandOrResult,
108+
const DistributeLayoutAttr layout);
98109

99110
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
100111
/// If the operation contains regions, it is also applied recursively to the
101112
/// contained operations
102-
void setLayoutAttrs(Operation *op,
103-
function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
113+
void setDistributeLayoutAttrs(
114+
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
104115

105116
/// Extract a set of small vectors from a value with a given shape using
106117
/// vector.extract_stride_slice

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ void XeGPUBlockingPass::runOnOperation() {
247247
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
248248
// This ensures that the LayoutAttr remains accessible even if the defining
249249
// operation is replaced.
250-
xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
250+
xegpu::setDistributeLayoutAttrs(
251+
op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
251252

252253
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
253254
xegpu::LayoutAttr layout) {
@@ -377,7 +378,7 @@ void XeGPUBlockingPass::runOnOperation() {
377378
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
378379
op->removeAttr(name);
379380
if (!isa<LoopLikeOpInterface>(op))
380-
xegpu::setLayoutAttr(result, layout.dropInstData());
381+
xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
381382
}
382383
}
383384

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
718718
}
719719
// If the result is a vector type, add a temporary layout attribute to the
720720
// op.
721-
xegpu::setLayoutAttr(result, layout);
721+
xegpu::setDistributeLayoutAttr(result, layout);
722722
}
723723
return success();
724724
}
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
800800
// If the type is a vector type and this region argument is an OpResult,
801801
// set the layout attribute on the OpResult.
802802
if (auto result = dyn_cast<OpResult>(successorInput))
803-
xegpu::setLayoutAttr(result, successorOperandLayout);
803+
xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
804804
}
805805
}
806806
return success();

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,14 +841,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
841841
if (!isa<VectorType>(operand.get().getType()))
842842
continue;
843843

844-
auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(operand));
844+
auto layout =
845+
xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
845846
if (!layout) {
846847
op->emitError("Could not find layout attribute for operand ")
847848
<< operand.getOperandNumber() << " of operation " << op->getName();
848849
signalPassFailure();
849850
return;
850851
}
851-
xegpu::setLayoutAttr(operand, layout);
852+
xegpu::setDistributeLayoutAttr(operand, layout);
852853
}
853854
});
854855
// Step 2: Move all operations of a GPU function inside
@@ -883,7 +884,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
883884
return AffineMap::get(val.getContext());
884885
// Get the layout of the vector type.
885886
// TODO: support more layout types
886-
auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(val));
887+
auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
887888
// If no layout is specified, assume the inner most dimension is distributed
888889
// for now.
889890
if (!layout)

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
429429
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
430430
resultTy.getElementType());
431431
tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
432-
xegpu::setLayoutAttr(cast<OpResult>(tmpC),
433-
originalLayout.dropSgLayoutAndData());
432+
xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
433+
originalLayout.dropSgLayoutAndData());
434434

435435
newDpasOps.push_back(tmpC);
436436
}
@@ -508,8 +508,8 @@ struct WgToSgVectorBroadcastOp
508508
for (auto operand : adaptor.getOperands().front()) {
509509
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
510510
newResultType, operand);
511-
xegpu::setLayoutAttr(newBroadcast->getResult(0),
512-
layout.dropSgLayoutAndData());
511+
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
512+
layout.dropSgLayoutAndData());
513513
newBroadcastOps.push_back(newBroadcast.getResult());
514514
}
515515

@@ -755,7 +755,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
755755
auto cstOp =
756756
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
757757
if (auto newLayout = layout.dropSgLayoutAndData())
758-
xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
758+
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
759759
SmallVector<Value> newConsts(count, cstOp);
760760

761761
rewriter.replaceOpWithMultiple(op, {newConsts});

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,33 +160,34 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr)
160160
}
161161

162162
template <typename T, typename>
163-
void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout) {
163+
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
164+
const DistributeLayoutAttr layout) {
164165
Operation *owner = operandOrResult.getOwner();
165166
std::string name = xegpu::getLayoutName(operandOrResult);
166167
if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
167168
owner->setAttr(name, layout);
168169
}
169170

170171
// Explicit instantiation for OpResult
171-
template void
172-
xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
173-
const mlir::xegpu::DistributeLayoutAttr layout);
172+
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
173+
const mlir::OpResult &result,
174+
const mlir::xegpu::DistributeLayoutAttr layout);
174175

175176
// Explicit instantiation for OpOperand
176-
template void
177-
xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
178-
const mlir::xegpu::DistributeLayoutAttr layout);
177+
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
178+
const mlir::OpOperand &operand,
179+
const mlir::xegpu::DistributeLayoutAttr layout);
179180

180-
void xegpu::setLayoutAttrs(Operation *op,
181-
function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
181+
void xegpu::setDistributeLayoutAttrs(
182+
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
182183
op->walk([&](Operation *nestOp) {
183184
for (OpOperand &opr : nestOp->getOpOperands()) {
184185
auto layout = getLayoutImpl(opr.get());
185-
setLayoutAttr(opr, layout);
186+
setDistributeLayoutAttr(opr, layout);
186187
}
187188
for (OpResult result : nestOp->getOpResults()) {
188189
auto layout = getLayoutImpl(result);
189-
setLayoutAttr(result, layout);
190+
setDistributeLayoutAttr(result, layout);
190191
}
191192
});
192193
}

0 commit comments

Comments
 (0)