Skip to content

Commit 79e37d8

Browse files
committed
apply review suggestions
Signed-off-by: Dmitry Chigarev <[email protected]>
1 parent fdb0540 commit 79e37d8

File tree

1 file changed

+41
-50
lines changed

1 file changed

+41
-50
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,18 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9898
}
9999

100100
// Extract cache hints from the op attributes if available.
101-
static void getOpCacheHints(Operation *op,
102-
SmallVector<xegpu::CachePolicyAttr, 3> &hints) {
103-
assert(hints.size() == 3 &&
104-
"Expecting a vector of size 3 for l1, l2, l3 hints.");
101+
static SmallVector<xegpu::CachePolicyAttr, 3> getOpCacheHints(Operation *op) {
102+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
103+
xegpu::CachePolicyAttr{},
104+
xegpu::CachePolicyAttr{}};
105105
// get l1, l2, l3 hints from attributes if available.
106106
if (auto l1Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l1_hint"))
107-
hints[0] = l1Attr;
107+
cacheHints[0] = l1Attr;
108108
if (auto l2Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l2_hint"))
109-
hints[1] = l2Attr;
109+
cacheHints[1] = l2Attr;
110110
if (auto l3Attr = op->getAttrOfType<xegpu::CachePolicyAttr>("l3_hint"))
111-
hints[2] = l3Attr;
111+
cacheHints[2] = l3Attr;
112+
return cacheHints;
112113
}
113114

114115
static xegpu::CreateNdDescOp
@@ -389,28 +390,25 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
389390
}
390391
Value indices = gatScatOp.getIndices();
391392
// Extract indices layout and propagate it to all 'vector' ops created here
392-
auto indicesLayout = mlir::xegpu::getDistributeLayoutAttr(indices);
393+
auto indicesLayout = xegpu::getDistributeLayoutAttr(indices);
393394
VectorType vecType = cast<VectorType>(indices.getType());
394395

395396
auto strideVector =
396397
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back());
397-
mlir::xegpu::setDistributeLayoutAttr(strideVector->getOpResult(0),
398-
indicesLayout);
398+
xegpu::setDistributeLayoutAttr(strideVector->getOpResult(0), indicesLayout);
399399

400400
auto stridedIndices =
401401
arith::MulIOp::create(rewriter, loc, strideVector.getResult(), indices);
402-
mlir::xegpu::setDistributeLayoutAttr(stridedIndices->getOpResult(0),
403-
indicesLayout);
402+
xegpu::setDistributeLayoutAttr(stridedIndices->getOpResult(0), indicesLayout);
404403

405404
auto baseVector = vector::BroadcastOp::create(
406405
rewriter, loc,
407406
VectorType::get(vecType.getShape(), rewriter.getIndexType()), baseOffset);
408-
mlir::xegpu::setDistributeLayoutAttr(baseVector->getOpResult(0),
409-
indicesLayout);
407+
xegpu::setDistributeLayoutAttr(baseVector->getOpResult(0), indicesLayout);
410408

411409
auto result = arith::AddIOp::create(rewriter, loc, baseVector.getResult(),
412410
stridedIndices.getResult());
413-
mlir::xegpu::setDistributeLayoutAttr(result->getOpResult(0), indicesLayout);
411+
xegpu::setDistributeLayoutAttr(result->getOpResult(0), indicesLayout);
414412
return result.getResult();
415413
}
416414

@@ -639,37 +637,33 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
639637
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
640638

641639
auto numOffsets = gatherOp.getOffsets().size();
642-
auto layoutRes = mlir::xegpu::getDistributeLayoutAttr(gatherOp.getResult());
643-
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
644-
gatherOp->getOpOperand(numOffsets + 1));
645-
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
646-
gatherOp->getOpOperand(numOffsets + 2));
647-
auto layoutPassThru = mlir::xegpu::getDistributeLayoutAttr(
648-
gatherOp->getOpOperand(numOffsets + 3));
649-
650-
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
651-
xegpu::CachePolicyAttr{},
652-
xegpu::CachePolicyAttr{}};
653-
getOpCacheHints(gatherOp, cacheHints);
640+
auto layoutRes = xegpu::getDistributeLayoutAttr(gatherOp.getResult());
641+
auto layoutIndices =
642+
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 1));
643+
auto layoutMask =
644+
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 2));
645+
auto layoutPassThru =
646+
xegpu::getDistributeLayoutAttr(gatherOp->getOpOperand(numOffsets + 3));
647+
648+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
649+
getOpCacheHints(gatherOp);
654650
auto xeGatherOp = xegpu::LoadGatherOp::create(
655651
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
656652
/*chunk_size=*/IntegerAttr{},
657653
/*l1_hint=*/cacheHints[0],
658654
/*l2_hint=*/cacheHints[1],
659655
/*l3_hint=*/cacheHints[2]);
660-
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
661-
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1),
662-
layoutIndices);
663-
mlir::xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(2),
664-
layoutMask);
656+
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpResult(0), layoutRes);
657+
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(1), layoutIndices);
658+
xegpu::setDistributeLayoutAttr(xeGatherOp->getOpOperand(2), layoutMask);
665659

666660
auto selectOp =
667661
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
668662
xeGatherOp.getResult(), gatherOp.getPassThru());
669-
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(0), layoutMask);
670-
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(2),
671-
layoutPassThru);
672-
mlir::xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);
663+
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(0), layoutMask);
664+
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(1), layoutRes);
665+
xegpu::setDistributeLayoutAttr(selectOp->getOpOperand(2), layoutPassThru);
666+
xegpu::setDistributeLayoutAttr(selectOp->getOpResult(0), layoutRes);
673667

674668
rewriter.replaceOp(gatherOp, selectOp.getResult());
675669
return success();
@@ -696,27 +690,24 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
696690
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
697691

698692
auto numOffsets = scatterOp.getOffsets().size();
699-
auto layoutIndices = mlir::xegpu::getDistributeLayoutAttr(
700-
scatterOp->getOpOperand(numOffsets + 1));
701-
auto layoutMask = mlir::xegpu::getDistributeLayoutAttr(
702-
scatterOp->getOpOperand(numOffsets + 2));
703-
auto layoutVal = mlir::xegpu::getDistributeLayoutAttr(
704-
scatterOp->getOpOperand(numOffsets + 3));
705-
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints{xegpu::CachePolicyAttr{},
706-
xegpu::CachePolicyAttr{},
707-
xegpu::CachePolicyAttr{}};
708-
getOpCacheHints(scatterOp, cacheHints);
693+
auto layoutIndices =
694+
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 1));
695+
auto layoutMask =
696+
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 2));
697+
auto layoutVal =
698+
xegpu::getDistributeLayoutAttr(scatterOp->getOpOperand(numOffsets + 3));
699+
SmallVector<xegpu::CachePolicyAttr, 3> cacheHints =
700+
getOpCacheHints(scatterOp);
709701
auto storeOp = xegpu::StoreScatterOp::create(
710702
rewriter, loc, scatterOp.getValueToStore(), flatMemref, localOffsets,
711703
scatterOp.getMask(),
712704
/*chunk_size=*/IntegerAttr{},
713705
/*l1_hint=*/cacheHints[0],
714706
/*l2_hint=*/cacheHints[1],
715707
/*l3_hint=*/cacheHints[2]);
716-
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
717-
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2),
718-
layoutIndices);
719-
mlir::xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(3), layoutMask);
708+
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(0), layoutVal);
709+
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(2), layoutIndices);
710+
xegpu::setDistributeLayoutAttr(storeOp->getOpOperand(3), layoutMask);
720711
rewriter.eraseOp(scatterOp);
721712
return success();
722713
}

0 commit comments

Comments
 (0)