Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
if (!stridedLayout)
return failure();
mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
MemRefLayoutAttrInterface newLayout =
StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
// Special case: if resetting the offset causes the strided layout to become
// the identity layout, then reset to the identity layout.
// TODO: this'll get a lot simpler when we have the contiguous layout.
SmallVector<int64_t> stridesIfIdentity;
if (source.hasStaticShape()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making a utility function something likeisIdentity for this? Looks useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - in my to-be-split contiguous layouts PR, this was something along the lines of MemRefLayoutAttrInterface StridedLayoutAttr::getCanonical(sizes, strides, offset);, but that's future work / a branch that needs to be rescued.

stridesIfIdentity = computeSuffixProduct(source.getShape());
} else if (source.getRank() <= 1) {
stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
}
if (stridesIfIdentity == stridedLayout.getStrides()) {
newLayout = AffineMapAttr::get(
AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
}
mb.setLayout(newLayout);
}
return (MemRefType)(mb);
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
}

// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
// CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
// CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
Expand All @@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
// CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
// CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
// CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
%ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
%ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
Expand Down
50 changes: 47 additions & 3 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
// CHECK-SAME: boundsCheck(false)
// CHECK-SAME: resetOffset
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
: memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
: memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
// CHECK: amdgpu.fat_raw_buffer_cast
func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
: memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
// CHECK: return %[[ret]]
func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
: memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
// CHECK: return %[[ret]]
func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
: memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
// CHECK: return %[[ret]]
func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
: memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
// CHECK: return %[[ret]]
func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset
: memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
}

// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
Expand Down
Loading