@@ -545,18 +545,16 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
545545}
546546
547547// / Linearize a vector.create_mask that has at most 1 non-unit dimension.
548- // / Example:
549- // /
548+ // / For example,
550549// / ```
551- // / %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
550+ // / %mask3 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
552551// / ```
553552// /
554- // / becomes
555- // /
553+ // / becomes,
556554// / ```
557555// / [...]
558- // / %2 = vector.create_mask %prod: vector<16xi1>
559- // / %3 = vector.shape_cast %2 : vector<16xi1> to vector<1x16x1xi1>
556+ // / %mask1 = vector.create_mask %prod: vector<16xi1>
557+ // / %mask3 = vector.shape_cast %mask1 : vector<16xi1> to vector<1x16x1xi1>
560558// / ```
561559// /
562560// / where %prod above the product of the (clamped) dimension-wise masking ranges
@@ -601,11 +599,11 @@ struct LinearizeVectorCreateMask final
601599 Value zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
602600 int nonUnitDim = -1 ;
603601 for (unsigned i = 0 ; i < type.getRank (); ++i) {
604- auto v = adaptor.getOperands ()[i];
605- auto dimSize = type.getDimSize (i);
602+ Value dimRange = adaptor.getOperands ()[i];
603+ int64_t dimSize = type.getDimSize (i);
606604 if (dimSize <= 1 ) {
607605 Value nxt = rewriter.create <arith::CmpIOp>(
608- loc, arith::CmpIPredicate::sgt, v , zero);
606+ loc, arith::CmpIPredicate::sgt, dimRange , zero);
609607 prod = rewriter.create <arith::MulIOp>(loc, prod, nxt);
610608 } else {
611609 assert (nonUnitDim == -1 && " at most 1 non-unit expected" );
0 commit comments