Skip to content

Commit 76928f0

Browse files
[mlir][ROCDL] Address review comments
- Use SmallVectorImpl for output parameters per LLVM coding standards - Fix type casting bug (getSourceA -> getSourceB on line 1309) - Use cast<> instead of dyn_cast<> for guaranteed vector types - Simplify assembly format with functional-type shorthand Addresses review comments from @kuhar on PR #167041
1 parent fef3ace commit 76928f0

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemon
590590
LLVM_ScalarOrVectorOf<CD>:$C)> {
591591
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
592592
let assemblyFormat = [{
593-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
593+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
594594
}];
595595
}
596596

@@ -603,7 +603,7 @@ class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<
603603
DefaultValuedAttr<I1Attr, "0">:$opsel)> {
604604
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
605605
let assemblyFormat = [{
606-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
606+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
607607
}];
608608
}
609609

@@ -618,7 +618,7 @@ class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mne
618618
DefaultValuedAttr<I1Attr, "0">:$clamp)> {
619619
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
620620
let assemblyFormat = [{
621-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
621+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
622622
}];
623623
}
624624

@@ -635,7 +635,7 @@ class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL
635635
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
636636
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
637637
let assemblyFormat = [{
638-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
638+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
639639
}];
640640
}
641641

@@ -650,7 +650,7 @@ class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<
650650
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
651651
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
652652
let assemblyFormat = [{
653-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
653+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
654654
}];
655655
}
656656

@@ -667,7 +667,7 @@ class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> :
667667
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
668668
let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
669669
let assemblyFormat = [{
670-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
670+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
671671
}];
672672
}
673673

@@ -683,7 +683,7 @@ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp
683683
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
684684
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
685685
let assemblyFormat = [{
686-
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
686+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
687687
}];
688688
}
689689

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
682682
static void wmmaPushInputOperand(
683683
ConversionPatternRewriter &rewriter, Location loc,
684684
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
685-
Value mlirInput, SmallVector<Value, 4> &operands,
686-
SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
685+
Value mlirInput, SmallVectorImpl<Value> &operands,
686+
SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
687687
Type inputType = llvmInput.getType();
688688
auto vectorType = dyn_cast<VectorType>(inputType);
689689
if (!vectorType) {
@@ -741,8 +741,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
741741
Location loc,
742742
const TypeConverter *typeConverter,
743743
Value output, int32_t subwordOffset,
744-
bool clamp, SmallVector<Value, 4> &operands,
745-
SmallVector<NamedAttribute, 4> &attrs) {
744+
bool clamp, SmallVectorImpl<Value> &operands,
745+
SmallVectorImpl<NamedAttribute> &attrs) {
746746
Type inputType = output.getType();
747747
auto vectorType = dyn_cast<VectorType>(inputType);
748748
Type elemType = vectorType.getElementType();
@@ -1305,9 +1305,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13051305
// The WMMA operations represent vectors of bf16s as vectors of i16s
13061306
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
13071307
// bitcast them back.
1308-
auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
1309-
auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
1310-
auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
1308+
auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1309+
auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1310+
auto destCType = cast<VectorType>(adaptor.getDestC().getType());
13111311
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
13121312
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
13131313
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;

0 commit comments

Comments
 (0)