Skip to content

Commit e2eb9e6

Browse files
committed
refactor pack and unpack
1 parent 7b5e8f1 commit e2eb9e6

File tree

4 files changed

+301
-78
lines changed

4 files changed

+301
-78
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace mlir {
1515
class VectorType;
1616
class OpOperand;
1717
class OpResult;
18+
class OpBuilder;
19+
class ValueRange;
1820

1921
namespace xegpu {
2022
class LayoutAttr;
@@ -53,17 +55,46 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
5355
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
5456
LayoutAttr layout);
5557

58+
/// Return the attribute name for the OpOperand to attach LayoutAttr
59+
std::string getLayoutName(OpOperand &opr);
60+
61+
/// Return the attribute name for the OpResult to attach LayoutAttr
62+
std::string getLayoutName(OpResult res);
63+
5664
/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
5765
/// values, the LayoutAttr is extracted from the TensorDescType itself. For
5866
/// other values, it is obtained from the attributes of the defining operation.
5967
/// Returns nullptr if no LayoutAttr is found.
6068
LayoutAttr getLayoutAttr(Value value);
6169

62-
/// Retrieves the name for the LayoutAttr associated with a given OpOperand.
63-
std::string getLayoutName(OpOperand &opr);
70+
/// Retrieves the LayoutAttr associated with a given OpOperand. It will
71+
/// first check the operand_layout_{id} of the owner operation. If not found,
72+
/// it will check the operand itself and its defining op.
73+
LayoutAttr getLayoutAttr(OpOperand &opr);
6474

65-
/// Retrieves the name for the LayoutAttr associated with a given OpResult.
66-
std::string getLayoutName(OpResult res);
75+
/// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
76+
void setLayoutAttr(OpOperand &opr, LayoutAttr layout);
77+
78+
/// Set the LayoutAttr for the given OpResult by attching it to the defining op
79+
void setLayoutAttr(OpResult result, LayoutAttr layout);
80+
81+
/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
82+
/// If the operation contains regions, it is also applied recursively to the
83+
/// contained operations
84+
void setLayoutAttrs(Operation *op,
85+
function_ref<LayoutAttr(Value)> getLayoutImpl);
86+
87+
/// Extract a set of small vectors from a value with a given shape using
88+
/// vector.extract_stride_slice
89+
SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,
90+
Location loc, Value value,
91+
ArrayRef<int64_t> shape);
92+
93+
/// Create a vector of shape from a set of values using
94+
/// vector.insert_stride_slice.
95+
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
96+
ValueRange values,
97+
ArrayRef<int64_t> shape);
6798

6899
/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
69100
/// cannot carry the layout attribute, they are converted into RankedTensorType

mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp

Lines changed: 118 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1414
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1515
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
16+
#include "mlir/Interfaces/LoopLikeInterface.h"
1617
#include "mlir/Pass/Pass.h"
1718
#include "mlir/Pass/PassManager.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -45,6 +46,10 @@ class XeGPUInstructionlizePass final
4546
std::optional<SmallVector<int64_t>>
4647
getTileShape(TypedValue<ShapedType> value) const;
4748

49+
std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;
50+
51+
std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;
52+
4853
// Get the tile shape for a given operation.
4954
std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
5055

@@ -67,20 +72,46 @@ XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
6772
return std::nullopt;
6873
}
6974

75+
std::optional<SmallVector<int64_t>>
76+
XeGPUInstructionlizePass::getTileShape(OpOperand &operand) const {
77+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
78+
if (layout && layout.isSgLayout()) {
79+
if (auto inst_data = layout.getInstData())
80+
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
81+
82+
if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
83+
return llvm::to_vector(type.getShape());
84+
}
85+
return std::nullopt;
86+
}
87+
88+
std::optional<SmallVector<int64_t>>
89+
XeGPUInstructionlizePass::getTileShape(OpResult result) const {
90+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
91+
if (layout && layout.isSgLayout()) {
92+
if (auto inst_data = layout.getInstData())
93+
return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
94+
95+
if (auto type = dyn_cast<ShapedType>(result.getType()))
96+
return llvm::to_vector(type.getShape());
97+
}
98+
return std::nullopt;
99+
}
100+
70101
std::optional<SmallVector<int64_t>>
71102
XeGPUInstructionlizePass::getTileShape(Operation *op) const {
72103
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
73-
return getTileShape(cast<TypedValue<ShapedType>>(op->getResult(0)));
104+
return getTileShape(op->getOpResult(0));
74105
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
75-
return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(0)));
106+
return getTileShape(op->getOpOperand(0));
76107
if (isa<xegpu::StoreNdOp>(op))
77-
return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(1)));
108+
return getTileShape(op->getOpOperand(1));
78109

79110
if (isa<xegpu::DpasOp>(op)) {
80-
auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
81-
auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
82-
std::optional<SmallVector<int64_t>> aTile = getTileShape(a);
83-
std::optional<SmallVector<int64_t>> bTile = getTileShape(b);
111+
std::optional<SmallVector<int64_t>> aTile =
112+
getTileShape(op->getOpOperand(0));
113+
std::optional<SmallVector<int64_t>> bTile =
114+
getTileShape(op->getOpOperand(1));
84115

85116
if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
86117
return std::nullopt;
@@ -91,8 +122,8 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
91122

92123
// semantic check for C
93124
if (op->getNumOperands() == 3) {
94-
auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
95-
std::optional<SmallVector<int64_t>> cTile = getTileShape(c);
125+
std::optional<SmallVector<int64_t>> cTile =
126+
getTileShape(op->getOpOperand(2));
96127
int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
97128
if (!cTile || !llvm::equal(*cTile, expectedCTile))
98129
return std::nullopt;
@@ -104,59 +135,101 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
104135
}
105136

106137
bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
107-
for (Value opr : op->getOperands()) {
108-
if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
109-
std::optional<SmallVector<int64_t>> tileShape = getTileShape(value);
110-
// the tile should have the same rank as the origial type
111-
if (!tileShape ||
112-
tileShape->size() != static_cast<size_t>(value.getType().getRank()))
113-
return false;
114-
if (!llvm::equal(*tileShape, value.getType().getShape()))
115-
return true;
116-
}
138+
if (isa<LoopLikeOpInterface>(op))
139+
return false;
140+
141+
for (auto &opr : op->getOpOperands()) {
142+
std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
143+
auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
144+
if (!shapedType)
145+
continue;
146+
147+
if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
148+
return true;
149+
}
150+
151+
for (auto result : op->getOpResults()) {
152+
std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
153+
auto shapedType = dyn_cast<ShapedType>(result.getType());
154+
if (!shapedType)
155+
continue;
156+
157+
if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
158+
return true;
117159
}
118160
return false;
119161
}
120162

121163
void XeGPUInstructionlizePass::runOnOperation() {
122164
MLIRContext *ctx = &getContext();
123-
Operation *op = getOperation();
165+
Operation *mod = getOperation();
166+
167+
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
168+
// This ensures that the LayoutAttr remains accessible even if the defining
169+
// operation is replaced.
170+
xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
124171

125-
// first perform type conversion for SCF control folow ops
126-
xegpu::doSCFStructuralTypeConversionWithTensorType(op);
172+
// Perform type conversion for SCF control folow ops
173+
xegpu::doSCFStructuralTypeConversionWithTensorType(mod);
127174

128175
xegpu::UnrollOptions options;
129176
options.setFilterConstraint([&](Operation *op) -> LogicalResult {
130177
return needsUnroll(op) ? success() : failure();
131178
});
132179

133-
options.setNativeShapeFn([&](Operation *op) {
134-
return getTileShape(op);
135-
});
180+
options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
136181

137-
options.setUnrolledTypesFn(
138-
[&](ShapedType type, ArrayRef<int64_t> tileShape) {
139-
Type elemTy = type.getElementType();
140-
Type newTy;
182+
options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
183+
Type elemTy = type.getElementType();
184+
Type newTy;
141185

142-
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
143-
newTy = xegpu::TensorDescType::get(
144-
ctx, tileShape, elemTy, tdescTy.getEncoding(),
145-
tdescTy.getLayoutAttr().dropInstData());
146-
else
147-
newTy = type.clone(tileShape, elemTy);
186+
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
187+
newTy = xegpu::TensorDescType::get(
188+
ctx, tileShape, elemTy, tdescTy.getEncoding(),
189+
tdescTy.getLayoutAttr().dropInstData());
190+
else
191+
newTy = type.clone(tileShape, elemTy);
148192

149-
std::optional<SmallVector<int64_t>> ratio =
150-
computeShapeRatio(type.getShape(), tileShape);
151-
assert(ratio &&
152-
"The shape of the type must be a multiple of tileShape.");
153-
return SmallVector<Type>(computeProduct(*ratio), newTy);
154-
});
155-
156-
GreedyRewriteConfig config;
157-
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
193+
std::optional<SmallVector<int64_t>> ratio =
194+
computeShapeRatio(type.getShape(), tileShape);
195+
assert(ratio && "The shape of the type must be a multiple of tileShape.");
196+
return SmallVector<Type>(computeProduct(*ratio), newTy);
197+
});
158198

159199
RewritePatternSet patterns(ctx);
160200
populateXeGPUUnrollPatterns(patterns, options);
161-
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
201+
(void)applyPatternsGreedily(mod, std::move(patterns));
202+
203+
mod->walk([&](UnrealizedConversionCastOp castOp) {
204+
ValueRange inputs = castOp.getInputs();
205+
ValueRange outputs = castOp.getOutputs();
206+
207+
if (inputs.size() == 1 && outputs.size() == 1) {
208+
castOp->replaceAllUsesWith(inputs);
209+
castOp->erase();
210+
}
211+
212+
VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
213+
VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
214+
if (inputTy && outputTy) {
215+
OpBuilder builder(castOp);
216+
// unpack
217+
if (inputs.size() > 1 && outputs.size() == 1) {
218+
ArrayRef<int64_t> shape = outputTy.getShape();
219+
Value result = xegpu::createVectorWithShapeFromValues(
220+
builder, castOp.getLoc(), inputs, shape);
221+
castOp->replaceAllUsesWith(ValueRange(result));
222+
castOp->erase();
223+
}
224+
225+
// pack
226+
if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
227+
ArrayRef<int64_t> tileShape = outputTy.getShape();
228+
SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
229+
builder, castOp.getLoc(), inputs[0], tileShape);
230+
castOp->replaceAllUsesWith(results);
231+
castOp->erase();
232+
}
233+
}
234+
});
162235
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Utils/IndexingUtils.h"
1818
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1919
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
20+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
2021
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2122
#include "llvm/ADT/STLExtras.h"
2223
#include "llvm/Support/Debug.h"
@@ -74,17 +75,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
7475
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
7576
"Expecting blockSize size to match the rank of destTy.");
7677
auto shape = vecTy.getShape();
77-
auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
78-
79-
Value result = rewriter.create<arith::ConstantOp>(
80-
loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
81-
for (auto [src, offsets] :
82-
llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
83-
SmallVector<int64_t> staticStrides(offsets.size(), 1);
84-
result = rewriter.create<vector::InsertStridedSliceOp>(
85-
loc, src, result, offsets, staticStrides);
86-
}
87-
return result;
78+
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
8879
}
8980

9081
if (isa<xegpu::TensorDescType>(destTy)) {
@@ -109,16 +100,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
109100
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
110101
assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
111102
"Expecting blockSize size to match the rank of src.");
112-
auto shape = vecTy.getShape();
113-
SmallVector<Value> results;
114-
for (SmallVector<int64_t> offsets :
115-
StaticTileOffsetRange(shape, blockSize)) {
116-
SmallVector<int64_t> staticStrides(offsets.size(), 1);
117-
auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
118-
loc, src, offsets, blockSize, staticStrides);
119-
results.push_back(slice);
120-
}
121-
return results;
103+
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
104+
blockSize);
122105
}
123106

124107
if (isa<xegpu::TensorDescType>(src.getType())) {

0 commit comments

Comments
 (0)