diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 682eb82ac5840..dc5eb2527f949 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern { VectorType extractSrcType = extractOp.getSourceVectorType(); Location loc = extractOp.getLoc(); - // "vector.extract %v[] : vector from vector" is an invalid op. - assert(extractSrcType.getRank() > 0 && - "vector.extract does not support rank 0 sources"); - - // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be - // canonicalized to %v. - if (extractOp.getNumIndices() == 0) + // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern. + if (extractSrcType.getRank() <= 1) { return failure(); - - // Rewrite vector.extract with 1d source to vector.extractelement. - if (extractSrcType.getRank() == 1) { - if (extractOp.hasDynamicPosition()) - // TODO: Dinamic position not supported yet. - return failure(); - - assert(extractOp.getNumIndices() == 1 && "expected 1 index"); - int64_t pos = extractOp.getStaticPosition()[0]; - rewriter.setInsertionPoint(extractOp); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getVector(), - rewriter.create(loc, pos)); - return success(); } // All following cases are 2d or higher dimensional source vectors. @@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern { } }; -/// Pattern to move out vector.extractelement of 0-D tensors. Those don't -/// need to be distributed and can just be propagated outside of the region. -struct WarpOpExtractElement : public OpRewritePattern { - WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn, - PatternBenefit b = 1) +/// Pattern to move out vector.extract with a scalar result. +/// Only supports 1-D and 0-D sources for now. +struct WarpOpExtractScalar : public OpRewritePattern { + WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, + PatternBenefit b = 1) : OpRewritePattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = - getWarpResult(warpOp, llvm::IsaPred); + getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); - auto extractOp = operand->get().getDefiningOp(); + auto extractOp = operand->get().getDefiningOp(); VectorType extractSrcType = extractOp.getSourceVectorType(); + // Only supports 1-D or 0-D sources for now. + if (extractSrcType.getRank() > 1) { + return rewriter.notifyMatchFailure( + extractOp, "only 0-D or 1-D source supported for now"); + } // TODO: Supported shuffle types should be parameterizable, similar to // `WarpShuffleFromIdxFn`. if (!extractSrcType.getElementType().isF32() && @@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern { VectorType distributedVecType; if (!is0dOrVec1Extract) { assert(extractSrcType.getRank() == 1 && - "expected that extractelement src rank is 0 or 1"); + "expected that extract src rank is 0 or 1"); if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) return failure(); int64_t elementsPerLane = @@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern { // Yield source vector and position (if present) from warp op. SmallVector additionalResults{extractOp.getVector()}; SmallVector additionalResultTypes{distributedVecType}; - if (static_cast(extractOp.getPosition())) { - additionalResults.push_back(extractOp.getPosition()); - additionalResultTypes.push_back(extractOp.getPosition().getType()); - } + additionalResults.append( + SmallVector(extractOp.getDynamicPosition())); + additionalResultTypes.append( + SmallVector(extractOp.getDynamicPosition().getTypes())); + Location loc = extractOp.getLoc(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( @@ -1368,39 +1355,33 @@ struct WarpOpExtractElement : public OpRewritePattern { // All lanes extract the scalar. if (is0dOrVec1Extract) { Value newExtract; - if (extractSrcType.getRank() == 1) { - newExtract = rewriter.create( - loc, distributedVec, - rewriter.create(loc, 0)); - - } else { - newExtract = - rewriter.create(loc, distributedVec); - } + SmallVector indices(extractSrcType.getRank(), 0); + newExtract = + rewriter.create(loc, distributedVec, indices); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newExtract); return success(); } + int64_t staticPos = extractOp.getStaticPosition()[0]; + OpFoldResult pos = ShapedType::isDynamic(staticPos) + ? (newWarpOp->getResult(newRetIndices[1])) + : OpFoldResult(rewriter.getIndexAttr(staticPos)); // 1d extract: Distribute the source vector. One lane extracts and shuffles // the value to all other lanes. int64_t elementsPerLane = distributedVecType.getShape()[0]; AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); // tid of extracting thread: pos / elementsPerLane - Value broadcastFromTid = rewriter.create( - loc, sym0.ceilDiv(elementsPerLane), - newWarpOp->getResult(newRetIndices[1])); + Value broadcastFromTid = affine::makeComposedAffineApply( + rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); // Extract at position: pos % elementsPerLane - Value pos = + Value newPos = elementsPerLane == 1 ? rewriter.create(loc, 0).getResult() - : rewriter - .create( - loc, sym0 % elementsPerLane, - newWarpOp->getResult(newRetIndices[1])) - .getResult(); + : affine::makeComposedAffineApply(rewriter, loc, + sym0 % elementsPerLane, pos); Value extracted = - rewriter.create(loc, distributedVec, pos); + rewriter.create(loc, distributedVec, newPos); // Shuffle the extracted value to all lanes. Value shuffled = warpShuffleFromIdxFn( @@ -1413,31 +1394,59 @@ struct WarpOpExtractElement : public OpRewritePattern { WarpShuffleFromIdxFn warpShuffleFromIdxFn; }; -struct WarpOpInsertElement : public OpRewritePattern { +/// Pattern to convert vector.extractelement to vector.extract. +struct WarpOpExtractElement : public OpRewritePattern { + WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1) + : OpRewritePattern(ctx, b) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return failure(); + auto extractOp = operand->get().getDefiningOp(); + SmallVector indices; + if (auto pos = extractOp.getPosition()) { + indices.push_back(pos); + } + rewriter.setInsertionPoint(extractOp); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getVector(), indices); + return success(); + } +}; + +/// Pattern to move out vector.insert with a scalar input. +/// Only supports 1-D and 0-D destinations for now. +struct WarpOpInsertScalar : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - OpOperand *operand = - getWarpResult(warpOp, llvm::IsaPred); + OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); if (!operand) return failure(); unsigned int operandNumber = operand->getOperandNumber(); - auto insertOp = operand->get().getDefiningOp(); + auto insertOp = operand->get().getDefiningOp(); VectorType vecType = insertOp.getDestVectorType(); VectorType distrType = cast(warpOp.getResult(operandNumber).getType()); - bool hasPos = static_cast(insertOp.getPosition()); + + // Only supports 1-D or 0-D destinations for now. + if (vecType.getRank() > 1) { + return rewriter.notifyMatchFailure( + insertOp, "only 0-D or 1-D source supported for now"); + } // Yield destination vector, source scalar and position from warp op. SmallVector additionalResults{insertOp.getDest(), insertOp.getSource()}; SmallVector additionalResultTypes{distrType, insertOp.getSource().getType()}; - if (hasPos) { - additionalResults.push_back(insertOp.getPosition()); - additionalResultTypes.push_back(insertOp.getPosition().getType()); - } + additionalResults.append(SmallVector(insertOp.getDynamicPosition())); + additionalResultTypes.append( + SmallVector(insertOp.getDynamicPosition().getTypes())); + Location loc = insertOp.getLoc(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( @@ -1446,13 +1455,26 @@ struct WarpOpInsertElement : public OpRewritePattern { rewriter.setInsertionPointAfter(newWarpOp); Value distributedVec = newWarpOp->getResult(newRetIndices[0]); Value newSource = newWarpOp->getResult(newRetIndices[1]); - Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value(); rewriter.setInsertionPointAfter(newWarpOp); + OpFoldResult pos; + if (vecType.getRank() != 0) { + int64_t staticPos = insertOp.getStaticPosition()[0]; + pos = ShapedType::isDynamic(staticPos) + ? (newWarpOp->getResult(newRetIndices[2])) + : OpFoldResult(rewriter.getIndexAttr(staticPos)); + } + + // This condition is always true for 0-d vectors. if (vecType == distrType) { - // Broadcast: Simply move the vector.inserelement op out. - Value newInsert = rewriter.create( - loc, newSource, distributedVec, newPos); + Value newInsert; + SmallVector indices; + if (pos) { + indices.push_back(pos); + } + newInsert = rewriter.create(loc, newSource, + distributedVec, indices); + // Broadcast: Simply move the vector.insert op out. rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert); return success(); @@ -1462,16 +1484,11 @@ struct WarpOpInsertElement : public OpRewritePattern { int64_t elementsPerLane = distrType.getShape()[0]; AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); // tid of extracting thread: pos / elementsPerLane - Value insertingLane = rewriter.create( - loc, sym0.ceilDiv(elementsPerLane), newPos); + Value insertingLane = affine::makeComposedAffineApply( + rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); // Insert position: pos % elementsPerLane - Value pos = - elementsPerLane == 1 - ? rewriter.create(loc, 0).getResult() - : rewriter - .create(loc, sym0 % elementsPerLane, - newPos) - .getResult(); + OpFoldResult newPos = affine::makeComposedFoldedAffineApply( + rewriter, loc, sym0 % elementsPerLane, pos); Value isInsertingLane = rewriter.create( loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); Value newResult = @@ -1480,8 +1497,8 @@ struct WarpOpInsertElement : public OpRewritePattern { loc, isInsertingLane, /*thenBuilder=*/ [&](OpBuilder &builder, Location loc) { - Value newInsert = builder.create( - loc, newSource, distributedVec, pos); + Value newInsert = builder.create( + loc, newSource, distributedVec, newPos); builder.create(loc, newInsert); }, /*elseBuilder=*/ @@ -1506,25 +1523,13 @@ struct WarpOpInsert : public OpRewritePattern { auto insertOp = operand->get().getDefiningOp(); Location loc = insertOp.getLoc(); - // "vector.insert %v, %v[] : ..." can be canonicalized to %v. - if (insertOp.getNumIndices() == 0) + // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern. + if (insertOp.getDestVectorType().getRank() <= 1) { return failure(); - - // Rewrite vector.insert with 1d dest to vector.insertelement. - if (insertOp.getDestVectorType().getRank() == 1) { - if (insertOp.hasDynamicPosition()) - // TODO: Dinamic position not supported yet. - return failure(); - - assert(insertOp.getNumIndices() == 1 && "expected 1 index"); - int64_t pos = insertOp.getStaticPosition()[0]; - rewriter.setInsertionPoint(insertOp); - rewriter.replaceOpWithNewOp( - insertOp, insertOp.getSource(), insertOp.getDest(), - rewriter.create(loc, pos)); - return success(); } + // All following cases are 2d or higher dimensional source vectors. + if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { // There is no distribution, this is a broadcast. Simply move the insert // out of the warp op. @@ -1620,9 +1625,30 @@ struct WarpOpInsert : public OpRewritePattern { } }; +struct WarpOpInsertElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred); + if (!operand) + return failure(); + auto insertOp = operand->get().getDefiningOp(); + SmallVector indices; + if (auto pos = insertOp.getPosition()) { + indices.push_back(pos); + } + rewriter.setInsertionPoint(insertOp); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getSource(), insertOp.getDest(), indices); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if -/// the scf.ForOp is the last operation in the region so that it doesn't change -/// the order of execution. This creates a new scf.for region after the +/// the scf.ForOp is the last operation in the region so that it doesn't +/// change the order of execution. This creates a new scf.for region after the /// WarpExecuteOnLane0Op. The new scf.for region will contain a new /// WarpExecuteOnLane0Op region. Example: /// ``` @@ -1668,8 +1694,8 @@ struct WarpOpScfForOp : public OpRewritePattern { if (!forOp) return failure(); // Collect Values that come from the warp op but are outside the forOp. - // Those Value needs to be returned by the original warpOp and passed to the - // new op. + // Those Value needs to be returned by the original warpOp and passed to + // the new op. llvm::SmallSetVector escapingValues; SmallVector inputTypes; SmallVector distTypes; @@ -1715,8 +1741,8 @@ struct WarpOpScfForOp : public OpRewritePattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); - // Create a new for op outside the region with a WarpExecuteOnLane0Op region - // inside. + // Create a new for op outside the region with a WarpExecuteOnLane0Op + // region inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newOperands); @@ -1778,8 +1804,8 @@ struct WarpOpScfForOp : public OpRewritePattern { }; /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. -/// The vector is reduced in parallel. Currently limited to vector size matching -/// the warpOp size. E.g.: +/// The vector is reduced in parallel. Currently limited to vector size +/// matching the warpOp size. E.g.: /// ``` /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { /// %0 = "some_def"() : () -> (vector<32xf32>) @@ -1880,13 +1906,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { patterns.add(patterns.getContext(), readBenefit); - patterns - .add( - patterns.getContext(), benefit); - patterns.add(patterns.getContext(), - warpShuffleFromIdxFn, benefit); + patterns.add( + patterns.getContext(), benefit); + patterns.add(patterns.getContext(), warpShuffleFromIdxFn, + benefit); patterns.add(patterns.getContext(), distributionMapFn, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 3acddd6e54639..b4491812dc26c 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -783,7 +783,7 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32> // CHECK-PROP: vector.yield %[[V]] : vector<64xf32> // CHECK-PROP: } -// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32> +// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32> // CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]] // CHECK-PROP: return %[[SHUFFLED]] : f32 func.func @vector_extract_1d(%laneid: index) -> (f32) { @@ -874,7 +874,7 @@ func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector // CHECK-PROP: vector.yield %[[V]] : vector // CHECK-PROP: } -// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector +// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][] : f32 from vector // CHECK-PROP: return %[[E]] : f32 func.func @vector_extractelement_0d(%laneid: index) -> (f32) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { @@ -888,12 +888,11 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) { // ----- // CHECK-PROP-LABEL: func.func @vector_extractelement_1element( -// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32> // CHECK-PROP: vector.yield %[[V]] : vector<1xf32> // CHECK-PROP: } -// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32> +// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : f32 from vector<1xf32> // CHECK-PROP: return %[[E]] : f32 func.func @vector_extractelement_1element(%laneid: index) -> (f32) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { @@ -918,7 +917,7 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) { // CHECK-PROP: } // CHECK-PROP: %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]] // CHECK-PROP: %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]] -// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32> +// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extract %[[W]][%[[DISTR_POS]]] : f32 from vector<3xf32> // CHECK-PROP: %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32 // CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32 // CHECK-PROP: return %[[SHUFFLED]] @@ -938,7 +937,7 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) { // CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index( // CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (index) { // CHECK-PROP: "some_def" -// CHECK-PROP: vector.extractelement +// CHECK-PROP: vector.extract // CHECK-PROP: vector.yield {{.*}} : index // CHECK-PROP: } func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) { @@ -1151,7 +1150,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, % // CHECK-PROP: %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]] // CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index // CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) { -// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index] +// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [%[[INSERTING_POS]]] // CHECK-PROP: scf.yield %[[INSERT]] // CHECK-PROP: } else { // CHECK-PROP: scf.yield %[[W]]#0 @@ -1175,7 +1174,7 @@ func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32 // CHECK-PROP: %[[VEC:.*]] = "some_def" // CHECK-PROP: %[[VAL:.*]] = "another_def" // CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] -// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32> +// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [%[[POS]]] : f32 into vector<96xf32> func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) { %0 = "some_def"() : () -> (vector<96xf32>) @@ -1193,7 +1192,7 @@ func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (ve // CHECK-PROP: %[[VEC:.*]] = "some_def" // CHECK-PROP: %[[VAL:.*]] = "another_def" // CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] -// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[] : vector +// CHECK-PROP: vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector func.func @vector_insertelement_0d(%laneid: index) -> (vector) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector) { %0 = "some_def"() : () -> (vector) @@ -1208,7 +1207,6 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector) { // CHECK-PROP-LABEL: func @vector_insert_1d( // CHECK-PROP-SAME: %[[LANEID:.*]]: index -// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-PROP-DAG: %[[C26:.*]] = arith.constant 26 : index // CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32) // CHECK-PROP: %[[VEC:.*]] = "some_def" @@ -1216,7 +1214,7 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector) { // CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] // CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]] // CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) { -// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index] +// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#1, %[[W]]#0 [1] // CHECK-PROP: scf.yield %[[INSERT]] // CHECK-PROP: } else { // CHECK-PROP: scf.yield %[[W]]#0 @@ -1316,14 +1314,14 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) { // CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}] // CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32> // CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex> -// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex> +// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex> // CHECK-PROP: vector.yield %[[EXTRACTELT]] : index // CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]] // CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[W]], %[[APPLY]]], // CHECK-PROP: return %[[TRANSFERREAD]] func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> { %0 = gpu.thread_id x - %c0_i32 = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : index %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<1x64xi32> %cst_0 = arith.constant dense : vector<1x64xi1> @@ -1336,7 +1334,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1 %28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32> %29 = vector.extract %28[0] : vector<64xi32> from vector<1x64xi32> %30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex> - %36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex> + %36 = vector.extractelement %30[%c0_i32 : index] : vector<64xindex> %37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32> vector.yield %37 : vector<64xf32> }