Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a4a1a59
[mlir][amdgpu] Add make_dma_base operation
amd-eochoalo Nov 21, 2025
d14f3e2
Remove MemRead and MemWrite from operation
amd-eochoalo Nov 24, 2025
d3ca18c
Add Pure to make_dma_base
amd-eochoalo Nov 24, 2025
76e47f1
Add DynamicIndexList
amd-eochoalo Nov 24, 2025
f6f67e3
Add globalStride
amd-eochoalo Nov 24, 2025
1e2668c
Add verifier for innermost dimension
amd-eochoalo Nov 24, 2025
f1df3c5
Add sharedSize
amd-eochoalo Nov 24, 2025
a24a840
Add optional atomic barrier
amd-eochoalo Nov 24, 2025
ccaf771
Add iterate
amd-eochoalo Nov 24, 2025
566d2e6
[mlir][amdgpu] Add make_dma_descriptor.
amd-eochoalo Nov 24, 2025
2be4ccc
Fix indentation
amd-eochoalo Nov 24, 2025
b3ba450
Review
amd-eochoalo Nov 24, 2025
d34c423
Fix parser
amd-eochoalo Nov 24, 2025
cfb20cc
whitespace
amd-eochoalo Nov 24, 2025
5e98ed0
Fix parser
amd-eochoalo Nov 24, 2025
5cca5f9
check if it is not empty
amd-eochoalo Nov 25, 2025
0f913f5
less variables
amd-eochoalo Nov 25, 2025
adcbc32
mlir example
amd-eochoalo Nov 25, 2025
61fd94d
MLIR examples
amd-eochoalo Nov 25, 2025
3de0f3c
Use custom<DynamicIndexList> for indices.
amd-eochoalo Nov 25, 2025
0a70e24
Update mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
amd-eochoalo Nov 25, 2025
ec58b7c
Merge branch 'main' into eochoa/2025-11-24/amdgpu-make-dma-descriptor
amd-eochoalo Nov 26, 2025
29072b8
Remove OptionalAttr from static indices
amd-eochoalo Nov 26, 2025
50e76d4
Atomic barrier only takes Variadic<Index>
amd-eochoalo Nov 26, 2025
b339c7a
Do not use attribute splitting
amd-eochoalo Nov 26, 2025
fb82ac3
Rename $every to $pad_every.
amd-eochoalo Nov 26, 2025
445f96e
Remove unused functions
amd-eochoalo Nov 26, 2025
e022322
Only use dynamic indices in make_dma_base
amd-eochoalo Nov 26, 2025
850d6d0
Add verification for same rank
amd-eochoalo Nov 26, 2025
a8fbe1a
Verify tile and tensor's rank are the same
amd-eochoalo Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def AMDGPU_Dialect : Dialect {
"gpu::GPUDialect"
];
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
Expand Down Expand Up @@ -79,6 +80,40 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
let assemblyFormat = "`<` $value `>`";
}


//===----------------------------------------------------------------------===//
// AMDGPU Type definitions
//===----------------------------------------------------------------------===//

class AMDGPU_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<AMDGPU_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
let summary = "Pair of base addresses that move data between LDS and global storage.";
let description = [{
This type is opaque and it is used to represent a struct of two addresses.
One address is in LDS while the other is in global memory.
}];
let parameters = (ins "Type":$elementType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
return $_get(elementType.getContext(), elementType);
}]>
];
let assemblyFormat = "`<` $elementType `>`";
}

def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
let summary = "Descriptors used in tensor store/load operations.";
let description = [{
This type is opaque and corresponds to the two or four descriptor groups
used in tensor_load_to_lds or tensor_store_from_lds.
}];

}

//===----------------------------------------------------------------------===//
// AMDGPU Op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1192,4 +1227,75 @@ def AMDGPU_ScaledMFMAOp :
}];
let hasCanonicalizer = 1;
}

def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins
Arg<AnyMemRef, "buffer to read from">:$src,
Variadic<Index>:$srcIndices,
Arg<AnyMemRef, "buffer to write to">:$dst,
Variadic<Index>:$dstIndices)>,
Results<(outs AMDGPU_TDMBaseType: $base)> {

// TODO:
// * Add verifiers such that one of the memrefs is from LDS and the other global.
// * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.

let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
let description = [{
This operation creates a pair of addresses that will be used by tensor_load_to_lds
and tensor_store_from_lds.

This operation creates a value corresponding roughly to the descriptor group 0
found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
}];
Comment on lines 1249 to 1278
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add an mlir example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping on this


let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst) `to` type(results)
}];
}

def AMDGPU_MakeDmaDescriptorOp :
AMDGPU_Op<"make_dma_descriptor", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins
AMDGPU_TDMBaseType: $base,
Variadic<Index>: $global_dynamic_sizes,
OptionalAttr<DenseI64ArrayAttr>: $global_static_sizes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing examples of a similar pattern - subview, for example, don't use an optional here

(The static indices array must have a fixed length and use -1 to refer to dynamic dimensions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variadic<Index>: $global_dynamic_strides,
OptionalAttr<DenseI64ArrayAttr>: $global_static_strides,
Variadic<Index>: $shared_dynamic_sizes,
OptionalAttr<DenseI64ArrayAttr>: $shared_static_sizes,
Optional<Index>: $pad,
OptionalAttr<IndexAttr>: $pad_const,
Optional<Index>: $every,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this padEvery? Or something else more descriptive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptionalAttr<IndexAttr>: $every_const,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be an optional index - doesn't need attribute splitting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional<AnyMemRef>: $atomic_barrier_address,
Variadic<Index>: $atomic_barrier_dynamic_indices,
OptionalAttr<DenseI64ArrayAttr>: $atomic_barrier_static_indices,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atomic barrier part is more of a pure memref operation, and should probably just take Variadic<Index> (it's selecting a pointer, unlike all the other bits, which are doing size/strides/offsets subview-style)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional<Index>: $global_increment,
Optional<Index>: $lds_increment,
Optional<Index>: $iteration_count)>,
Results<(outs AMDGPU_TDMDescriptorType: $desc)> {

let summary = "TODO";
let description = [{
TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

}];

let assemblyFormat = [{
$base
`globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes)
`globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides)
`sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes)
( `padShared` `(` custom<DynamicIndex>($pad, $pad_const)^ `every` custom<DynamicIndex>($every, $every_const) `)` )?
( `atomicBarrier` `(` $atomic_barrier_address^
custom<DynamicIndexList>($atomic_barrier_dynamic_indices, $atomic_barrier_static_indices)
`:` type($atomic_barrier_address) `)`)?
( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )?
attr-dict `:` qualified(type($base)) `to` type(results)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

}];

let hasVerifier = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want the OpFoldResult helpers for combining the static/dynamic parts, I claim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point here but let's leave that for the next PR which will add the lowering.

}

#endif // AMDGPU
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc"

namespace mlir::amdgpu {
/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
Expand Down Expand Up @@ -52,6 +53,9 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.h.inc"

Expand Down
40 changes: 40 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,37 @@ struct AMDGPUInlinerInterface final : DialectInlinerInterface {
};
} // namespace

static ParseResult
parseDynamicIndex(OpAsmParser &parser,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't need this

For the non-vararg cases, we can use an argument.

For the vararg cases (sizes, strides, etc.) we'll want to use the same machinery memref.subview does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::optional<OpAsmParser::UnresolvedOperand> dynamicSize,
IntegerAttr &staticSize) {
int64_t staticVal;
if (parser.parseOptionalInteger(staticVal).has_value()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/optional: you could push the definition inside the if, since it not used elsewhere

Suggested change
int64_t staticVal;
if (parser.parseOptionalInteger(staticVal).has_value()) {
if (int64_t staticVal; parser.parseOptionalInteger(staticVal).has_value()) {

staticSize = parser.getBuilder().getIndexAttr(staticVal);
return success();
}

return parser.parseOperand(dynamicSize.value());
}

static void printDynamicIndex(OpAsmPrinter &printer, Operation *op,
Value dynamicSize, IntegerAttr staticSize) {
if (staticSize) {
printer << staticSize.getValue();
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return early and no else

printer << dynamicSize;
}
}

void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
Expand Down Expand Up @@ -701,6 +727,17 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// MakeDmaDescriptorOp
//===----------------------------------------------------------------------===//

LogicalResult MakeDmaDescriptorOp::verify() {
if (getGlobalStaticStrides()->back() != 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does something else already guarantee there's at least one element?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return emitOpError("strides for the innermost dimension must be 1.");
}
return success();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation conditions: Number of strides == number of sizes == number of offsets, and the global and LDS tiles have the same dimensions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -839,5 +876,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
10 changes: 10 additions & 0 deletions mlir/test/Dialect/AMDGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,13 @@ func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32x
%0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}

// -----

// CHECK-LABEL: func @make_dma_descriptor
// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>)
func.func @make_dma_descriptor_invalid_strides(%base: !amdgpu.tdm_base<i32>) {
// expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides for the innermost dimension must be 1.}}
amdgpu.make_dma_descriptor %base globalSize [0] globalStride [1, 2] sharedSize [0] : !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor
func.return
}
77 changes: 77 additions & 0 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,80 @@ func.func @memory_counter_wait() {
amdgpu.memory_counter_wait exp(4)
func.return
}

// CHECK-LABEL: func @make_dma_base
// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space<workgroup>>)
func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) {
// CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> to !amdgpu.tdm_base<i32>
amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> to !amdgpu.tdm_base<i32>

// CHECK: amdgpu.make_dma_base %[[SMEM]][%[[IDX]]], %[[MEM]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> to !amdgpu.tdm_base<i32>
amdgpu.make_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> to !amdgpu.tdm_base<i32>
func.return
}

// CHECK-LABEL: func @make_dma_descriptor
// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[BARRIER:.+]]: memref<8xi32>, %[[IDX:.+]]: index)
func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {

// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
amdgpu.make_dma_descriptor %base
// CHECK-SAME: globalSize [0]
globalSize [0]
// CHECK-SAME: globalStride [1]
globalStride [1]
// CHECK-SAME: sharedSize [0] : !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor
sharedSize [0] : !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor

// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
amdgpu.make_dma_descriptor %base
// CHECK-SAME: globalSize [0]
globalSize [0]
// CHECK-SAME: globalStride [1]
globalStride [1]
// CHECK-SAME: sharedSize [0]
sharedSize [0]
// CHECK-SAME: padShared(1 every 1)
padShared(1 every 1)
: !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor

// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
amdgpu.make_dma_descriptor %base
// CHECK-SAME: globalSize [0]
globalSize [0]
// CHECK-SAME: globalStride [1]
globalStride [1]
// CHECK-SAME: sharedSize [0]
sharedSize [0]
// CHECK-SAME: padShared(1 every 1)
padShared(%idx every %idx)
: !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor

// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
amdgpu.make_dma_descriptor %base
// CHECK-SAME: globalSize [0]
globalSize [0]
// CHECK-SAME: globalStride [1]
globalStride [1]
// CHECK-SAME: sharedSize [0]
sharedSize [0]
// CHECK-SAME: atomicBarrier(%[[BARRIER]] [0] : memref<8xi32>)
atomicBarrier(%barrier [0] : memref<8xi32>)
: !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor

// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
amdgpu.make_dma_descriptor %base
// CHECK-SAME: globalSize [0]
globalSize [0]
// CHECK-SAME: globalStride [1]
globalStride [1]
// CHECK-SAME: sharedSize [0]
sharedSize [0]
// CHECK-SAME: iterate %[[IDX]], %[[IDX]], %[[IDX]]
iterate %idx, %idx, %idx
: !amdgpu.tdm_base<i32> to !amdgpu.tdm_descriptor


func.return
}

Loading