@@ -1045,10 +1045,87 @@ static constexpr IntrinsicHandler handlers[]{
10451045 {" dst" , asAddr},
10461046 {" nbytes" , asValue}}},
10471047 /* isElemental=*/ false },
1048+ {" tma_bulk_ldc4" ,
1049+ &I::genTMABulkLoadC4,
1050+ {{{" barrier" , asAddr},
1051+ {" src" , asAddr},
1052+ {" dst" , asAddr},
1053+ {" nelems" , asValue}}},
1054+ /* isElemental=*/ false },
1055+ {" tma_bulk_ldc8" ,
1056+ &I::genTMABulkLoadC8,
1057+ {{{" barrier" , asAddr},
1058+ {" src" , asAddr},
1059+ {" dst" , asAddr},
1060+ {" nelems" , asValue}}},
1061+ /* isElemental=*/ false },
1062+ {" tma_bulk_ldi4" ,
1063+ &I::genTMABulkLoadI4,
1064+ {{{" barrier" , asAddr},
1065+ {" src" , asAddr},
1066+ {" dst" , asAddr},
1067+ {" nelems" , asValue}}},
1068+ /* isElemental=*/ false },
1069+ {" tma_bulk_ldi8" ,
1070+ &I::genTMABulkLoadI8,
1071+ {{{" barrier" , asAddr},
1072+ {" src" , asAddr},
1073+ {" dst" , asAddr},
1074+ {" nelems" , asValue}}},
1075+ /* isElemental=*/ false },
1076+ {" tma_bulk_ldr2" ,
1077+ &I::genTMABulkLoadR2,
1078+ {{{" barrier" , asAddr},
1079+ {" src" , asAddr},
1080+ {" dst" , asAddr},
1081+ {" nelems" , asValue}}},
1082+ /* isElemental=*/ false },
1083+ {" tma_bulk_ldr4" ,
1084+ &I::genTMABulkLoadR4,
1085+ {{{" barrier" , asAddr},
1086+ {" src" , asAddr},
1087+ {" dst" , asAddr},
1088+ {" nelems" , asValue}}},
1089+ /* isElemental=*/ false },
1090+ {" tma_bulk_ldr8" ,
1091+ &I::genTMABulkLoadR8,
1092+ {{{" barrier" , asAddr},
1093+ {" src" , asAddr},
1094+ {" dst" , asAddr},
1095+ {" nelems" , asValue}}},
1096+ /* isElemental=*/ false },
10481097 {" tma_bulk_s2g" ,
10491098 &I::genTMABulkS2G,
10501099 {{{" src" , asAddr}, {" dst" , asAddr}, {" nbytes" , asValue}}},
10511100 /* isElemental=*/ false },
1101+ {" tma_bulk_store_c4" ,
1102+ &I::genTMABulkStoreC4,
1103+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1104+ /* isElemental=*/ false },
1105+ {" tma_bulk_store_c8" ,
1106+ &I::genTMABulkStoreC8,
1107+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1108+ /* isElemental=*/ false },
1109+ {" tma_bulk_store_i4" ,
1110+ &I::genTMABulkStoreI4,
1111+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1112+ /* isElemental=*/ false },
1113+ {" tma_bulk_store_i8" ,
1114+ &I::genTMABulkStoreI8,
1115+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1116+ /* isElemental=*/ false },
1117+ {" tma_bulk_store_r2" ,
1118+ &I::genTMABulkStoreR2,
1119+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1120+ /* isElemental=*/ false },
1121+ {" tma_bulk_store_r4" ,
1122+ &I::genTMABulkStoreR4,
1123+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1124+ /* isElemental=*/ false },
1125+ {" tma_bulk_store_r8" ,
1126+ &I::genTMABulkStoreR8,
1127+ {{{" src" , asAddr}, {" dst" , asAddr}, {" count" , asValue}}},
1128+ /* isElemental=*/ false },
10521129 {" tma_bulk_wait_group" ,
10531130 &I::genTMABulkWaitGroup,
10541131 {{}},
@@ -9278,6 +9355,93 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
92789355 builder, loc, dst, src, barrier, fir::getBase (args[3 ]), {}, {});
92799356}
92809357
9358+ static void genTMABulkLoad (fir::FirOpBuilder &builder, mlir::Location loc,
9359+ mlir::Value barrier, mlir::Value src,
9360+ mlir::Value dst, mlir::Value nelem,
9361+ mlir::Value eleSize) {
9362+ mlir::Value size = mlir::arith::MulIOp::create (builder, loc, nelem, eleSize);
9363+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (builder.getContext ());
9364+ barrier = builder.createConvert (loc, llvmPtrTy, barrier);
9365+ mlir::NVVM::InlinePtxOp::create (
9366+ builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
9367+ " cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
9368+ " [%1], %2, [%3];" ,
9369+ {});
9370+ mlir::NVVM::InlinePtxOp::create (
9371+ builder, loc, mlir::TypeRange{}, {barrier, size}, {},
9372+ " mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" , {});
9373+ }
9374+
9375+ // TMA_BULK_LOADC4
9376+ void IntrinsicLibrary::genTMABulkLoadC4 (
9377+ llvm::ArrayRef<fir::ExtendedValue> args) {
9378+ assert (args.size () == 4 );
9379+ mlir::Value eleSize =
9380+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9381+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9382+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9383+ }
9384+
9385+ // TMA_BULK_LOADC8
9386+ void IntrinsicLibrary::genTMABulkLoadC8 (
9387+ llvm::ArrayRef<fir::ExtendedValue> args) {
9388+ assert (args.size () == 4 );
9389+ mlir::Value eleSize =
9390+ builder.createIntegerConstant (loc, builder.getI32Type (), 16 );
9391+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9392+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9393+ }
9394+
9395+ // TMA_BULK_LOADI4
9396+ void IntrinsicLibrary::genTMABulkLoadI4 (
9397+ llvm::ArrayRef<fir::ExtendedValue> args) {
9398+ assert (args.size () == 4 );
9399+ mlir::Value eleSize =
9400+ builder.createIntegerConstant (loc, builder.getI32Type (), 4 );
9401+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9402+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9403+ }
9404+
9405+ // TMA_BULK_LOADI8
9406+ void IntrinsicLibrary::genTMABulkLoadI8 (
9407+ llvm::ArrayRef<fir::ExtendedValue> args) {
9408+ assert (args.size () == 4 );
9409+ mlir::Value eleSize =
9410+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9411+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9412+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9413+ }
9414+
9415+ // TMA_BULK_LOADR2
9416+ void IntrinsicLibrary::genTMABulkLoadR2 (
9417+ llvm::ArrayRef<fir::ExtendedValue> args) {
9418+ assert (args.size () == 4 );
9419+ mlir::Value eleSize =
9420+ builder.createIntegerConstant (loc, builder.getI32Type (), 2 );
9421+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9422+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9423+ }
9424+
9425+ // TMA_BULK_LOADR4
9426+ void IntrinsicLibrary::genTMABulkLoadR4 (
9427+ llvm::ArrayRef<fir::ExtendedValue> args) {
9428+ assert (args.size () == 4 );
9429+ mlir::Value eleSize =
9430+ builder.createIntegerConstant (loc, builder.getI32Type (), 4 );
9431+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9432+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9433+ }
9434+
9435+ // TMA_BULK_LOADR8
9436+ void IntrinsicLibrary::genTMABulkLoadR8 (
9437+ llvm::ArrayRef<fir::ExtendedValue> args) {
9438+ assert (args.size () == 4 );
9439+ mlir::Value eleSize =
9440+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9441+ genTMABulkLoad (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9442+ fir::getBase (args[2 ]), fir::getBase (args[3 ]), eleSize);
9443+ }
9444+
92819445// TMA_BULK_S2G (CUDA)
92829446void IntrinsicLibrary::genTMABulkS2G (llvm::ArrayRef<fir::ExtendedValue> args) {
92839447 assert (args.size () == 3 );
@@ -9287,6 +9451,97 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
92879451 mlir::NVVM::NVVMMemorySpace::Global);
92889452 mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create (
92899453 builder, loc, dst, src, fir::getBase (args[2 ]), {}, {});
9454+
9455+ mlir::NVVM::InlinePtxOp::create (builder, loc, mlir::TypeRange{}, {}, {},
9456+ " cp.async.bulk.commit_group" , {});
9457+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create (builder, loc,
9458+ builder.getI32IntegerAttr (0 ), {});
9459+ }
9460+
9461+ static void genTMABulkStore (fir::FirOpBuilder &builder, mlir::Location loc,
9462+ mlir::Value src, mlir::Value dst, mlir::Value count,
9463+ mlir::Value eleSize) {
9464+ mlir::Value size = mlir::arith::MulIOp::create (builder, loc, eleSize, count);
9465+ src = convertPtrToNVVMSpace (builder, loc, src,
9466+ mlir::NVVM::NVVMMemorySpace::Shared);
9467+ dst = convertPtrToNVVMSpace (builder, loc, dst,
9468+ mlir::NVVM::NVVMMemorySpace::Global);
9469+ mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create (builder, loc, dst, src,
9470+ size, {}, {});
9471+ mlir::NVVM::InlinePtxOp::create (builder, loc, mlir::TypeRange{}, {}, {},
9472+ " cp.async.bulk.commit_group" , {});
9473+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create (builder, loc,
9474+ builder.getI32IntegerAttr (0 ), {});
9475+ }
9476+
9477+ // TMA_BULK_STORE_C4 (CUDA)
9478+ void IntrinsicLibrary::genTMABulkStoreC4 (
9479+ llvm::ArrayRef<fir::ExtendedValue> args) {
9480+ assert (args.size () == 3 );
9481+ mlir::Value eleSize =
9482+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9483+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9484+ fir::getBase (args[2 ]), eleSize);
9485+ }
9486+
9487+ // TMA_BULK_STORE_C8 (CUDA)
9488+ void IntrinsicLibrary::genTMABulkStoreC8 (
9489+ llvm::ArrayRef<fir::ExtendedValue> args) {
9490+ assert (args.size () == 3 );
9491+ mlir::Value eleSize =
9492+ builder.createIntegerConstant (loc, builder.getI32Type (), 16 );
9493+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9494+ fir::getBase (args[2 ]), eleSize);
9495+ }
9496+
9497+ // TMA_BULK_STORE_I4 (CUDA)
9498+ void IntrinsicLibrary::genTMABulkStoreI4 (
9499+ llvm::ArrayRef<fir::ExtendedValue> args) {
9500+ assert (args.size () == 3 );
9501+ mlir::Value eleSize =
9502+ builder.createIntegerConstant (loc, builder.getI32Type (), 4 );
9503+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9504+ fir::getBase (args[2 ]), eleSize);
9505+ }
9506+
9507+ // TMA_BULK_STORE_I8 (CUDA)
9508+ void IntrinsicLibrary::genTMABulkStoreI8 (
9509+ llvm::ArrayRef<fir::ExtendedValue> args) {
9510+ assert (args.size () == 3 );
9511+ mlir::Value eleSize =
9512+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9513+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9514+ fir::getBase (args[2 ]), eleSize);
9515+ }
9516+
9517+ // TMA_BULK_STORE_R2 (CUDA)
9518+ void IntrinsicLibrary::genTMABulkStoreR2 (
9519+ llvm::ArrayRef<fir::ExtendedValue> args) {
9520+ assert (args.size () == 3 );
9521+ mlir::Value eleSize =
9522+ builder.createIntegerConstant (loc, builder.getI32Type (), 2 );
9523+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9524+ fir::getBase (args[2 ]), eleSize);
9525+ }
9526+
9527+ // TMA_BULK_STORE_R4 (CUDA)
9528+ void IntrinsicLibrary::genTMABulkStoreR4 (
9529+ llvm::ArrayRef<fir::ExtendedValue> args) {
9530+ assert (args.size () == 3 );
9531+ mlir::Value eleSize =
9532+ builder.createIntegerConstant (loc, builder.getI32Type (), 4 );
9533+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9534+ fir::getBase (args[2 ]), eleSize);
9535+ }
9536+
9537+ // TMA_BULK_STORE_R8 (CUDA)
9538+ void IntrinsicLibrary::genTMABulkStoreR8 (
9539+ llvm::ArrayRef<fir::ExtendedValue> args) {
9540+ assert (args.size () == 3 );
9541+ mlir::Value eleSize =
9542+ builder.createIntegerConstant (loc, builder.getI32Type (), 8 );
9543+ genTMABulkStore (builder, loc, fir::getBase (args[0 ]), fir::getBase (args[1 ]),
9544+ fir::getBase (args[2 ]), eleSize);
92909545}
92919546
92929547// TMA_BULK_WAIT_GROUP (CUDA)
0 commit comments