Skip to content

Conversation

@amd-eochoalo
Copy link
Contributor

  • Adds lowering for amdgpu.make_dma_base
  • Adds definition for amdgpu.tensor_load_to_lds.
  • Adds definition for amdgpu.tensor_store_from_lds.

Builds on top of #169407

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes
  • Adds lowering for amdgpu.make_dma_base
  • Adds definition for amdgpu.tensor_load_to_lds.
  • Adds definition for amdgpu.tensor_store_from_lds.

Builds on top of #169407


Patch is 30.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169817.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+143-9)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+5)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+76-1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+77-30)
  • (renamed) mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir (+74)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+40)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+65-4)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..1806c747046b8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -80,15 +80,15 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// AMDGPU Type definitions
+//===----------------------------------------------------------------------===//
+
 class AMDGPU_Type<string name, string typeMnemonic, list<Trait> traits = []>
     : TypeDef<AMDGPU_Dialect, name, traits> {
   let mnemonic = typeMnemonic;
 }
 
-//===----------------------------------------------------------------------===//
-// AMDGPU Type definitions
-//===----------------------------------------------------------------------===//
-
 def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let summary = "Pair of base addresses that move data between LDS and global storage.";
   let description = [{
@@ -104,6 +104,15 @@ def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
   let assemblyFormat = "`<` $elementType `>`";
 }
 
+def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
+  let summary = "Descriptors used in tensor store/load operations.";
+  let description = [{
+    This type is opaque and corresponds to the two or four descriptor groups
+    used in tensor_load_to_lds or tensor_store_from_lds.
+  }];
+
+}
+
 //===----------------------------------------------------------------------===//
 // AMDGPU Op definitions
 //===----------------------------------------------------------------------===//
@@ -1222,14 +1231,12 @@ def AMDGPU_MakeDmaBaseOp :
     AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins
                    Arg<AnyMemRef, "buffer to read from">:$src,
-                   Variadic<Index>:$srcIndices,
+                   Variadic<Index>:$src_indices,
                    Arg<AnyMemRef, "buffer to write to">:$dst,
-                   Variadic<Index>:$dstIndices)>,
+                   Variadic<Index>:$dst_indices)>,
     Results<(outs AMDGPU_TDMBaseType: $base)> {
 
   // TODO:
-  // * Add verifiers such that one of the memrefs is from LDS and the other global.
-  // * Add verifiers to make sure that the type is in the correct direction.
   // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
 
   let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
@@ -1240,11 +1247,138 @@ def AMDGPU_MakeDmaBaseOp :
     This operation creates a value corresponding to the tensor descriptor (D#) group 0
     found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
 
+    For example:
+
+    ```mlir
+      %base = amdgpu.make_dma_base %src[%idx0], %dst[%idx1] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+      amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+    ```
+
+    to
+
+    ```mlir
+      // pseudo-code
+      %global_base = llvm.extractvalue %global_memref[1]
+      %global_address = llvm.get_element_ptr ...
+
+      %lds_base = llvm.extractvalue %lds_memref[1]
+      %lds_address = llvm.get_element_ptr ...
+
+      // Definition of %base
+      %undef = llvm.mlir.undef : vector<4xi32>
+      %v0 = llvm.insertelement %15, %undef[0] : vector<4xi32>
+      %v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32>
+      %v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32>
+      %base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32>
+
+      rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+    ```
+
     These tensor DMA operations were introduced in gfx1250.
   }];
 
   let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst) `to` type(results)
+    $src `[` $src_indices `]` `,` $dst `[` $dst_indices `]` attr-dict `:` type($src) `,` type($dst) `->` type(results)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_MakeDmaDescriptorOp :
+  AMDGPU_Op<"make_dma_descriptor", [Pure, AttrSizedOperandSegments]>,
+  Arguments<(ins
+    AMDGPU_TDMBaseType: $base,
+    Variadic<Index>: $global_dynamic_sizes,
+    DenseI64ArrayAttr: $global_static_sizes,
+    Variadic<Index>: $global_dynamic_strides,
+    DenseI64ArrayAttr: $global_static_strides,
+    Variadic<Index>: $shared_dynamic_sizes,
+    DenseI64ArrayAttr: $shared_static_sizes,
+    Optional<Index>: $pad,
+    Optional<Index>: $pad_every,
+    Optional<AnyMemRef>: $atomic_barrier_address,
+    Variadic<Index>: $atomic_barrier_indices,
+    Optional<Index>: $global_increment,
+    Optional<Index>: $lds_increment,
+    Optional<Index>: $iteration_count)>,
+  Results<(outs AMDGPU_TDMDescriptorType: $desc)> {
+
+  let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS.";
+  let description = [{
+     Make all descriptor groups needed by tensor memory operations.
+
+     The $base operand corresponds to the base pair addresses, one must be an address in LDS
+     while the other must be a global memory location.
+
+     $global_{static/dynamic}_sizes determine the size of the tensor.
+     $global_{static/dynamic}_strides determine the strides of the tensor.
+     $shared_{static/dynamic}_sizes determines the size of the tile.
+
+     Padding can be applied to the LDS address when copying from memory to LDS,
+     but not when copying from LDS to memory.
+     The values in the padded target addresses remain the same as before the operation was applied.
+
+     2D and 3D tensors may be iterated over by setting $global_increment, $lds_increment, and $iteration_count.
+     $global_increment determines how much to increment the starting global memory address per iteration in units of the $base's element type.
+     $lds_increment determines how much to increment the starting LDS address per iteration in units of the $base's element type.
+     $iterate_count determines how many times to iterate.
+
+     ```mlir
+      // Example of moving a two-dimensional tensor to LDS.
+      %base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+      amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+
+      // Example of moving a two dimension tensor to LDS where padding is applied after every integer.
+      %base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
+      %descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad pad_every %pad_every) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
+      amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
+     ```
+  }];
+
+  let assemblyFormat = [{
+    $base
+    `globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes)
+    `globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides)
+    `sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes)
+    ( `padShared` `(` $pad^ `every` $pad_every `)` )?
+    ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]`
+                      `:` type($atomic_barrier_address) `)`)?
+    ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )?
+    attr-dict `:` qualified(type($base)) `->` type(results)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_TensorLoadToLDSOp :
+  AMDGPU_Op<"tensor_load_to_lds", [MemoryEffects<[MemWrite]>, MemoryEffects<[MemRead]>]>,
+  Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> {
+  let summary = "Load tensors from global memory to LDS.";
+  let description = [{
+    Load tensors of up to five dimensions from global memory to LDS.
+
+    The operation is fully described by the descriptor operand.
+  }];
+
+  let assemblyFormat = [{
+    $desc attr-dict `:` qualified(type($desc))
+  }];
+}
+
+def AMDGPU_TensorStoreFromLDSOp :
+  AMDGPU_Op<"tensor_store_from_lds", [MemoryEffects<[MemWrite]>, MemoryEffects<[MemRead]>]>,
+  Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> {
+  let summary = "Store tensors from LDS to global memory.";
+  let description = [{
+    Store tensors of up to five dimensions from LDS to global memory.
+
+    The operation is fully described by the descriptor operand.
+  }];
+
+  let assemblyFormat = [{
+    $desc attr-dict `:` qualified(type($desc))
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index a7680fb5c3191..958757da0933e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -48,6 +48,11 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
                                   IntegerAttr m, IntegerAttr n, IntegerAttr k) {
   printMNKDimensionList(printer, m, n, k);
 }
+
+// Utility functions for quering the address space.
+bool hasGlobalMemorySpace(Attribute memorySpace);
+bool hasWorkgroupMemorySpace(Attribute memorySpace);
+bool hasFatRawBufferMemorySpace(Attribute memorySpace);
 } // namespace mlir::amdgpu
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..3316e16a05d5c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2264,6 +2264,76 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
   }
 };
 
+struct AMDGPUMakeDmaBaseLowering
+    : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("make_dma_base is only supported on gfx1250");
+
+    Location loc = op.getLoc();
+
+    ValueRange srcIndices = adaptor.getSrcIndices();
+    Value src = adaptor.getSrc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, src, srcIndices);
+
+    ValueRange dstIndices = adaptor.getDstIndices();
+    Value dst = adaptor.getDst();
+    auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
+
+    Value dstPtr =
+        getStridedElementPtr(rewriter, loc, dstMemRefType, dst, dstIndices);
+
+    bool storeFrom = hasWorkgroupMemorySpace(srcMemRefType.getMemorySpace());
+    Value ldsAddr = storeFrom ? srcPtr : dstPtr;
+    Value globalAddr = storeFrom ? dstPtr : srcPtr;
+
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+
+    Value castForLdsAddr =
+        LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsAddr);
+    Value castForGlobalAddr =
+        LLVM::PtrToIntOp::create(rewriter, loc, i64, globalAddr);
+
+    Value mask = createI64Constant(rewriter, loc, 0x1FFFFFFFFFFFFFF);
+    Value first57BitsOfGlobalAddr =
+        LLVM::AndOp::create(rewriter, loc, castForGlobalAddr, mask);
+    Value shift = LLVM::LShrOp::create(rewriter, loc, first57BitsOfGlobalAddr,
+                                       createI64Constant(rewriter, loc, 32));
+
+    Value lowHalf =
+        LLVM::TruncOp::create(rewriter, loc, i32, first57BitsOfGlobalAddr);
+    Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
+
+    Value c0 = createI32Constant(rewriter, loc, 0);
+    Value c1 = createI32Constant(rewriter, loc, 1);
+    Value c2 = createI32Constant(rewriter, loc, 2);
+    Value c3 = createI32Constant(rewriter, loc, 3);
+
+    Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+    Value result = LLVM::UndefOp::create(rewriter, loc, v4i32);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result, c0, c0);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result,
+                                           castForLdsAddr, c1);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
+    result = LLVM::InsertElementOp::create(rewriter, loc, result, highHalf, c3);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
   using Base::Base;
@@ -2278,6 +2348,10 @@ struct ConvertAMDGPUToROCDLPass
 
     RewritePatternSet patterns(ctx);
     LLVMTypeConverter converter(ctx);
+    converter.addConversion([&](TDMBaseType type) -> Type {
+      Type i32 = IntegerType::get(type.getContext(), 32);
+      return converter.convertType(VectorType::get(4, i32));
+    });
     populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
     LLVMConversionTarget target(getContext());
     target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
@@ -2333,6 +2407,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering, TransposeLoadOpLowering,
-           AMDGPUPermlaneLowering>(converter, chipset);
+           AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering>(converter,
+                                                              chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..8fc6220efc6ad 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -41,6 +41,38 @@ using namespace mlir::amdgpu;
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
 
+namespace mlir::amdgpu {
+bool hasGlobalMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return true;
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  return false;
+}
+
+bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 3;
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 7;
+  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
+  return false;
+}
+} // namespace mlir::amdgpu
+
 namespace {
 struct AMDGPUInlinerInterface final : DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
@@ -158,36 +190,6 @@ LogicalResult FatRawBufferCastOp::verify() {
   return success();
 }
 
-static bool hasGlobalMemorySpace(Attribute memorySpace) {
-  if (!memorySpace)
-    return true;
-  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
-    return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
-    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
-  return false;
-}
-
-static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
-  if (!memorySpace)
-    return false;
-  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
-    return intMemorySpace.getInt() == 3;
-  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
-    return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
-  return false;
-}
-
-static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
-  if (!memorySpace)
-    return false;
-  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
-    return intMemorySpace.getInt() == 7;
-  if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
-    return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
-  return false;
-}
-
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//
@@ -705,6 +707,51 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MakeDmaBaseOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaBaseOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+  bool store_from_lds = hasWorkgroupMemorySpace(srcType.getMemorySpace()) &&
+                        hasGlobalMemorySpace(dstType.getMemorySpace());
+  bool load_to_lds = hasGlobalMemorySpace(srcType.getMemorySpace()) &&
+                     hasWorkgroupMemorySpace(dstType.getMemorySpace());
+  bool is_valid = store_from_lds != load_to_lds;
+  if (!is_valid)
+    return emitOpError("invalid combination of address spaces.");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeDmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MakeDmaDescriptorOp::verify() {
+  ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides();
+
+  if (globalStaticStrides.empty()) {
+    return emitOpError("strides must not be empty.");
+  }
+  if (globalStaticStrides.back() != 1) {
+    return emitOpError("strides for the innermost dimension must be 1.");
+  }
+
+  ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes();
+  size_t rank = globalStaticSizes.size();
+  if (rank != globalStaticStrides.size()) {
+    return emitOpError("strides and sizes must have same rank.");
+  }
+
+  ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes();
+  if (rank != sharedStaticSizes.size()) {
+    return emitOpError("tensor must have same rank as tile.");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ScaledMFMAOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
similarity index 73%
rename from mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index d2391140ce056..96d03a427215f 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -162,3 +162,77 @@ func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
   return %ret0: vector<16xf64>
 }
+
+// -----
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+#amdgpu_fat_buffer_addrspace = 7
+
+// CHECK-LABEL: func @make_dma_base
+// ...
[truncated]

Comment on lines -1225 to 1237
Variadic<Index>:$srcIndices,
Variadic<Index>:$src_indices,
Arg<AnyMemRef, "buffer to write to">:$dst,
Variadic<Index>:$dstIndices)>,
Variadic<Index>:$dst_indices)>,
Results<(outs AMDGPU_TDMBaseType: $base)> {
Copy link
Contributor Author

@amd-eochoalo amd-eochoalo Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 thinking about this, maybe it would be best to instead of having src, dst memrefs here, just have lds and global memrefs. As you mentioned, the base is intended to be reused across store and load operations. Then it would be easier to fold a pair of instructions:

%base0 = amdgpu.make_dma_base %lds[...], %global[...]
%base1 = amdgpu.make_dma_base %lds[...], %global[...]

Rather than

                             // src       , dst
%base0 = amdgpu.make_dma_base %lds[...], %global[...]
%base1 = amdgpu.make_dma_base %global[...], %lds[...]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants