Skip to content

Commit 80d9499

Browse files
Fix bug in gpu.memcpy lowering for dynamically shaped operands.
Compute the number of elements to be copied by multiplying dim sizes along all the dimensions.
1 parent 98542a3 commit 80d9499

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
7676
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
7777
MemRefType type, MemRefDescriptor desc) const {
7878
Type indexType = ConvertToLLVMPattern::getIndexType();
79-
return type.hasStaticShape()
80-
? ConvertToLLVMPattern::createIndexAttrConstant(
81-
rewriter, loc, indexType, type.getNumElements())
82-
// For identity maps (verified by caller), the number of
83-
// elements is stride[0] * size[0].
84-
: rewriter.create<LLVM::MulOp>(loc,
85-
desc.stride(rewriter, loc, 0),
86-
desc.size(rewriter, loc, 0));
79+
if (type.hasStaticShape())
80+
return ConvertToLLVMPattern::createIndexAttrConstant(
81+
rewriter, loc, indexType, type.getNumElements());
82+
// Compute the number of elements by multiplying all the dim sizes.
83+
uint64_t rank = type.getRank();
84+
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
85+
for (unsigned i = 1; i < rank; i++)
86+
numElements = rewriter.create<LLVM::MulOp>(
87+
loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
88+
return numElements;
8789
}
8890

8991
MLIRContext *context = &this->getTypeConverter()->getContext();

mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
1717
return
1818
}
1919
}
20+
21+
// -----
22+
23+
module attributes {gpu.container_module} {
24+
25+
// CHECK: func @dynamic
26+
func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
27+
// CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
28+
%t0 = gpu.wait async
29+
%t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
30+
// CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
31+
// CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
32+
// CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]] : i64
33+
// CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
34+
// CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
35+
// CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
36+
// CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
37+
return
38+
}
39+
}

0 commit comments

Comments
 (0)