-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][AMDGPU] Infer canonical layouts for fat_raw_buffer_cast resetOffset #149867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Infer canonical layouts for fat_raw_buffer_cast resetOffset #149867
Conversation
…ffset When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization. Now, the return type inference will put in an identity layout after offset stripping for 1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and 2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Krzysztof Drewniak (krzysz00) ChangesWhen inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization. Now, the return type inference will put in an identity layout after offset stripping for
Full diff: https://github.com/llvm/llvm-project/pull/149867.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 88c2eb3326d96..18e8270f5aa99 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
if (!stridedLayout)
return failure();
- mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+ MemRefLayoutAttrInterface newLayout =
+ StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
+ // Special case: if resetting the offset causes the strided layout to become
+ // the identity layout, then reset to the identity layout.
+ // TODO: this'll get a lot simpler when we have the contiguous layout.
+ SmallVector<int64_t> stridesIfIdentity;
+ if (source.hasStaticShape()) {
+ stridesIfIdentity = computeSuffixProduct(source.getShape());
+ } else if (source.getRank() <= 1) {
+ stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
+ }
+ if (stridesIfIdentity == stridedLayout.getStrides()) {
+ newLayout = AffineMapAttr::get(
+ AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
+ }
+ mb.setLayout(newLayout);
}
return (MemRefType)(mb);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..cc1162d8b0de8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
}
// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
-func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
// CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
// CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
@@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
// CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
// CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
// CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
- %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
- return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
}
// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..fe2b32be04de4 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
// CHECK-SAME: boundsCheck(false)
// CHECK-SAME: resetOffset
-func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
- : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
- func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
}
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
|
|
@llvm/pr-subscribers-backend-amdgpu Author: Krzysztof Drewniak (krzysz00) ChangesWhen inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization. Now, the return type inference will put in an identity layout after offset stripping for
Full diff: https://github.com/llvm/llvm-project/pull/149867.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 88c2eb3326d96..18e8270f5aa99 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
if (!stridedLayout)
return failure();
- mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+ MemRefLayoutAttrInterface newLayout =
+ StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
+ // Special case: if resetting the offset causes the strided layout to become
+ // the identity layout, then reset to the identity layout.
+ // TODO: this'll get a lot simpler when we have the contiguous layout.
+ SmallVector<int64_t> stridesIfIdentity;
+ if (source.hasStaticShape()) {
+ stridesIfIdentity = computeSuffixProduct(source.getShape());
+ } else if (source.getRank() <= 1) {
+ stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
+ }
+ if (stridesIfIdentity == stridedLayout.getStrides()) {
+ newLayout = AffineMapAttr::get(
+ AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
+ }
+ mb.setLayout(newLayout);
}
return (MemRefType)(mb);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..cc1162d8b0de8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
}
// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
-func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
// CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
// CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
@@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
// CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
// CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
// CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
- %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
- return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+ return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
}
// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..fe2b32be04de4 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
// CHECK-SAME: boundsCheck(false)
// CHECK-SAME: resetOffset
-func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
%ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
- : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
- func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+ : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+ %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+ : memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+ func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
}
// CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the type inference for amdgpu.fat_raw_buffer_cast operations with resetOffset to generate canonical identity layouts instead of strided layouts when appropriate. This addresses compatibility issues with EmulateNarrowTypes pass that expects canonical layouts.
- Implements logic to detect when a strided layout becomes equivalent to an identity layout after offset removal
- Updates return type inference for statically-shaped memrefs and rank ≤ 1 memrefs with unit strides
- Adds comprehensive test coverage for various memref configurations with resetOffset
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | Core implementation of canonical layout inference logic |
| mlir/test/Dialect/AMDGPU/ops.mlir | New test cases covering different memref shapes and stride patterns |
| mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | Updated existing test to reflect new canonical layout behavior |
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp:98
- [nitpick] The variable name 'stridesIfIdentity' could be more descriptive. Consider renaming to 'expectedIdentityStrides' or 'identityLayoutStrides' to better convey that these are the strides that would indicate an identity layout.
SmallVector<int64_t> stridesIfIdentity;
| // the identity layout, then reset to the identity layout. | ||
| // TODO: this'll get a lot simpler when we have the contiguous layout. | ||
| SmallVector<int64_t> stridesIfIdentity; | ||
| if (source.hasStaticShape()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about making a utility function something likeisIdentity for this? Looks useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - in my to-be-split contiguous layouts PR, this was something along the lines of MemRefLayoutAttrInterface StridedLayoutAttr::getCanonical(sizes, strides, offset);, but that's future work / a branch that needs to be rescued.
…ffset (llvm#149867) When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization. Now, the return type inference will put in an identity layout after offset stripping for 1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and 2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.
When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization.
Now, the return type inference will put in an identity layout after offset stripping for