@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
7676 Value getNumElements (ConversionPatternRewriter &rewriter, Location loc,
7777 MemRefType type, MemRefDescriptor desc) const {
7878 Type indexType = ConvertToLLVMPattern::getIndexType ();
79- return type.hasStaticShape ()
80- ? ConvertToLLVMPattern::createIndexAttrConstant (
81- rewriter, loc, indexType, type.getNumElements ())
82- // For identity maps (verified by caller), the number of
83- // elements is stride[0] * size[0].
84- : rewriter.create <LLVM::MulOp>(loc,
85- desc.stride (rewriter, loc, 0 ),
86- desc.size (rewriter, loc, 0 ));
79+ if (type.hasStaticShape ())
80+ return ConvertToLLVMPattern::createIndexAttrConstant (
81+ rewriter, loc, indexType, type.getNumElements ());
82+ // Compute the number of elements by multiplying all the dim sizes.
83+ uint64_t rank = type.getRank ();
84+ Value numElements = desc.size (rewriter, loc, /* pos=*/ 0 );
85+ for (unsigned i = 1 ; i < rank; i++)
86+ numElements = rewriter.create <LLVM::MulOp>(
87+ loc, numElements, desc.size (rewriter, loc, /* pos=*/ i));
88+ return numElements;
8789 }
8890
8991 MLIRContext *context = &this ->getTypeConverter ()->getContext();
0 commit comments