Skip to content

Conversation

@arnab-polymage
Copy link
Contributor

Compute the number of elements to be copied by multiplying dim sizes along all the dimensions.

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Arnab Dutta (arnab-polymage)

Changes

Compute the number of elements to be copied by multiplying dim sizes along all the dimensions.


Full diff: https://github.com/llvm/llvm-project/pull/128820.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+10-8)
  • (modified) mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir (+20)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8017eb6bb383b..512820bab4097 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
     Type indexType = ConvertToLLVMPattern::getIndexType();
-    return type.hasStaticShape()
-               ? ConvertToLLVMPattern::createIndexAttrConstant(
-                     rewriter, loc, indexType, type.getNumElements())
-               // For identity maps (verified by caller), the number of
-               // elements is stride[0] * size[0].
-               : rewriter.create<LLVM::MulOp>(loc,
-                                              desc.stride(rewriter, loc, 0),
-                                              desc.size(rewriter, loc, 0));
+    if (type.hasStaticShape())
+      return ConvertToLLVMPattern::createIndexAttrConstant(
+          rewriter, loc, indexType, type.getNumElements());
+    // Compute the number of elements by multiplying all the dim sizes.
+    uint64_t rank = type.getRank();
+    Value numElements = desc.size(rewriter, loc, /*pos=*/0);
+    for (unsigned i = 1; i < rank; i++)
+      numElements = rewriter.create<LLVM::MulOp>(
+          loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+    return numElements;
   }
 
   MLIRContext *context = &this->getTypeConverter()->getContext();
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b07698279..b45d188a77e3f 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @dynamic
+  func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
+    // CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
+    // CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]]  : i64
+    // CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
+    // CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
+    // CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
+    return
+  }
+}

@arnab-polymage arnab-polymage force-pushed the ornib/gpu_memcpy_lowering_bug branch from b1ca515 to cb73e69 Compare February 26, 2025 09:56
Compute the number of elements to be copied by multiplying dim sizes
along all the dimensions.
@arnab-polymage arnab-polymage force-pushed the ornib/gpu_memcpy_lowering_bug branch from cb73e69 to 88d79e8 Compare February 27, 2025 04:47
@bondhugula bondhugula merged commit c13ebb5 into llvm:main Mar 3, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants