Skip to content

Commit 087046a

Browse files
committed
Adding 6-bit loads.
1 parent 50d19a6 commit 087046a

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -901,20 +901,20 @@ def AMDGPU_GatherToLDSOp :
901901
def AMDGPU_TransposeLoadOp :
902902
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
903903
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
904-
Results<(outs MFMAInTypes:$dst)> {
904+
Results<(outs MFMAInTypes:$result)> {
905905
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
906906
let description = [{
907907
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
908908

909909
Operands:
910910
* `$src`: LDS memref to read from.
911911
* `$srcIndices`: indices into `$src` to read from for this thread.
912-
* `$dst`: target register this transpose load instruction will write to.
912+
* `$result`: target register this transpose load instruction will write to.
913913

914914
Note: Lowering is only supported on gfx950 and up.
915915
}];
916916
let assemblyFormat = [{
917-
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
917+
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
918918
}];
919919
let hasVerifier = 1;
920920
}

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,10 @@ struct TransposeLoadOpLowering
11281128
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
11291129
op, op.getDst().getType(), srcPtr);
11301130
break;
1131+
case 6:
1132+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr6_b64>(
1133+
op, op.getDst().getType(), srcPtr);
1134+
break;
11311135
case 8:
11321136
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
11331137
op, op.getDst().getType(), srcPtr);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,26 @@ LogicalResult TransposeLoadOp::verify() {
536536
return emitOpError("destination type must be a vector type");
537537
size_t transferSize =
538538
transferType.getNumElements() * transferType.getElementTypeBitWidth();
539-
if (transferSize != 64)
540-
return emitOpError("Transferring type size must be 64 bits");
541-
542-
return success();
539+
size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
540+
541+
// ElementSize -> LoadSize
542+
const std::map<int, int> KValidLoadSizeMap = {
543+
{4, 64},
544+
{6, 96},
545+
{8, 64},
546+
{16, 64},
547+
};
548+
549+
auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
550+
if (validLoadSize == KValidLoadSizeMap.end())
551+
return emitOpError("Unsupported element type size for transpose load: ")
552+
<< elementTypeSize << " bits";
553+
if (transferSize != validLoadSize->second)
554+
return emitOpError("Transferring type size must be ")
555+
<< validLoadSize->second
556+
<< " bits for element type size "
557+
558+
return success();
543559
}
544560

545561
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

0 commit comments

Comments
 (0)