Skip to content

Commit 9af00e6

Browse files
authored
[mlir][amdgpu] Add make_dma_base operation (llvm#169086)
1 parent d58ebe3 commit 9af00e6

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def AMDGPU_Dialect : Dialect {
3333
"gpu::GPUDialect"
3434
];
3535
let useDefaultAttributePrinterParser = 1;
36+
let useDefaultTypePrinterParser = 1;
3637
}
3738

3839
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
@@ -79,6 +80,30 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace,
7980
let assemblyFormat = "`<` $value `>`";
8081
}
8182

83+
class AMDGPU_Type<string name, string typeMnemonic, list<Trait> traits = []>
84+
: TypeDef<AMDGPU_Dialect, name, traits> {
85+
let mnemonic = typeMnemonic;
86+
}
87+
88+
//===----------------------------------------------------------------------===//
89+
// AMDGPU Type definitions
90+
//===----------------------------------------------------------------------===//
91+
92+
def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> {
93+
let summary = "Pair of base addresses that move data between LDS and global storage.";
94+
let description = [{
95+
This type is opaque and it is used to represent a struct of two addresses.
96+
One address is in LDS while the other is in global memory.
97+
}];
98+
let parameters = (ins "Type":$elementType);
99+
let builders = [
100+
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
101+
return $_get(elementType.getContext(), elementType);
102+
}]>
103+
];
104+
let assemblyFormat = "`<` $elementType `>`";
105+
}
106+
82107
//===----------------------------------------------------------------------===//
83108
// AMDGPU Op definitions
84109
//===----------------------------------------------------------------------===//
@@ -1192,4 +1217,35 @@ def AMDGPU_ScaledMFMAOp :
11921217
}];
11931218
let hasCanonicalizer = 1;
11941219
}
1220+
1221+
def AMDGPU_MakeDmaBaseOp :
1222+
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
1223+
Arguments<(ins
1224+
Arg<AnyMemRef, "buffer to read from">:$src,
1225+
Variadic<Index>:$srcIndices,
1226+
Arg<AnyMemRef, "buffer to write to">:$dst,
1227+
Variadic<Index>:$dstIndices)>,
1228+
Results<(outs AMDGPU_TDMBaseType: $base)> {
1229+
1230+
// TODO:
1231+
// * Add verifiers such that one of the memrefs is from LDS and the other global.
1232+
// * Add verifiers to make sure that the type is in the correct direction.
1233+
// * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
1234+
1235+
let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
1236+
let description = [{
1237+
This operation creates a pair of addresses that will be used by tensor_load_to_lds
1238+
and tensor_store_from_lds.
1239+
1240+
This operation creates a value corresponding to the tensor descriptor (D#) group 0
1241+
found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect.
1242+
1243+
These tensor DMA operations were introduced in gfx1250.
1244+
}];
1245+
1246+
let assemblyFormat = [{
1247+
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst) `to` type(results)
1248+
}];
1249+
}
1250+
11951251
#endif // AMDGPU

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
2626

2727
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
28+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc"
2829

2930
namespace mlir::amdgpu {
3031
/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
@@ -52,6 +53,9 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
5253
#define GET_ATTRDEF_CLASSES
5354
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
5455

56+
#define GET_TYPEDEF_CLASSES
57+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc"
58+
5559
#define GET_OP_CLASSES
5660
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.h.inc"
5761

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() {
5555
#define GET_OP_LIST
5656
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
5757
>();
58+
addTypes<
59+
#define GET_TYPEDEF_LIST
60+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
61+
>();
5862
addAttributes<
5963
#define GET_ATTRDEF_LIST
6064
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
@@ -839,5 +843,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
839843
#define GET_ATTRDEF_CLASSES
840844
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
841845

846+
#define GET_TYPEDEF_CLASSES
847+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
848+
842849
#define GET_OP_CLASSES
843850
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,15 @@ func.func @memory_counter_wait() {
685685
amdgpu.memory_counter_wait exp(4)
686686
func.return
687687
}
688+
689+
// CHECK-LABEL: func @make_dma_base
690+
// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space<workgroup>>)
691+
func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) {
692+
// CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> to !amdgpu.tdm_base<i32>
693+
amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> to !amdgpu.tdm_base<i32>
694+
695+
// CHECK: amdgpu.make_dma_base %[[SMEM]][%[[IDX]]], %[[MEM]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> to !amdgpu.tdm_base<i32>
696+
amdgpu.make_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> to !amdgpu.tdm_base<i32>
697+
func.return
698+
}
699+

0 commit comments

Comments
 (0)