Skip to content

Commit 96bd837

Browse files
Merge branch 'main' into gfx950-mfma-rocdl
2 parents 1cf0e2d + 27f15ad commit 96bd837

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,36 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
418418
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
419419
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
420420

421+
//===---------------------------------------------------------------------===//
422+
// LDS transpose intrinsics (available in GFX950)
423+
424+
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
425+
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
426+
427+
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
428+
ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
429+
Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
430+
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
431+
}
432+
433+
def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
434+
def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
435+
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
436+
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
437+
438+
//===---------------------------------------------------------------------===//
439+
// Global load to LDS intrinsic (available in GFX950)
440+
441+
def ROCDL_GlobalLoadLDSOp :
442+
ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
443+
Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
444+
Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
445+
I32:$size,
446+
I32:$offset,
447+
I32:$aux)> {
448+
let assemblyFormat = "operands attr-dict";
449+
}
450+
421451
//===---------------------------------------------------------------------===//
422452
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
423453
// raw buffer mode).

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,32 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
227227
llvm.return
228228
}
229229

230+
llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
231+
// CHECK-LABEL: rocdl.ds.read.tr
232+
// CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
233+
%r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
234+
// CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
235+
%r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
236+
// CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
237+
%r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
238+
// CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
239+
%r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
240+
// CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
241+
%r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
242+
llvm.return %r3 : vector<4xf16>
243+
}
244+
245+
llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
246+
%aux = llvm.mlir.constant(0 : i32) : i32
247+
%offset = llvm.mlir.constant(0 : i32) : i32
248+
%size = llvm.mlir.constant(10 : i32) : i32
249+
250+
//CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
251+
rocdl.global.load.lds %src, %dst, %size, %offset, %aux
252+
253+
llvm.return
254+
}
255+
230256
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
231257
%stride : i16,
232258
%numRecords : i32,

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
457457
llvm.return %r0 : vector<8xf32>
458458
}
459459

460+
llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
461+
// CHECK-LABEL: rocdl.ds.read.tr
462+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
463+
%r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
464+
// CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
465+
%r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
466+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
467+
%r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
468+
// CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
469+
%r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
470+
// CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
471+
%r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
472+
llvm.return %r3 : vector<4xf16>
473+
}
474+
475+
llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
476+
%aux = llvm.mlir.constant(0 : i32) : i32
477+
%offset = llvm.mlir.constant(0 : i32) : i32
478+
%size = llvm.mlir.constant(10 : i32) : i32
479+
//CHECK: call void @llvm.amdgcn.global.load.lds
480+
rocdl.global.load.lds %src, %dst, %size, %offset, %aux
481+
llvm.return
482+
}
483+
460484
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
461485
%stride : i16,
462486
%numRecords : i32,

0 commit comments

Comments
 (0)