@@ -195,7 +195,7 @@ static constexpr IntrinsicHandler cudaHandlers[]{
195195 false },
196196 {" atomicadd_r4x4" ,
197197 static_cast <CUDAIntrinsicLibrary::ExtendedGenerator>(
198- &CI::genAtomicAddVector< 4 > ),
198+ &CI::genAtomicAddVector4x4 ),
199199 {{{" a" , asAddr}, {" v" , asAddr}}},
200200 false },
201201 {" atomicaddd" ,
@@ -758,6 +758,56 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector(
758758 return fir::ArrayBoxValue (res, {ext});
759759}
760760
761+ // ATOMICADDVECTOR4x4
762+ fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector4x4 (
763+ mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
764+ assert (args.size () == 2 );
765+ mlir::Value a = fir::getBase (args[0 ]);
766+ if (mlir::isa<fir::BaseBoxType>(a.getType ()))
767+ a = fir::BoxAddrOp::create (builder, loc, a);
768+
769+ const unsigned extent = 4 ;
770+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (builder.getContext ());
771+ mlir::Value ptr = builder.createConvert (loc, llvmPtrTy, a);
772+ mlir::Type f32Ty = builder.getF32Type ();
773+ mlir::Type idxTy = builder.getIndexType ();
774+ mlir::Type refTy = fir::ReferenceType::get (f32Ty);
775+ llvm::SmallVector<mlir::Value> values;
776+ for (unsigned i = 0 ; i < extent; ++i) {
777+ mlir::Value pos = builder.createIntegerConstant (loc, idxTy, i);
778+ mlir::Value coord = fir::CoordinateOp::create (builder, loc, refTy,
779+ fir::getBase (args[1 ]), pos);
780+ mlir::Value value = fir::LoadOp::create (builder, loc, coord);
781+ values.push_back (value);
782+ }
783+
784+ auto inlinePtx = mlir::NVVM::InlinePtxOp::create (
785+ builder, loc, {f32Ty, f32Ty, f32Ty, f32Ty},
786+ {ptr, values[0 ], values[1 ], values[2 ], values[3 ]}, {},
787+ " atom.add.v4.f32 {%0, %1, %2, %3}, [%4], {%5, %6, %7, %8};" , {});
788+
789+ llvm::SmallVector<mlir::Value> results;
790+ results.push_back (inlinePtx.getResult (0 ));
791+ results.push_back (inlinePtx.getResult (1 ));
792+ results.push_back (inlinePtx.getResult (2 ));
793+ results.push_back (inlinePtx.getResult (3 ));
794+
795+ mlir::Type vecF32Ty = mlir::VectorType::get ({extent}, f32Ty);
796+ mlir::Value undef = mlir::LLVM::UndefOp::create (builder, loc, vecF32Ty);
797+ mlir::Type i32Ty = builder.getI32Type ();
798+ for (unsigned i = 0 ; i < extent; ++i)
799+ undef = mlir::LLVM::InsertElementOp::create (
800+ builder, loc, undef, results[i],
801+ builder.createIntegerConstant (loc, i32Ty, i));
802+
803+ auto i128Ty = builder.getIntegerType (128 );
804+ auto i128VecTy = mlir::VectorType::get ({1 }, i128Ty);
805+ mlir::Value vec128 =
806+ mlir::vector::BitCastOp::create (builder, loc, i128VecTy, undef);
807+ return mlir::vector::ExtractOp::create (builder, loc, vec128,
808+ mlir::ArrayRef<int64_t >{0 });
809+ }
810+
761811mlir::Value
762812CUDAIntrinsicLibrary::genAtomicAnd (mlir::Type resultType,
763813 llvm::ArrayRef<mlir::Value> args) {
0 commit comments