1313#include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1414#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1515#include " mlir/IR/Operation.h"
16+ #include " mlir/Interfaces/LoopLikeInterface.h"
1617#include " llvm/Support/FormatVariadic.h"
1718#include < cstdint>
1819#include < numeric>
@@ -88,25 +89,29 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
8889
8990xegpu::LayoutAttr xegpu::getLayoutAttr (Value value) {
9091 if (!value)
91- return LayoutAttr () ;
92+ return nullptr ;
9293
9394 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType ()))
9495 return tdescTy.getLayoutAttr ();
9596
9697 if (auto result = dyn_cast<OpResult>(value)) {
9798 Operation *defOp = result.getDefiningOp ();
9899 assert (defOp && " result must have a defining op" );
100+
101+ // for LoadNdOp, the layout is stored in the tensor descriptor
102+ if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
103+ return getLayoutAttr (loadNd.getTensorDesc ());
104+
99105 std::string layoutName = getLayoutName (result);
100106 if (defOp->hasAttr (layoutName))
101107 return defOp->getAttrOfType <xegpu::LayoutAttr>(layoutName);
102108 }
103109
104110 if (auto arg = dyn_cast<BlockArgument>(value)) {
105111 auto parentOp = arg.getOwner ()->getParentOp ();
106- if (auto funcOp = dyn_cast<FuncOp>(parentOp)) {
107- std::string layoutName = getLayoutName (arg);
108- if (funcOp->hasAttr (layoutName))
109- return funcOp->getAttrOfType <xegpu::LayoutAttr>(layoutName);
112+ if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
113+ OpOperand *tiedInit = loop.getTiedLoopInit (arg);
114+ return getLayoutAttr (tiedInit->get ());
110115 }
111116 }
112117
@@ -122,4 +127,3 @@ std::string xegpu::getLayoutName(OpResult res) {
122127 const StringRef prefix = " layout_result_" ;
123128 return llvm::formatv (" {0}{1}" , prefix, res.getResultNumber ()).str ();
124129}
125-
0 commit comments