Skip to content

Commit a0512d9

Browse files
committed
address comments
1 parent 3f5d692 commit a0512d9

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op,
144144
/// if no GPU module parent or XeVM target attribute exists.
145145
std::optional<std::string> getChipStr(Operation *op);
146146

147+
/// Generates element-wise addition ops of two arrays with same length.
148+
SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
149+
ArrayRef<OpFoldResult> lhs,
150+
ArrayRef<OpFoldResult> rhs);
151+
147152
/// Generates element-wise addition ops of two arrays with automatic alignment.
148153
/// When the input arrays have different sizes, the shorter array is
149154
/// right-aligned with the longer array, and the unmatched leading elements from
@@ -157,7 +162,6 @@ std::optional<std::string> getChipStr(Operation *op);
157162
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
158163
ArrayRef<OpFoldResult> lhs,
159164
ArrayRef<OpFoldResult> rhs);
160-
161165
} // namespace xegpu
162166

163167
} // namespace mlir

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -686,12 +686,12 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
686686
using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
687687
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
688688
PatternRewriter &rewriter) const override {
689+
Location loc = op.getLoc();
690+
VectorType valueTy = op.getType();
689691
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
690-
if (!targetShape)
692+
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
691693
return failure();
692694

693-
Location loc = op.getLoc();
694-
VectorType valueTy = op.getType();
695695
Type elemTy = valueTy.getElementType();
696696
ArrayRef<int64_t> shape = valueTy.getShape();
697697
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
@@ -702,17 +702,17 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
702702
SmallVector<SmallVector<OpFoldResult>> offsetsList;
703703
for (SmallVector<int64_t> offsets :
704704
StaticTileOffsetRange(shape, *targetShape)) {
705-
auto adds = xegpu::addWithRightAligned(
705+
auto adds = xegpu::addElementwise(
706706
rewriter, loc, mixedOffsets,
707707
getAsIndexOpFoldResult(op.getContext(), offsets));
708708
offsetsList.push_back(adds);
709709
}
710710

711711
SmallVector<Value> newOps;
712+
layout = layout.dropInstData();
712713
for (SmallVector<OpFoldResult> offsets : offsetsList) {
713714
auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
714-
op.getLoc(), newValueTy, op.getMemDesc(), offsets,
715-
layout.dropInstData());
715+
op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
716716
newOps.push_back(newOp);
717717
}
718718
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -743,7 +743,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
743743
SmallVector<SmallVector<OpFoldResult>> offsetsList;
744744
for (SmallVector<int64_t> offsets :
745745
StaticTileOffsetRange(shape, *targetShape)) {
746-
auto adds = xegpu::addWithRightAligned(
746+
auto adds = xegpu::addElementwise(
747747
rewriter, loc, mixedOffsets,
748748
getAsIndexOpFoldResult(op.getContext(), offsets));
749749
offsetsList.push_back(adds);

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,21 @@ std::optional<std::string> xegpu::getChipStr(Operation *op) {
447447
return std::nullopt;
448448
}
449449

450+
/// Generates element-wise addition ops of two arrays with same length.
451+
SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
452+
Location loc,
453+
ArrayRef<OpFoldResult> lhs,
454+
ArrayRef<OpFoldResult> rhs) {
455+
assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
456+
SmallVector<OpFoldResult> results;
457+
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
458+
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
459+
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
460+
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
461+
}
462+
return results;
463+
}
464+
450465
/// Generates element-wise addition ops of two arrays with automatic alignment.
451466
/// When the input arrays have different sizes, the shorter array is
452467
/// right-aligned with the longer array, and the unmatched leading elements from
@@ -466,10 +481,6 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
466481
ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
467482
SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
468483
a = a.slice(a.size() - b.size());
469-
for (auto [l, r] : llvm::zip(a, b)) {
470-
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
471-
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
472-
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
473-
}
484+
results.append(addElementwise(builder, loc, a, b));
474485
return results;
475486
}

0 commit comments

Comments
 (0)