99#include " GPUOpsLowering.h"
1010
1111#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
12+ #include " mlir/Conversion/LLVMCommon/VectorPattern.h"
1213#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1314#include " mlir/IR/Attributes.h"
1415#include " mlir/IR/Builders.h"
@@ -586,22 +587,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
586587 return success ();
587588}
588589
589- // / Unrolls op if it's operating on vectors.
590- LogicalResult impl::scalarizeVectorOp (Operation *op, ValueRange operands,
591- ConversionPatternRewriter &rewriter,
592- const LLVMTypeConverter &converter) {
590+ // / Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
591+ // / Used either directly (for ops on 1D vectors) or as the callback passed to
592+ // / detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
593+ static Value scalarizeVectorOpHelper (Operation *op, ValueRange operands,
594+ Type llvm1DVectorTy,
595+ ConversionPatternRewriter &rewriter,
596+ const LLVMTypeConverter &converter) {
593597 TypeRange operandTypes (operands);
594- if (llvm::none_of (operandTypes, llvm::IsaPred<VectorType>)) {
595- return rewriter.notifyMatchFailure (op, " expected vector operand" );
596- }
597- if (op->getNumRegions () != 0 || op->getNumSuccessors () != 0 )
598- return rewriter.notifyMatchFailure (op, " expected no region/successor" );
599- if (op->getNumResults () != 1 )
600- return rewriter.notifyMatchFailure (op, " expected single result" );
601- VectorType vectorType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
602- if (!vectorType)
603- return rewriter.notifyMatchFailure (op, " expected vector result" );
604-
598+ VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
605599 Location loc = op->getLoc ();
606600 Value result = rewriter.create <LLVM::PoisonOp>(loc, vectorType);
607601 Type indexType = converter.convertType (rewriter.getIndexType ());
@@ -621,9 +615,32 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
621615 result = rewriter.create <LLVM::InsertElementOp>(
622616 loc, result, scalarOp->getResult (0 ), index);
623617 }
618+ return result;
619+ }
624620
625- rewriter.replaceOp (op, result);
626- return success ();
621+ // / Unrolls op to array/vector elements.
622+ LogicalResult impl::scalarizeVectorOp (Operation *op, ValueRange operands,
623+ ConversionPatternRewriter &rewriter,
624+ const LLVMTypeConverter &converter) {
625+ TypeRange operandTypes (operands);
626+ if (llvm::any_of (operandTypes, llvm::IsaPred<VectorType>)) {
627+ VectorType vectorType = cast<VectorType>(op->getResultTypes ()[0 ]);
628+ rewriter.replaceOp (op, scalarizeVectorOpHelper (op, operands, vectorType,
629+ rewriter, converter));
630+ return success ();
631+ }
632+
633+ if (llvm::any_of (operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
634+ return LLVM::detail::handleMultidimensionalVectors (
635+ op, operands, converter,
636+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
637+ return scalarizeVectorOpHelper (op, operands, llvm1DVectorTy, rewriter,
638+ converter);
639+ },
640+ rewriter);
641+ }
642+
643+ return rewriter.notifyMatchFailure (op, " no llvm.array or vector to unroll" );
627644}
628645
629646static IntegerAttr wrapNumericMemorySpace (MLIRContext *ctx, unsigned space) {
0 commit comments