Skip to content

Commit 2c8d19f

Browse files
amd-eochoalokcloudy0717
authored andcommitted
[mlir][amdgpu] Lower amdgpu.make_dma_base (llvm#169817)
* Adds lowering for `amdgpu.make_dma_base`
1 parent d878878 commit 2c8d19f

File tree

6 files changed

+183
-27
lines changed

6 files changed

+183
-27
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,15 +1229,13 @@ def AMDGPU_ScaledMFMAOp :
12291229

12301230
def AMDGPU_MakeDmaBaseOp :
12311231
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
1232-
Arguments<(ins
1233-
Arg<AnyMemRef, "buffer to read from">:$src,
1234-
Variadic<Index>:$src_indices,
1235-
Arg<AnyMemRef, "buffer to write to">:$dst,
1236-
Variadic<Index>:$dst_indices)>,
1232+
Arguments<(ins Arg<AnyMemRef>:$global,
1233+
Variadic<Index>:$global_indices,
1234+
Arg<AnyMemRef>:$lds,
1235+
Variadic<Index>:$lds_indices)>,
12371236
Results<(outs AMDGPU_TDMBaseType: $base)> {
12381237

12391238
// TODO:
1240-
// * Add verifiers such that one of the memrefs is from LDS and the other global.
12411239
// * Add verifiers to make sure that the number of indices do not exceed the number of dimensions.
12421240

12431241
let summary = "Pair of based addresses used when moving tiles between LDS and global memory.";
@@ -1251,35 +1249,39 @@ def AMDGPU_MakeDmaBaseOp :
12511249
For example:
12521250

12531251
```mlir
1254-
%base = amdgpu.make_dma_base %src[%idx0], %dst[%idx1] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
1252+
%base = amdgpu.make_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
12551253
%descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
12561254
amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
12571255
```
12581256

12591257
to
12601258

12611259
```mlir
1262-
// pseudocode
1263-
%base_0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr)>
1264-
%base_1 = llvm.insertvalue %global_addr, %base_0[0] : !llvm.struct<(ptr, ptr)>
1265-
%base_2 = llvm.insertvalue %lds_addr, %base_1[1] : !llvm.struct(ptr, ptr)>
1266-
// type(%base_2) = !llvm.struct<(ptr, ptr) roughly corresponds to amdgpu.tdm_base<i32>
1267-
1268-
// The base will be used when contructing dgroup0
1269-
// when lowering amdgpu.make_dma_descriptor
1270-
%dgroup0_0 = llvm.mlir.undef : !llvm.struct<(....)>
1271-
%dgroup0_1 = llvm.insertvalue %base2, %dgroup0_0 : ....
1272-
1273-
// When lowering amdgpu.tensor_load_to_lds
1274-
rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
1260+
// pseudo-code
1261+
%global_base = llvm.extractvalue %global_memref[1]
1262+
%global_address = llvm.get_element_ptr ...
1263+
1264+
%lds_base = llvm.extractvalue %lds_memref[1]
1265+
%lds_address = llvm.get_element_ptr ...
1266+
1267+
// Definition of %base
1268+
%undef = llvm.mlir.undef : vector<4xi32>
1269+
%v0 = llvm.insertelement %15, %undef[0] : vector<4xi32>
1270+
%v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32>
1271+
%v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32>
1272+
%base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32>
1273+
1274+
rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
12751275
```
12761276

12771277
These tensor DMA operations were introduced in gfx1250.
12781278
}];
12791279

12801280
let assemblyFormat = [{
1281-
$src `[` $src_indices `]` `,` $dst `[` $dst_indices `]` attr-dict `:` type($src) `,` type($dst) `->` type(results)
1281+
$global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results)
12821282
}];
1283+
1284+
let hasVerifier = 1;
12831285
}
12841286

12851287
def AMDGPU_MakeDmaDescriptorOp :
@@ -1323,12 +1325,12 @@ def AMDGPU_MakeDmaDescriptorOp :
13231325

13241326
```mlir
13251327
// Example of moving a two-dimensional tensor to LDS.
1326-
%base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
1328+
%base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
13271329
%descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
13281330
amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
13291331

13301332
// Example of moving a two dimension tensor to LDS where padding is applied after every integer.
1331-
%base = amdgpu.make_dma_base %src[0, 0], %dst[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
1333+
%base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
13321334
%descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad pad_every %pad_every) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
13331335
amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
13341336
```

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2264,6 +2264,77 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
22642264
}
22652265
};
22662266

2267+
struct AMDGPUMakeDmaBaseLowering
2268+
: public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2269+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2270+
2271+
AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
2272+
: ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2273+
Chipset chipset;
2274+
2275+
LogicalResult
2276+
matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2277+
ConversionPatternRewriter &rewriter) const override {
2278+
if (chipset < kGfx1250)
2279+
return op->emitOpError("make_dma_base is only supported on gfx1250");
2280+
2281+
Location loc = op.getLoc();
2282+
2283+
ValueRange ldsIndices = adaptor.getLdsIndices();
2284+
Value lds = adaptor.getLds();
2285+
auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2286+
2287+
Value ldsPtr =
2288+
getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
2289+
2290+
ValueRange globalIndices = adaptor.getGlobalIndices();
2291+
Value global = adaptor.getGlobal();
2292+
auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2293+
2294+
Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
2295+
global, globalIndices);
2296+
2297+
Type i32 = rewriter.getI32Type();
2298+
Type i64 = rewriter.getI64Type();
2299+
2300+
Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2301+
Value castForGlobalAddr =
2302+
LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2303+
2304+
Value lowHalf =
2305+
LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2306+
2307+
Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2308+
createI64Constant(rewriter, loc, 32));
2309+
2310+
Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2311+
2312+
Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
2313+
Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2314+
2315+
Value typeField = createI32Constant(rewriter, loc, 2 << 30);
2316+
Value highHalfPlusType =
2317+
LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2318+
2319+
Value c0 = createI32Constant(rewriter, loc, 0);
2320+
Value c1 = createI32Constant(rewriter, loc, 1);
2321+
Value c2 = createI32Constant(rewriter, loc, 2);
2322+
Value c3 = createI32Constant(rewriter, loc, 3);
2323+
2324+
Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2325+
Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2326+
result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
2327+
result = LLVM::InsertElementOp::create(rewriter, loc, result,
2328+
castForLdsAddr, c1);
2329+
result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
2330+
result = LLVM::InsertElementOp::create(rewriter, loc, result,
2331+
highHalfPlusType, c3);
2332+
2333+
rewriter.replaceOp(op, result);
2334+
return success();
2335+
}
2336+
};
2337+
22672338
struct ConvertAMDGPUToROCDLPass
22682339
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
22692340
using Base::Base;
@@ -2278,6 +2349,10 @@ struct ConvertAMDGPUToROCDLPass
22782349

22792350
RewritePatternSet patterns(ctx);
22802351
LLVMTypeConverter converter(ctx);
2352+
converter.addConversion([&](TDMBaseType type) -> Type {
2353+
Type i32 = IntegerType::get(type.getContext(), 32);
2354+
return converter.convertType(VectorType::get(4, i32));
2355+
});
22812356
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
22822357
LLVMConversionTarget target(getContext());
22832358
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
@@ -2333,6 +2408,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
23332408
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
23342409
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
23352410
GatherToLDSOpLowering, TransposeLoadOpLowering,
2336-
AMDGPUPermlaneLowering>(converter, chipset);
2411+
AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering>(converter,
2412+
chipset);
23372413
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
23382414
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,24 @@ LogicalResult TransposeLoadOp::verify() {
705705
return success();
706706
}
707707

708+
//===----------------------------------------------------------------------===//
709+
// MakeDmaBaseOp
710+
//===----------------------------------------------------------------------===//
711+
712+
LogicalResult MakeDmaBaseOp::verify() {
713+
MemRefType ldsType = cast<MemRefType>(getLds().getType());
714+
MemRefType globalType = cast<MemRefType>(getGlobal().getType());
715+
if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) {
716+
return emitOpError(
717+
"lds memref must have workgroup address space attribute.");
718+
}
719+
if (!hasGlobalMemorySpace(globalType.getMemorySpace())) {
720+
return emitOpError(
721+
"global memref must have global address space attribute.");
722+
}
723+
return success();
724+
}
725+
708726
//===----------------------------------------------------------------------===//
709727
// MakeDmaDescriptorOp
710728
//===----------------------------------------------------------------------===//

mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir renamed to mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,51 @@ func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M
162162
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
163163
return %ret0: vector<16xf64>
164164
}
165+
166+
// -----
167+
168+
#gpu_global_addrspace = 1
169+
#gpu_lds_addrspace = 3
170+
#amdgpu_fat_buffer_addrspace = 7
171+
172+
// CHECK-LABEL: func @make_dma_base
173+
// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>)
174+
func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) {
175+
// CHECK-DAG: %[[INT:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
176+
// CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1>
177+
// CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3>
178+
179+
// CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1>
180+
// CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3>
181+
182+
// CHECK-DAG: %[[MEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[MEM_BASE_PTR]][%[[INT]]]
183+
// CHECK-DAG: %[[SMEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[SMEM_BASE_PTR]][%[[INT]]]
184+
185+
// CHECK-DAG: %[[MEM_INT:.+]] = llvm.ptrtoint %[[MEM_BASE_OFFSET]] : !llvm.ptr<1> to i64
186+
// CHECK-DAG: %[[SMEM_INT:.+]] = llvm.ptrtoint %[[SMEM_BASE_OFFSET]] : !llvm.ptr<3> to i32
187+
188+
// CHECK: %[[MEM_INT_LOW:.+]] = llvm.trunc %[[MEM_INT]] : i64 to i32
189+
// CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64)
190+
// CHECK: %[[SHIFTED_MEM_INT:.+]] = llvm.lshr %[[MEM_INT]], %[[SHIFT]]
191+
// CHECK: %[[MEM_INT_HIGH:.+]] = llvm.trunc %[[SHIFTED_MEM_INT]] : i64 to i32
192+
// CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32)
193+
// CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]]
194+
195+
// CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32)
196+
// CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]]
197+
198+
// CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32
199+
// CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32
200+
// CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32
201+
// CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32
202+
203+
// CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32>
204+
// CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32]
205+
// CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32]
206+
// CHECK: %[[V4I32_0_3:.+]] = llvm.insertelement %[[MEM_INT_LOW]], %[[V4I32_0_2]][%[[C2]] : i32]
207+
// CHECK: %[[V4I32_0_4:.+]] = llvm.insertelement %[[MEM_INT_HIGH_TYPE]], %[[V4I32_0_3]][%[[C3]] : i32]
208+
209+
%0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32>
210+
211+
func.return %0 : !amdgpu.tdm_base<i32>
212+
}

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,20 @@ func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32x
357357

358358
// -----
359359

360+
func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) {
361+
// expected-error@+1 {{'amdgpu.make_dma_base' op lds memref must have workgroup address space attribute.}}
362+
amdgpu.make_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_base<i32>
363+
}
364+
365+
// -----
366+
367+
func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) {
368+
// expected-error@+1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}}
369+
amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
370+
}
371+
372+
// -----
373+
360374
func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) {
361375
// expected-error@+1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}}
362376
amdgpu.make_dma_descriptor %base globalSize [0] globalStride [1] sharedSize [0] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,6 @@ func.func @memory_counter_wait() {
691691
func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) {
692692
// CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
693693
amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
694-
695-
// CHECK: amdgpu.make_dma_base %[[SMEM]][%[[IDX]]], %[[MEM]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> -> !amdgpu.tdm_base<i32>
696-
amdgpu.make_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32> -> !amdgpu.tdm_base<i32>
697694
func.return
698695
}
699696

@@ -748,3 +745,4 @@ func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8x
748745

749746
func.return
750747
}
748+

0 commit comments

Comments
 (0)