File tree Expand file tree Collapse file tree 1 file changed +10
-18
lines changed
mlir/lib/Dialect/AMDGPU/IR Expand file tree Collapse file tree 1 file changed +10
-18
lines changed Original file line number Diff line number Diff line change @@ -559,26 +559,18 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
559559 LogicalResult matchAndRewrite (GatherToLDSOp gatherOp,
560560 PatternRewriter &rewriter) const override {
561561 bool modified = false ;
562-
563- // Check source.
564- if (auto castOp = gatherOp.getSrc ().getDefiningOp <memref::CastOp>()) {
565- if (memref::CastOp::canFoldIntoConsumerOp (castOp)) {
566- rewriter.modifyOpInPlace (gatherOp, [&] {
567- gatherOp.getSrcMutable ().assign (castOp.getSource ());
568- });
569- modified = true ;
562+ auto foldCast = [&](OpOperand &operand) {
563+ if (auto castOp = operand.get ().getDefiningOp <memref::CastOp>()) {
564+ if (memref::CastOp::canFoldIntoConsumerOp (castOp)) {
565+ rewriter.modifyOpInPlace (gatherOp,
566+ [&] { operand.assign (castOp.getSource ()); });
567+ modified = true ;
568+ }
570569 }
571- }
570+ };
572571
573- // Check target.
574- if (auto castOp = gatherOp.getDst ().getDefiningOp <memref::CastOp>()) {
575- if (memref::CastOp::canFoldIntoConsumerOp (castOp)) {
576- rewriter.modifyOpInPlace (gatherOp, [&] {
577- gatherOp.getDstMutable ().assign (castOp.getSource ());
578- });
579- modified = true ;
580- }
581- }
572+ foldCast (gatherOp.getSrcMutable ());
573+ foldCast (gatherOp.getDstMutable ());
582574
583575 return success (modified);
584576 }
You can’t perform that action at this time.
0 commit comments