Skip to content

Commit 8fca9c1

Browse files
committed
Address comments
1 parent c5b2e81 commit 8fca9c1

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,16 +488,16 @@ struct LinearizeVectorCreateMask final
488488
// the second operand of create_mask to get the 1D mask.
489489
auto firstOperand = adaptor.getOperands().front();
490490
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
491-
auto isNonZero = rewriter.create<mlir::arith::CmpIOp>(
491+
auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
492492
loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
493-
auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>(
493+
auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
494494
loc, rewriter.getIndexType(), isNonZero);
495495
auto secondOperand = adaptor.getOperands().back();
496-
auto maskSize = rewriter.create<mlir::arith::MulIOp>(
496+
auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
497497
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
498498

499499
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
500-
loc, dstTy, maskSize.getResult());
500+
loc, dstTy, maskSize);
501501
rewriter.replaceOp(createMaskOp, newMask);
502502
return success();
503503
}

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,21 @@ func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1
355355
// CHECK: %[[C0:.*]] = arith.constant 0 : index
356356
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
357357
// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
358-
// CHECK: %[[MULI:.*]] = arith.muli %[[INDEXCAST]], %[[ARG1]] : index
358+
// CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index
359359
// CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
360360
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
361361
// CHECK: return %[[CAST]] : vector<1x16xi1>
362362
%0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
363363
return %0 : vector<1x16xi1>
364364
}
365+
366+
// -----
367+
// CHECK-LABEL: linearize_scalable_create_mask
368+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x[16]xi1>
369+
func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {
370+
371+
// CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1>
372+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<[16]xi1> to vector<1x[16]xi1>
373+
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
374+
return %0 : vector<1x[16]xi1>
375+
}

0 commit comments

Comments
 (0)