Skip to content

Conversation

@krzysz00
Copy link
Contributor

When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization.

Now, the return type inference will put in an identity layout after offset stripping for

  1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and
  2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.

…ffset

When inferring the return type of amdgpu.fat_raw_buffer_cast with the
offset reset, we would sometimes use a strided layout, like
strided<[1]>, in cases where, after stripping the offset, the memref
had the identity layout. This would cause issues with
EmulateNarrowTypes, which does perform this layout canonicalization.

Now, the return type inference will put in an identity layout after
offset stripping for
1. Statically-shaped memrefs of any rank where the strides match the
suffix product of the shape, and
2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had
their offset removed by resetOffset.
@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Krzysztof Drewniak (krzysz00)

Changes

When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization.

Now, the return type inference will put in an identity layout after offset stripping for

  1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and
  2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.

Full diff: https://github.com/llvm/llvm-project/pull/149867.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+17-1)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+3-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+47-3)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 88c2eb3326d96..18e8270f5aa99 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
     auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
     if (!stridedLayout)
       return failure();
-    mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+    MemRefLayoutAttrInterface newLayout =
+        StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
+    // Special case: if resetting the offset causes the strided layout to become
+    // the identity layout, then reset to the identity layout.
+    // TODO: this'll get a lot simpler when we have the contiguous layout.
+    SmallVector<int64_t> stridesIfIdentity;
+    if (source.hasStaticShape()) {
+      stridesIfIdentity = computeSuffixProduct(source.getShape());
+    } else if (source.getRank() <= 1) {
+      stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
+    }
+    if (stridesIfIdentity == stridedLayout.getStrides()) {
+      newLayout = AffineMapAttr::get(
+          AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
+    }
+    mb.setLayout(newLayout);
   }
   return (MemRefType)(mb);
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..cc1162d8b0de8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
 }
 
 // CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
-func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
   // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
@@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
   // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
   // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
   // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
-  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
-  return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
 }
 
 // CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..fe2b32be04de4 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
 // CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
 // CHECK-SAME: boundsCheck(false)
 // CHECK-SAME: resetOffset
-func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
-    : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
-  func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+    : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
 }
 
 // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

When inferring the return type of amdgpu.fat_raw_buffer_cast with the offset reset, we would sometimes use a strided layout, like strided<[1]>, in cases where, after stripping the offset, the memref had the identity layout. This would cause issues with EmulateNarrowTypes, which does perform this layout canonicalization.

Now, the return type inference will put in an identity layout after offset stripping for

  1. Statically-shaped memrefs of any rank where the strides match the suffix product of the shape, and
  2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had their offset removed by resetOffset.

Full diff: https://github.com/llvm/llvm-project/pull/149867.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+17-1)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+3-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+47-3)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 88c2eb3326d96..18e8270f5aa99 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -89,7 +90,22 @@ static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
     auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
     if (!stridedLayout)
       return failure();
-    mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+    MemRefLayoutAttrInterface newLayout =
+        StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
+    // Special case: if resetting the offset causes the strided layout to become
+    // the identity layout, then reset to the identity layout.
+    // TODO: this'll get a lot simpler when we have the contiguous layout.
+    SmallVector<int64_t> stridesIfIdentity;
+    if (source.hasStaticShape()) {
+      stridesIfIdentity = computeSuffixProduct(source.getShape());
+    } else if (source.getRank() <= 1) {
+      stridesIfIdentity = SmallVector<int64_t>(source.getRank(), 1);
+    }
+    if (stridesIfIdentity == stridedLayout.getStrides()) {
+      newLayout = AffineMapAttr::get(
+          AffineMap::getMultiDimIdentityMap(source.getRank(), ctx));
+    }
+    mb.setLayout(newLayout);
   }
   return (MemRefType)(mb);
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 8871b2ce0eadb..cc1162d8b0de8 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
 }
 
 // CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
-func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
   // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
@@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
   // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
   // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
   // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
-  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
-  return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
 }
 
 // CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 5559ac8f1a5c3..fe2b32be04de4 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
 // CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
 // CHECK-SAME: boundsCheck(false)
 // CHECK-SAME: resetOffset
-func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
-    : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
-  func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+    : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset
+// CHECK: amdgpu.fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref<?xi32, strided<[1], offset: ?>>) -> memref<?xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<?xi32, strided<[1], offset: ?>> to memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<?xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref<i32, strided<[], offset: ?>>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<i32, strided<[], offset: ?>> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<4x4xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref<?x?xi32, strided<[?, 1], offset: ?>>) -> memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<?x?xi32, strided<[?, 1], offset: ?>> to memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<?x?xi32, strided<[?, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset
+// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast
+// CHECK: return %[[ret]]
+func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  %ret = amdgpu.fat_raw_buffer_cast %m resetOffset
+    : memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
+  func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space<fat_raw_buffer>>
 }
 
 // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1

@lialan lialan requested a review from Copilot July 21, 2025 18:50
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the type inference for amdgpu.fat_raw_buffer_cast operations with resetOffset to generate canonical identity layouts instead of strided layouts when appropriate. This addresses compatibility issues with EmulateNarrowTypes pass that expects canonical layouts.

  • Implements logic to detect when a strided layout becomes equivalent to an identity layout after offset removal
  • Updates return type inference for statically-shaped memrefs and rank ≤ 1 memrefs with unit strides
  • Adds comprehensive test coverage for various memref configurations with resetOffset

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Core implementation of canonical layout inference logic
mlir/test/Dialect/AMDGPU/ops.mlir New test cases covering different memref shapes and stride patterns
mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir Updated existing test to reflect new canonical layout behavior
Comments suppressed due to low confidence (1)

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp:98

  • [nitpick] The variable name 'stridesIfIdentity' could be more descriptive. Consider renaming to 'expectedIdentityStrides' or 'identityLayoutStrides' to better convey that these are the strides that would indicate an identity layout.
    SmallVector<int64_t> stridesIfIdentity;

// the identity layout, then reset to the identity layout.
// TODO: this'll get a lot simpler when we have the contiguous layout.
SmallVector<int64_t> stridesIfIdentity;
if (source.hasStaticShape()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making a utility function something likeisIdentity for this? Looks useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - in my to-be-split contiguous layouts PR, this was something along the lines of MemRefLayoutAttrInterface StridedLayoutAttr::getCanonical(sizes, strides, offset);, but that's future work / a branch that needs to be rescued.

@krzysz00 krzysz00 merged commit 9052a85 into llvm:main Jul 21, 2025
14 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…ffset (llvm#149867)

When inferring the return type of amdgpu.fat_raw_buffer_cast with the
offset reset, we would sometimes use a strided layout, like
strided<[1]>, in cases where, after stripping the offset, the memref had
the identity layout. This would cause issues with EmulateNarrowTypes,
which does perform this layout canonicalization.

Now, the return type inference will put in an identity layout after
offset stripping for
1. Statically-shaped memrefs of any rank where the strides match the
suffix product of the shape, and
2. Memrefs of rank <= 1 whose strides are [1] (or []) that just had
their offset removed by resetOffset.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants