Skip to content

Commit fa30258

Browse files
committed
Adding support for 6-bit loadings.
1 parent 087046a commit fa30258

File tree

4 files changed

+60
-22
lines changed

4 files changed

+60
-22
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,26 @@ def AMDGPU_GatherToLDSOp :
898898
let hasVerifier = 1;
899899
}
900900

901+
def F8Types : AnyTypeOf<[
902+
F8E8M0FNU, // 8 exponent, 0 mantissa
903+
F8E5M2, // 5 exponent, 2 mantissa
904+
F8E5M2FNUZ, // 5 exponent, 2 mantissa
905+
F8E4M3, // 4 exponent, 3 mantissa
906+
F8E4M3FN, // 4 exponent, 3 mantissa
907+
F8E4M3B11FNUZ, // 4 exponent, 3 mantissa (with bias 11)
908+
F8E3M4 // 3 exponent, 4 mantissa
909+
]>;
910+
def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
911+
def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
912+
VectorOfLengthAndType<[8], [F8Types, AnyI<8>]>,
913+
VectorOfLengthAndType<[16], [AnyI<4>, F6Types]>,
914+
VectorOfLengthAndType<[3], [I32]>,
915+
]>;
916+
901917
def AMDGPU_TransposeLoadOp :
902918
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
903919
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
904-
Results<(outs MFMAInTypes:$result)> {
920+
Results<(outs TrLoadTypes:$result)> {
905921
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
906922
let description = [{
907923
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,30 +1115,37 @@ struct TransposeLoadOpLowering
11151115

11161116
Location loc = op.getLoc();
11171117
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1118+
auto resultType = cast<VectorType>(op.getResult().getType());
11181119
Value srcPtr =
11191120
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
11201121
(adaptor.getSrcIndices()));
1121-
auto elementTypeSize = cast<VectorType>(op.getDst().getType())
1122-
.getElementType()
1123-
.getIntOrFloatBitWidth();
11241122

1125-
// TODO: support ds_read_tr16_b64 intrinsic.
1123+
size_t numElements = resultType.getNumElements();
1124+
size_t elementTypeSize =
1125+
resultType.getElementType().getIntOrFloatBitWidth();
1126+
11261127
switch (elementTypeSize) {
11271128
case 4:
1128-
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
1129-
op, op.getDst().getType(), srcPtr);
1129+
assert(numElements == 16);
1130+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(op, resultType,
1131+
srcPtr);
11301132
break;
1131-
case 6:
1132-
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b64>(
1133-
op, op.getDst().getType(), srcPtr);
1133+
case 32:
1134+
// To use ds_read_tr6_b96, the load size is vector<3xi32>.
1135+
// TODO: support native 6-bit data types.
1136+
assert(numElements == 3);
1137+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b96>(op, resultType,
1138+
srcPtr);
11341139
break;
11351140
case 8:
1136-
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
1137-
op, op.getDst().getType(), srcPtr);
1141+
assert(numElements == 8);
1142+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(op, resultType,
1143+
srcPtr);
11381144
break;
11391145
case 16:
1140-
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
1141-
op, op.getDst().getType(), srcPtr);
1146+
assert(numElements == 4);
1147+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, resultType,
1148+
srcPtr);
11421149
break;
11431150
default:
11441151
return op.emitOpError("Unsupported element size for transpose load");

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -531,31 +531,32 @@ LogicalResult TransposeLoadOp::verify() {
531531
return emitOpError("source memory address space must be Workgroup");
532532

533533
// TODO: support 6-bit element type vectors.
534-
auto transferType = dyn_cast<VectorType>(getDst().getType());
534+
auto transferType = dyn_cast<VectorType>(getType());
535535
if (!transferType)
536536
return emitOpError("destination type must be a vector type");
537537
size_t transferSize =
538538
transferType.getNumElements() * transferType.getElementTypeBitWidth();
539539
size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
540540

541541
// ElementSize -> LoadSize
542-
const std::map<int, int> KValidLoadSizeMap = {
542+
const std::map<size_t, size_t> KValidLoadSizeMap = {
543543
{4, 64},
544-
{6, 96},
544+
{32, 96}, // 6-bit element loads use casted vector<3xi32>
545545
{8, 64},
546546
{16, 64},
547547
};
548548

549549
auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
550-
if (validLoadSize == KValidLoadSizeMap.end())
550+
if (validLoadSize == KValidLoadSizeMap.end()) {
551551
return emitOpError("Unsupported element type size for transpose load: ")
552552
<< elementTypeSize << " bits";
553-
if (transferSize != validLoadSize->second)
553+
}
554+
if (transferSize != validLoadSize->second) {
554555
return emitOpError("Transferring type size must be ")
555-
<< validLoadSize->second
556-
<< " bits for element type size "
556+
<< validLoadSize->second << " bits for element type size ";
557+
}
557558

558-
return success();
559+
return success();
559560
}
560561

561562
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,17 @@ func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : m
1616
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, #gpu_lds_addrspace> -> vector<8xi8>
1717
return %0 : vector<8xi8>
1818
}
19+
20+
// CHECK-LABEL: func @transpose_load_to_rocdl_16xi4
21+
func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, #gpu_lds_addrspace>) -> vector<16xi4> {
22+
// CHECK: rocdl.ds.read.tr4.b64
23+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, #gpu_lds_addrspace> -> vector<16xi4>
24+
return %0 : vector<16xi4>
25+
}
26+
27+
// CHECK-LABEL: func @transpose_load_to_rocdl_3xi32
28+
func.func @transpose_load_to_rocdl_3xi32(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi32, #gpu_lds_addrspace>) -> vector<3xi32> {
29+
// CHECK: rocdl.ds.read.tr6.b96
30+
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi32, #gpu_lds_addrspace> -> vector<3xi32>
31+
return %0 : vector<3xi32>
32+
}

0 commit comments

Comments
 (0)