Skip to content

Commit 7c7d739

Browse files
committed
[MLIR][NVVM][Refactor] Refactor intrinsic lowering for NVVM Ops
This change adds standardizes the usage of getIntrinsicIDAndArgsMaybeWithTypes across NVVM Ops for intrinsic lowering which returns the intrinsic ID, arguments, and in the case of overloaded intrinsics, the types of the arguments as well.
1 parent 97d4c7d commit 7c7d739

File tree

3 files changed

+252
-199
lines changed

3 files changed

+252
-199
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,14 @@ enum NVVMMemorySpace {
5555
kSharedClusterMemorySpace = 7,
5656
};
5757

58-
/// A pair type of LLVM's Intrinsic ID and args (which are llvm values).
59-
/// This type is returned by the getIntrinsicIDAndArgs() methods.
60-
using IDArgPair =
61-
std::pair<llvm::Intrinsic::ID, llvm::SmallVector<llvm::Value *>>;
58+
/// A tuple type of LLVM's Intrinsic ID, args (which are llvm values),
59+
/// and args types (which are llvm types).
60+
/// Args types are only required for overloaded intrinsics to provide the
61+
/// correct argument types to the createIntrinsicCall() method.
62+
/// This type is returned by the getIntrinsicIDAndArgsMaybeWithTypesMaybeWithTypes() methods.
63+
using IIDArgsMaybeWithTypes =
64+
std::tuple<llvm::Intrinsic::ID, llvm::SmallVector<llvm::Value *>,
65+
llvm::SmallVector<llvm::Type *>>;
6266

6367
/// Return the element type and number of elements associated with a wmma matrix
6468
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 62 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,15 +1108,14 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
11081108
let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
11091109
let hasVerifier = 1;
11101110
let extraClassDeclaration = [{
1111-
static llvm::Intrinsic::ID
1112-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1113-
llvm::SmallVector<llvm::Value *> &args);
1111+
static NVVM::IIDArgsMaybeWithTypes
1112+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
1113+
llvm::IRBuilderBase &builder);
11141114
}];
11151115
string llvmBuilder = [{
1116-
llvm::SmallVector<llvm::Value *> translatedOperands;
1117-
auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs(
1118-
*op, moduleTranslation, translatedOperands);
1119-
createIntrinsicCall(builder, id, translatedOperands);
1116+
auto [id, args, types] = NVVM::CpAsyncOp::getIntrinsicIDAndArgsMaybeWithTypes(
1117+
*op, moduleTranslation, builder);
1118+
createIntrinsicCall(builder, id, args);
11201119
}];
11211120
}
11221121

@@ -2543,8 +2542,8 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
25432542
let extraClassDeclaration = [{
25442543
bool hasIntrinsic() { return !getPredicate(); }
25452544

2546-
static mlir::NVVM::IDArgPair
2547-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2545+
static mlir::NVVM::IIDArgsMaybeWithTypes
2546+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
25482547
llvm::IRBuilderBase& builder);
25492548
}];
25502549

@@ -2565,7 +2564,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
25652564
let hasVerifier = 1;
25662565

25672566
string llvmBuilder = [{
2568-
auto [id, args] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2567+
auto [id, args, types] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgsMaybeWithTypes(
25692568
*op, moduleTranslation, builder);
25702569
createIntrinsicCall(builder, id, args);
25712570
}];
@@ -2631,8 +2630,8 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch",
26312630
let hasVerifier = 1;
26322631

26332632
let extraClassDeclaration = [{
2634-
static NVVM::IDArgPair
2635-
getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt,
2633+
static NVVM::IIDArgsMaybeWithTypes
2634+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
26362635
llvm::IRBuilderBase &builder);
26372636
bool hasIntrinsic() { return !getPredicate() || !getTensormap(); }
26382637
}];
@@ -2643,7 +2642,7 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch",
26432642
}
26442643
}];
26452644
let llvmBuilder = [{
2646-
auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
2645+
auto [id, args, types] = NVVM::PrefetchOp::getIntrinsicIDAndArgsMaybeWithTypes(*op,
26472646
moduleTranslation, builder);
26482647

26492648
if(op.getTensormap())
@@ -2685,13 +2684,13 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
26852684
}];
26862685

26872686
let extraClassDeclaration = [{
2688-
static mlir::NVVM::IDArgPair
2689-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2687+
static mlir::NVVM::IIDArgsMaybeWithTypes
2688+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
26902689
llvm::IRBuilderBase& builder);
26912690
}];
26922691

26932692
string llvmBuilder = [{
2694-
auto [id, args] = NVVM::CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
2693+
auto [id, args, types] = NVVM::CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgsMaybeWithTypes(
26952694
*op, moduleTranslation, builder);
26962695
createIntrinsicCall(builder, id, args);
26972696
}];
@@ -2726,15 +2725,15 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
27262725
}];
27272726

27282727
let extraClassDeclaration = [{
2729-
static mlir::NVVM::IDArgPair
2730-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2728+
static mlir::NVVM::IIDArgsMaybeWithTypes
2729+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
27312730
llvm::IRBuilderBase& builder);
27322731
}];
27332732

27342733
let hasVerifier = 1;
27352734

27362735
string llvmBuilder = [{
2737-
auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
2736+
auto [id, args, types] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgsMaybeWithTypes(
27382737
*op, moduleTranslation, builder);
27392738
createIntrinsicCall(builder, id, args);
27402739
}];
@@ -2795,35 +2794,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
27952794
}];
27962795

27972796
let extraClassDeclaration = [{
2798-
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
2799-
NVVM::TMAReduxKind kind,
2800-
bool isIm2Col);
2797+
static mlir::NVVM::IIDArgsMaybeWithTypes
2798+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op,
2799+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder);
28012800
}];
28022801

28032802
let hasVerifier = 1;
28042803

28052804
string llvmBuilder = [{
2806-
// Arguments to the intrinsic:
2807-
// shared_mem_ptr, tmaDesc, tensorDims
2808-
// cache_hint(if applicable) and flag(boolean)
2809-
llvm::SmallVector<llvm::Value *> translatedOperands;
2810-
translatedOperands.push_back($srcMem);
2811-
translatedOperands.push_back($tmaDescriptor);
2812-
2813-
for (auto v : op.getCoordinates())
2814-
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2815-
2816-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2817-
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2818-
2819-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2820-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2821-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2822-
2823-
auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
2824-
op.getCoordinates().size(), $redKind,
2825-
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
2826-
createIntrinsicCall(builder, intId, translatedOperands);
2805+
auto [id, args, types] = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgsMaybeWithTypes(
2806+
*op, moduleTranslation, builder);
2807+
createIntrinsicCall(builder, id, args);
28272808
}];
28282809
}
28292810

@@ -2860,36 +2841,17 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
28602841
(`l2_cache_hint` `=` $l2CacheHint^ )?
28612842
attr-dict `:` type($dstMem) `,` type($srcMem)
28622843
}];
2844+
2845+
let extraClassDeclaration = [{
2846+
static mlir::NVVM::IIDArgsMaybeWithTypes
2847+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op,
2848+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder);
2849+
}];
28632850

28642851
string llvmBuilder = [{
2865-
// Arguments to the intrinsic:
2866-
// dst, mbar, src, size
2867-
// multicast_mask, cache_hint,
2868-
// flag for multicast_mask,
2869-
// flag for cache_hint
2870-
llvm::SmallVector<llvm::Value *> translatedOperands;
2871-
translatedOperands.push_back($dstMem);
2872-
translatedOperands.push_back($mbar);
2873-
translatedOperands.push_back($srcMem);
2874-
translatedOperands.push_back($size);
2875-
2876-
// Multicast, if available
2877-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2878-
auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
2879-
bool isMulticast = op.getMulticastMask() ? true : false;
2880-
translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
2881-
2882-
// Cachehint, if available
2883-
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2884-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2885-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2886-
2887-
// Flag arguments for multicast and cachehint
2888-
translatedOperands.push_back(builder.getInt1(isMulticast));
2889-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2890-
2891-
createIntrinsicCall(builder,
2892-
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
2852+
auto [id, args, types] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgsMaybeWithTypes(
2853+
*op, moduleTranslation, builder);
2854+
createIntrinsicCall(builder, id, args);
28932855
}];
28942856
}
28952857

@@ -2971,12 +2933,12 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
29712933
}];
29722934

29732935
let extraClassDeclaration = [{
2974-
static mlir::NVVM::IDArgPair
2975-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2936+
static mlir::NVVM::IIDArgsMaybeWithTypes
2937+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
29762938
llvm::IRBuilderBase& builder);
29772939
}];
29782940
string llvmBuilder = [{
2979-
auto [id, args] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2941+
auto [id, args, types] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgsMaybeWithTypes(
29802942
*op, moduleTranslation, builder);
29812943
createIntrinsicCall(builder, id, args);
29822944
}];
@@ -3392,14 +3354,13 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]
33923354
let assemblyFormat = "$addr `,` $nCols attr-dict `:` type(operands)";
33933355

33943356
let extraClassDeclaration = [{
3395-
static llvm::Intrinsic::ID
3396-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3397-
llvm::SmallVector<llvm::Value *> &args);
3357+
static NVVM::IIDArgsMaybeWithTypes
3358+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
3359+
llvm::IRBuilderBase &builder);
33983360
}];
33993361
string llvmBuilder = [{
3400-
llvm::SmallVector<llvm::Value *> args;
3401-
auto id = NVVM::Tcgen05AllocOp::getIntrinsicIDAndArgs(
3402-
*op, moduleTranslation, args);
3362+
auto [id, args, types] = NVVM::Tcgen05AllocOp::getIntrinsicIDAndArgsMaybeWithTypes(
3363+
*op, moduleTranslation, builder);
34033364
createIntrinsicCall(builder, id, args);
34043365
}];
34053366
}
@@ -3420,14 +3381,13 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10
34203381
let assemblyFormat = "$taddr `,` $nCols attr-dict `:` type(operands)";
34213382

34223383
let extraClassDeclaration = [{
3423-
static llvm::Intrinsic::ID
3424-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3425-
llvm::SmallVector<llvm::Value *> &args);
3384+
static NVVM::IIDArgsMaybeWithTypes
3385+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
3386+
llvm::IRBuilderBase &builder);
34263387
}];
34273388
string llvmBuilder = [{
3428-
llvm::SmallVector<llvm::Value *> args;
3429-
auto id = NVVM::Tcgen05DeallocOp::getIntrinsicIDAndArgs(
3430-
*op, moduleTranslation, args);
3389+
auto [id, args, types] = NVVM::Tcgen05DeallocOp::getIntrinsicIDAndArgsMaybeWithTypes(
3390+
*op, moduleTranslation, builder);
34313391
createIntrinsicCall(builder, id, args);
34323392
}];
34333393
}
@@ -3524,15 +3484,14 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]
35243484
}];
35253485

35263486
let extraClassDeclaration = [{
3527-
static llvm::Intrinsic::ID
3528-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3529-
llvm::SmallVector<llvm::Value *> &args);
3487+
static NVVM::IIDArgsMaybeWithTypes
3488+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
3489+
llvm::IRBuilderBase &builder);
35303490
}];
35313491

35323492
string llvmBuilder = [{
3533-
llvm::SmallVector<llvm::Value *> args;
3534-
auto id = NVVM::Tcgen05CommitOp::getIntrinsicIDAndArgs(
3535-
*op, moduleTranslation, args);
3493+
auto [id, args, types] = NVVM::Tcgen05CommitOp::getIntrinsicIDAndArgsMaybeWithTypes(
3494+
*op, moduleTranslation, builder);
35363495
createIntrinsicCall(builder, id, args);
35373496
}];
35383497
}
@@ -3636,12 +3595,14 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
36363595
let hasVerifier = 1;
36373596

36383597
let extraClassDeclaration = [{
3639-
static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
3598+
static NVVM::IIDArgsMaybeWithTypes getIntrinsicIDAndArgsMaybeWithTypes(Operation &op,
3599+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
36403600
}];
36413601

36423602
string llvmBuilder = [{
3643-
auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
3644-
createIntrinsicCall(builder, id, {$taddr, $smem_desc});
3603+
auto [id, args, types] = NVVM::Tcgen05CpOp::getIntrinsicIDAndArgsMaybeWithTypes(*op,
3604+
moduleTranslation, builder);
3605+
createIntrinsicCall(builder, id, args);
36453606
}];
36463607
}
36473608

@@ -3969,13 +3930,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
39693930
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
39703931

39713932
let extraClassDeclaration = [{
3972-
static mlir::NVVM::IDArgPair
3973-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3933+
static mlir::NVVM::IIDArgsMaybeWithTypes
3934+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
39743935
llvm::IRBuilderBase &builder);
39753936
}];
39763937

39773938
string llvmBuilder = [{
3978-
auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
3939+
auto [id, args, types] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgsMaybeWithTypes(
39793940
*op, moduleTranslation, builder);
39803941
$res = createIntrinsicCall(builder, id, args);
39813942
}];
@@ -4023,13 +3984,13 @@ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> {
40233984
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
40243985

40253986
let extraClassDeclaration = [{
4026-
static mlir::NVVM::IDArgPair
4027-
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3987+
static mlir::NVVM::IIDArgsMaybeWithTypes
3988+
getIntrinsicIDAndArgsMaybeWithTypes(Operation &op, LLVM::ModuleTranslation &mt,
40283989
llvm::IRBuilderBase &builder);
40293990
}];
40303991

40313992
string llvmBuilder = [{
4032-
auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs(
3993+
auto [id, args, types] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgsMaybeWithTypes(
40333994
*op, moduleTranslation, builder);
40343995
$res = createIntrinsicCall(builder, id, args);
40353996
}];

0 commit comments

Comments
 (0)