@@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final
906906 }
907907};
908908
909+ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
910+ using OpConversionPattern::OpConversionPattern;
911+
912+ LogicalResult
913+ matchAndRewrite (vector::StepOp stepOp, OpAdaptor adaptor,
914+ ConversionPatternRewriter &rewriter) const override {
915+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
916+ Type dstType = typeConverter.convertType (stepOp.getType ());
917+ if (!dstType)
918+ return failure ();
919+
920+ Location loc = stepOp.getLoc ();
921+ int64_t numElements = stepOp.getType ().getNumElements ();
922+ auto intType =
923+ rewriter.getIntegerType (typeConverter.getIndexTypeBitwidth ());
924+
925+ // Input vectors of size 1 are converted to scalars by the type converter.
926+ // We just create a constant in this case.
927+ if (numElements == 1 ) {
928+ Value zero = spirv::ConstantOp::getZero (intType, loc, rewriter);
929+ rewriter.replaceOp (stepOp, zero);
930+ return success ();
931+ }
932+
933+ SmallVector<Value> source;
934+ source.reserve (numElements);
935+ for (int64_t i = 0 ; i < numElements; ++i) {
936+ Attribute intAttr = rewriter.getIntegerAttr (intType, i);
937+ Value constOp = rewriter.create <spirv::ConstantOp>(loc, intType, intAttr);
938+ source.push_back (constOp);
939+ }
940+ rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(stepOp, dstType,
941+ source);
942+ return success ();
943+ }
944+ };
945+
909946} // namespace
910947#define CL_INT_MAX_MIN_OPS \
911948 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
929966 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
930967 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
931968 VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
932- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
933- typeConverter, patterns.getContext (), PatternBenefit (1 ));
969+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
970+ VectorStepOpConvert>(typeConverter, patterns.getContext (),
971+ PatternBenefit (1 ));
934972
935973 // Make sure that the more specialized dot product pattern has higher benefit
936974 // than the generic one that extracts all elements.
0 commit comments