Skip to content

Commit 442c18a

Browse files
committed
merge
1 parent b3af260 commit 442c18a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
136136

137137
// for LoadMatrixOp, the layout is attached to the property of the op
138138
if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
139-
return dyn_cast_if_present<xegpu::LayoutAttr>(loadOp.getLayoutAttr());
139+
return loadOp.getLayoutAttr();
140140

141141
// for StoreMatrixOp, the layout is attached to the property of the op
142142
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
143-
return dyn_cast_if_present<xegpu::LayoutAttr>(storeOp.getLayoutAttr());
143+
return storeOp.getLayoutAttr();
144144

145145
std::string layoutName = getLayoutName(result);
146146
if (defOp->hasAttr(layoutName))
@@ -164,10 +164,10 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
164164
Operation *op = opr.getOwner();
165165

166166
if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
167-
return dyn_cast_if_present<xegpu::LayoutAttr>(loadOp.getLayoutAttr());
167+
return loadOp.getLayoutAttr();
168168

169169
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
170-
return dyn_cast_if_present<xegpu::LayoutAttr>(storeOp.getLayoutAttr());
170+
return storeOp.getLayoutAttr();
171171

172172
std::string layoutName = xegpu::getLayoutName(opr);
173173
if (op->hasAttr(layoutName))
@@ -199,6 +199,7 @@ void xegpu::setDistributeLayoutAttrs(
199199
op->walk([&](Operation *nestOp) {
200200
if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
201201
return;
202+
202203
for (OpOperand &opr : nestOp->getOpOperands()) {
203204
auto layout = getLayoutImpl(opr.get());
204205
setDistributeLayoutAttr(opr, layout);
@@ -471,5 +472,4 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
471472
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
472473
}
473474
return results;
474-
return {};
475475
}

0 commit comments

Comments
 (0)