Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -898,23 +898,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}

def F8Types : AnyTypeOf<[
F8E8M0FNU, // 8 exponent, 0 mantissa
F8E5M2, // 5 exponent, 2 mantissa
F8E5M2FNUZ, // 5 exponent, 2 mantissa
F8E4M3, // 4 exponent, 3 mantissa
F8E4M3FN, // 4 exponent, 3 mantissa
F8E4M3B11FNUZ, // 4 exponent, 3 mantissa (with bias 11)
F8E3M4 // 3 exponent, 4 mantissa
]>;
def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BF16 exists ... and also, we can probably leave this open and rely on a getIntOrFloatBitWidth() check in the verifier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, now it accepts any vectors and the verifier will serve as the checker.

VectorOfLengthAndType<[8], [F8Types, AnyI<8>]>,
VectorOfLengthAndType<[16], [AnyI<4>, F6Types]>,
VectorOfLengthAndType<[3], [I32]>,
]>;

def AMDGPU_TransposeLoadOp :
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
Results<(outs MFMAInTypes:$dst)> {
Results<(outs TrLoadTypes:$result)> {
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
let description = [{
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.

The transpose load op represents a subgroup load from LDS memory,
where the subgroup of threads collectively reads a matrix from the source
memref, with each thread reading a vector of the matrix, and gets a transposed matrix
in as the result. That is, each thread reads a vector of the col-major matrix at different
indices, and the thread's read result is a vector of the corresponding row of the transposed
matrix.

This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
to the ROCDL documentation for more details about its exact semantics.

Format example:
```
%0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
```
Operands:
* `$src`: LDS memref to read from.
* `$srcIndices`: indices into `$src` to read from for this thread.
* `$dst`: target register this transpose load instruction will write to.
* `$result`: target register this transpose load instruction will write to.

Note: Lowering is only supported on gfx950 and up.
}];
let assemblyFormat = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know other ops here don't provide examples, but I think it would be worth adding going forward -- I rely on these all the time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your idea. So I tried to add a very simple example to show the format of the op. In terms of the semantics of the instruction, it is too hard to explain in a few sentences so I wrote that "please refer to the actual document for detailed explanation".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably call out that you mean the CDNA4 ISA manual

$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
}];
let hasVerifier = 1;
}
Expand Down
31 changes: 21 additions & 10 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1115,26 +1115,37 @@ struct TransposeLoadOpLowering

Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
auto resultType = cast<VectorType>(op.getResult().getType());
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
auto elementTypeSize = cast<VectorType>(op.getDst().getType())
.getElementType()
.getIntOrFloatBitWidth();

// TODO: support ds_read_tr16_b64 intrinsic.
size_t numElements = resultType.getNumElements();
size_t elementTypeSize =
resultType.getElementType().getIntOrFloatBitWidth();

switch (elementTypeSize) {
case 4:
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
op, op.getDst().getType(), srcPtr);
assert(numElements == 16);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
srcPtr);
break;
case 32:
// To use ds_read_tr6_b96, the load size is vector<3xi32>.
// TODO: support native 6-bit data types.
assert(numElements == 3);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
srcPtr);
break;
case 8:
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
op, op.getDst().getType(), srcPtr);
assert(numElements == 8);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
srcPtr);
break;
case 16:
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
op, op.getDst().getType(), srcPtr);
assert(numElements == 4);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
srcPtr);
break;
default:
return op.emitOpError("Unsupported element size for transpose load");
Expand Down
23 changes: 20 additions & 3 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,13 +531,30 @@ LogicalResult TransposeLoadOp::verify() {
return emitOpError("source memory address space must be Workgroup");

// TODO: support 6-bit element type vectors.
auto transferType = dyn_cast<VectorType>(getDst().getType());
auto transferType = dyn_cast<VectorType>(getType());
if (!transferType)
return emitOpError("destination type must be a vector type");
size_t transferSize =
transferType.getNumElements() * transferType.getElementTypeBitWidth();
if (transferSize != 64)
return emitOpError("Transferring type size must be 64 bits");
size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();

// ElementSize -> LoadSize
const std::map<size_t, size_t> KValidLoadSizeMap = {
{4, 64},
{32, 96}, // 6-bit element loads use casted vector<3xi32>
{8, 64},
{16, 64},
};

auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
if (validLoadSize == KValidLoadSizeMap.end()) {
return emitOpError("Unsupported element type size for transpose load: ")
<< elementTypeSize << " bits";
}
if (transferSize != validLoadSize->second) {
return emitOpError("Transferring type size must be ")
<< validLoadSize->second << " bits for element type size ";
}

return success();
}
Expand Down
35 changes: 28 additions & 7 deletions mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s

#gpu_lds_addrspace = 3
#amdgpu_fat_buffer_addrspace = 7
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD

// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, #gpu_lds_addrspace>) -> vector<4xf16> {
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
// CHECK: rocdl.ds.read.tr16.b64
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, #gpu_lds_addrspace> -> vector<4xf16>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
return %0 : vector<4xf16>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, #gpu_lds_addrspace>) -> vector<8xi8> {
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
// CHECK: rocdl.ds.read.tr8.b64
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, #gpu_lds_addrspace> -> vector<8xi8>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
return %0 : vector<8xi8>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_16xi4
func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
// CHECK: rocdl.ds.read.tr4.b64
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
return %0 : vector<16xi4>
}

// -----

// CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, 3>) -> vector<3xi32> {
// CHECK: rocdl.ds.read.tr6.b96
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, 3> -> vector<3xi32>
return %0 : vector<3xi32>
}