Skip to content

Commit e8b43fb

Browse files
committed
add filter
1 parent af01c99 commit e8b43fb

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,15 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
295295
}
296296

297297
LayoutAttr dropSgLayoutAndData() {
298+
if (!getInstData() && !getLaneLayout())
299+
return nullptr;
298300
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
299301
getLaneLayout(), getLaneData(), getOrder());
300302
}
301303

302304
LayoutAttr dropInstData() {
305+
if (!getSgLayout() && !getLaneLayout())
306+
return nullptr;
303307
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
304308
getLaneLayout(), getLaneData(), getOrder());
305309
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1414
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1515
#include "mlir/IR/Operation.h"
16+
#include "mlir/Interfaces/LoopLikeInterface.h"
1617
#include "llvm/Support/FormatVariadic.h"
1718
#include <cstdint>
1819
#include <numeric>
@@ -88,25 +89,29 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
8889

8990
xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
9091
if (!value)
91-
return LayoutAttr();
92+
return nullptr;
9293

9394
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType()))
9495
return tdescTy.getLayoutAttr();
9596

9697
if (auto result = dyn_cast<OpResult>(value)) {
9798
Operation *defOp = result.getDefiningOp();
9899
assert(defOp && "result must have a defining op");
100+
101+
// for LoadNdOp, the layout is stored in the tensor descriptor
102+
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
103+
return getLayoutAttr(loadNd.getTensorDesc());
104+
99105
std::string layoutName = getLayoutName(result);
100106
if (defOp->hasAttr(layoutName))
101107
return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
102108
}
103109

104110
if (auto arg = dyn_cast<BlockArgument>(value)) {
105111
auto parentOp = arg.getOwner()->getParentOp();
106-
if (auto funcOp = dyn_cast<FuncOp>(parentOp)) {
107-
std::string layoutName = getLayoutName(arg);
108-
if (funcOp->hasAttr(layoutName))
109-
return funcOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
112+
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
113+
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
114+
return getLayoutAttr(tiedInit->get());
110115
}
111116
}
112117

@@ -122,4 +127,3 @@ std::string xegpu::getLayoutName(OpResult res) {
122127
const StringRef prefix = "layout_result_";
123128
return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
124129
}
125-

0 commit comments

Comments
 (0)