@@ -858,6 +858,64 @@ struct CUDADeviceTy : public GenericDeviceTy {
858858 void *DstPtr, int64_t Size,
859859 AsyncInfoWrapperTy &AsyncInfoWrapper) override ;
860860
861+ Error dataFillImpl (void *TgtPtr, const void *PatternPtr, int64_t PatternSize,
862+ int64_t Size,
863+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
864+ if (auto Err = setContext ())
865+ return Err;
866+
867+ CUstream Stream;
868+ if (auto Err = getStream (AsyncInfoWrapper, Stream))
869+ return Err;
870+
871+ CUresult Res;
872+ size_t N = Size / PatternSize;
873+ if (PatternSize == 1 ) {
874+ Res = cuMemsetD8Async ((CUdeviceptr)TgtPtr,
875+ *(static_cast <const uint8_t *>(PatternPtr)), N,
876+ Stream);
877+ } else if (PatternSize == 2 ) {
878+ Res = cuMemsetD16Async ((CUdeviceptr)TgtPtr,
879+ *(static_cast <const uint16_t *>(PatternPtr)), N,
880+ Stream);
881+ } else if (PatternSize == 4 ) {
882+ Res = cuMemsetD32Async ((CUdeviceptr)TgtPtr,
883+ *(static_cast <const uint32_t *>(PatternPtr)), N,
884+ Stream);
885+ } else {
886+ // For larger patterns we can do a series of strided fills to copy the
887+ // pattern efficiently
888+ int64_t MemsetSize = PatternSize % 4u == 0u ? 4u
889+ : PatternSize % 2u == 0u ? 2u
890+ : 1u ;
891+
892+ int64_t NumberOfSteps = PatternSize / MemsetSize;
893+ int64_t Pitch = NumberOfSteps * MemsetSize;
894+ int64_t Height = Size / PatternSize;
895+
896+ for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
897+ if (MemsetSize == 4 ) {
898+ Res = cuMemsetD2D32Async (
899+ (CUdeviceptr)TgtPtr + Step * MemsetSize, Pitch,
900+ *(static_cast <const uint32_t *>(PatternPtr) + Step), 1u , Height,
901+ Stream);
902+ } else if (MemsetSize == 2 ) {
903+ Res = cuMemsetD2D16Async (
904+ (CUdeviceptr)TgtPtr + Step * MemsetSize, Pitch,
905+ *(static_cast <const uint16_t *>(PatternPtr) + Step), 1u , Height,
906+ Stream);
907+ } else {
908+ Res = cuMemsetD2D8Async (
909+ (CUdeviceptr)TgtPtr + Step * MemsetSize, Pitch,
910+ *(static_cast <const uint8_t *>(PatternPtr) + Step), 1u , Height,
911+ Stream);
912+ }
913+ }
914+ }
915+
916+ return Plugin::check (Res, " error in cuMemset: %s" );
917+ }
918+
861919 // / Initialize the async info for interoperability purposes.
862920 Error initAsyncInfoImpl (AsyncInfoWrapperTy &AsyncInfoWrapper) override {
863921 if (auto Err = setContext ())
0 commit comments