Skip to content

Commit 5774365

Browse files
[InsertGPUAllocs] Use gpu.memcpy for opencl instead of memref.copy
1 parent 8ac2fb6 commit 5774365

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

lib/Transforms/InsertGPUAllocs.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,10 @@ class InsertGPUAllocsPass final
360360
auto newAlloc = builder.create<mlir::memref::AllocOp>(
361361
loc, alloc.getType(), alloc.getDynamicSizes(),
362362
alloc.getSymbolOperands());
363-
builder.create<mlir::memref::CopyOp>(loc, allocResult,
364-
newAlloc.getResult());
363+
builder.create<mlir::gpu::MemcpyOp>(
364+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
365+
/*asyncDependencies*/ std::nullopt, newAlloc.getResult(),
366+
allocResult);
365367
use.set(newAlloc.getResult());
366368
}
367369
}
@@ -401,8 +403,9 @@ class InsertGPUAllocsPass final
401403
/*symbolOperands*/ std::nullopt, hostShared);
402404
auto allocResult = gpuAlloc.getResult(0);
403405
if (access.hostWrite && access.deviceRead) {
404-
auto copy =
405-
builder.create<mlir::memref::CopyOp>(loc, op, allocResult);
406+
auto copy = builder.create<mlir::gpu::MemcpyOp>(
407+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
408+
/*asyncDependencies*/ std::nullopt, allocResult, op);
406409
filter.insert(copy);
407410
}
408411

@@ -421,7 +424,9 @@ class InsertGPUAllocsPass final
421424
op.replaceAllUsesExcept(allocResult, filter);
422425
builder.setInsertionPoint(term);
423426
if (access.hostRead && access.deviceWrite) {
424-
builder.create<mlir::memref::CopyOp>(loc, allocResult, op);
427+
builder.create<mlir::gpu::MemcpyOp>(
428+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
429+
/*asyncDependencies*/ std::nullopt, op, allocResult);
425430
}
426431
builder.create<mlir::gpu::DeallocOp>(loc, std::nullopt, allocResult);
427432
}

test/Transforms/InsertGpuAllocs/add-gpu-alloc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ func.func @addt(%arg0: memref<2x5xf32>, %arg1: memref<2x5xf32>) -> memref<2x5xf3
77
%c1 = arith.constant 1 : index
88
%c5 = arith.constant 5 : index
99
// OPENCL: %[[MEMREF0:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
10-
// OPENCL: memref.copy %arg1, %[[MEMREF0]] : memref<2x5xf32> to memref<2x5xf32>
10+
// OPENCL: gpu.memcpy %[[MEMREF0]], %arg1 : memref<2x5xf32>, memref<2x5xf32>
1111
// OPENCL: %[[MEMREF1:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
12-
// OPENCL: memref.copy %arg0, %[[MEMREF1]] : memref<2x5xf32> to memref<2x5xf32>
12+
// OPENCL: gpu.memcpy %[[MEMREF1]], %arg0 : memref<2x5xf32>, memref<2x5xf32>
1313
// VULKAN: %[[MEMREF0:.*]] = memref.alloc() : memref<2x5xf32>
1414
// VULKAN: memref.copy %arg1, %[[MEMREF0]] : memref<2x5xf32> to memref<2x5xf32>
1515
// VULKAN: %[[MEMREF1:.*]] = memref.alloc() : memref<2x5xf32>

0 commit comments

Comments
 (0)