@@ -945,6 +945,53 @@ static bool upgradeArmOrAarch64IntrinsicFunction(bool IsArm, Function *F,
945945 return false ; // No other 'arm.*', 'aarch64.*'.
946946}
947947
948+ static Intrinsic::ID shouldUpgradeNVPTXTMAG2SIntrinsics (Function *F,
949+ StringRef Name) {
950+ if (Name.consume_front (" cp.async.bulk.tensor.g2s." )) {
951+ Intrinsic::ID ID =
952+ StringSwitch<Intrinsic::ID>(Name)
953+ .Case (" im2col.3d" ,
954+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
955+ .Case (" im2col.4d" ,
956+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
957+ .Case (" im2col.5d" ,
958+ Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
959+ .Case (" tile.1d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
960+ .Case (" tile.2d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
961+ .Case (" tile.3d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
962+ .Case (" tile.4d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
963+ .Case (" tile.5d" , Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
964+ .Default (Intrinsic::not_intrinsic);
965+
966+ if (ID == Intrinsic::not_intrinsic)
967+ return ID;
968+
969+ // These intrinsics may need upgrade for two reasons:
970+ // (1) When the address-space of the first argument is shared[AS=3]
971+ // (and we upgrade it to use shared_cluster address-space[AS=7])
972+ if (F->getArg (0 )->getType ()->getPointerAddressSpace () ==
973+ NVPTXAS::ADDRESS_SPACE_SHARED)
974+ return ID;
975+
976+ // (2) When there are only two boolean flag arguments at the end:
977+ //
978+ // The last three parameters of the older version of these
979+ // intrinsics are: arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag
980+ //
981+ // The newer version reads as:
982+ // arg1, arg2, .. i64 ch, i1 mc_flag, i1 ch_flag, i32 cta_group_flag
983+ //
984+ // So, when the type of the [N-3]rd argument is "not i1", then
985+ // it is the older version and we need to upgrade.
986+ size_t FlagStartIndex = F->getFunctionType ()->getNumParams () - 3 ;
987+ Type *ArgType = F->getFunctionType ()->getParamType (FlagStartIndex);
988+ if (!ArgType->isIntegerTy (1 ))
989+ return ID;
990+ }
991+
992+ return Intrinsic::not_intrinsic;
993+ }
994+
948995static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic (Function *F,
949996 StringRef Name) {
950997 if (Name.consume_front (" mapa.shared.cluster" ))
@@ -959,22 +1006,6 @@ static Intrinsic::ID shouldUpgradeNVPTXSharedClusterIntrinsic(Function *F,
9591006 Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster)
9601007 .Case (" shared.cta.to.cluster" ,
9611008 Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster)
962- .Case (" tensor.g2s.im2col.3d" ,
963- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d)
964- .Case (" tensor.g2s.im2col.4d" ,
965- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d)
966- .Case (" tensor.g2s.im2col.5d" ,
967- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d)
968- .Case (" tensor.g2s.tile.1d" ,
969- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d)
970- .Case (" tensor.g2s.tile.2d" ,
971- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d)
972- .Case (" tensor.g2s.tile.3d" ,
973- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d)
974- .Case (" tensor.g2s.tile.4d" ,
975- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d)
976- .Case (" tensor.g2s.tile.5d" ,
977- Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d)
9781009 .Default (Intrinsic::not_intrinsic);
9791010
9801011 if (ID != Intrinsic::not_intrinsic)
@@ -1339,6 +1370,14 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
13391370 return true ;
13401371 }
13411372
1373+ // Upgrade TMA copy G2S Intrinsics
1374+ IID = shouldUpgradeNVPTXTMAG2SIntrinsics (F, Name);
1375+ if (IID != Intrinsic::not_intrinsic) {
1376+ rename (F);
1377+ NewFn = Intrinsic::getOrInsertDeclaration (F->getParent (), IID);
1378+ return true ;
1379+ }
1380+
13421381 // The following nvvm intrinsics correspond exactly to an LLVM idiom, but
13431382 // not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
13441383 //
@@ -4831,7 +4870,18 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48314870 return ;
48324871 }
48334872 case Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster:
4834- case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster:
4873+ case Intrinsic::nvvm_cp_async_bulk_shared_cta_to_cluster: {
4874+ // Create a new call with the correct address space.
4875+ SmallVector<Value *, 4 > Args (CI->args ());
4876+ Args[0 ] = Builder.CreateAddrSpaceCast (
4877+ Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4878+
4879+ NewCall = Builder.CreateCall (NewFn, Args);
4880+ NewCall->takeName (CI);
4881+ CI->replaceAllUsesWith (NewCall);
4882+ CI->eraseFromParent ();
4883+ return ;
4884+ }
48354885 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
48364886 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d:
48374887 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
@@ -4840,10 +4890,22 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
48404890 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
48414891 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d:
48424892 case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: {
4843- // Create a new call with the correct address space.
4844- SmallVector<Value *, 4 > Args (CI->args ());
4845- Args[0 ] = Builder.CreateAddrSpaceCast (
4846- Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4893+ SmallVector<Value *, 16 > Args (CI->args ());
4894+
4895+ // Create AddrSpaceCast to shared_cluster if needed.
4896+ // This handles case (1) in shouldUpgradeNVPTXTMAG2SIntrinsics().
4897+ unsigned AS = CI->getArgOperand (0 )->getType ()->getPointerAddressSpace ();
4898+ if (AS == NVPTXAS::ADDRESS_SPACE_SHARED)
4899+ Args[0 ] = Builder.CreateAddrSpaceCast (
4900+ Args[0 ], Builder.getPtrTy (NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER));
4901+
4902+ // Attach the flag argument for cta_group, with a
4903+ // default value of 0. This handles case (2) in
4904+ // shouldUpgradeNVPTXTMAG2SIntrinsics().
4905+ size_t NumArgs = CI->arg_size ();
4906+ Value *FlagArg = CI->getArgOperand (NumArgs - 3 );
4907+ if (!FlagArg->getType ()->isIntegerTy (1 ))
4908+ Args.push_back (ConstantInt::get (Builder.getInt32Ty (), 0 ));
48474909
48484910 NewCall = Builder.CreateCall (NewFn, Args);
48494911 NewCall->takeName (CI);
0 commit comments