Skip to content

Commit cd0f038

Browse files
committed
Merge pull request opencv#17788 from YashasSamaga:cuda4dnn-nice-build
2 parents 5924770 + 1949056 commit cd0f038

File tree

3 files changed

+163
-47
lines changed

3 files changed

+163
-47
lines changed

modules/dnn/CMakeLists.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,18 @@ ocv_option(OPENCV_DNN_CUDA "Build with CUDA support"
2727
AND HAVE_CUDNN
2828
)
2929

30-
if(OPENCV_DNN_CUDA AND HAVE_CUDA AND HAVE_CUBLAS AND HAVE_CUDNN)
31-
add_definitions(-DCV_CUDA4DNN=1)
30+
if(OPENCV_DNN_CUDA)
31+
if(HAVE_CUDA AND HAVE_CUBLAS AND HAVE_CUDNN)
32+
add_definitions(-DCV_CUDA4DNN=1)
33+
else()
34+
if(NOT HAVE_CUDA)
35+
message(SEND_ERROR "DNN: CUDA backend requires CUDA Toolkit. Please resolve dependency or disable OPENCV_DNN_CUDA=OFF")
36+
elseif(NOT HAVE_CUBLAS)
37+
message(SEND_ERROR "DNN: CUDA backend requires cuBLAS. Please resolve dependency or disable OPENCV_DNN_CUDA=OFF")
38+
elseif(NOT HAVE_CUDNN)
39+
message(SEND_ERROR "DNN: CUDA backend requires cuDNN. Please resolve dependency or disable OPENCV_DNN_CUDA=OFF")
40+
endif()
41+
endif()
3242
endif()
3343

3444
ocv_cmake_hook_append(INIT_MODULE_SOURCES_opencv_dnn "${CMAKE_CURRENT_LIST_DIR}/cmake/hooks/INIT_MODULE_SOURCES_opencv_dnn.cmake")

modules/dnn/src/cuda4dnn/init.hpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#ifndef OPENCV_DNN_SRC_CUDA4DNN_INIT_HPP
6+
#define OPENCV_DNN_SRC_CUDA4DNN_INIT_HPP
7+
8+
#include "csl/error.hpp"
9+
10+
#include <cuda_runtime.h>
11+
#include <cudnn.h>
12+
13+
#include <opencv2/core/cuda.hpp>
14+
#include <sstream>
15+
16+
namespace cv { namespace dnn { namespace cuda4dnn {
17+
18+
void checkVersions()
19+
{
20+
int cudart_version = 0;
21+
CUDA4DNN_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
22+
if (cudart_version != CUDART_VERSION)
23+
{
24+
std::ostringstream oss;
25+
oss << "CUDART reports version " << cudart_version << " which does not match with the version " << CUDART_VERSION << " with which OpenCV was built";
26+
CV_LOG_WARNING(NULL, oss.str().c_str());
27+
}
28+
29+
auto cudnn_version = cudnnGetVersion();
30+
if (cudnn_version != CUDNN_VERSION)
31+
{
32+
std::ostringstream oss;
33+
oss << "cuDNN reports version " << cudnn_version << " which does not match with the version " << CUDNN_VERSION << " with which OpenCV was built";
34+
CV_LOG_WARNING(NULL, oss.str().c_str());
35+
}
36+
37+
auto cudnn_cudart_version = cudnnGetCudartVersion();
38+
if (cudart_version != cudnn_cudart_version)
39+
{
40+
std::ostringstream oss;
41+
oss << "CUDART version " << cudnn_cudart_version << " reported by cuDNN " << cudnn_version << " does not match with the version reported by CUDART " << cudart_version;
42+
CV_LOG_WARNING(NULL, oss.str().c_str());
43+
}
44+
}
45+
46+
int getDeviceCount()
47+
{
48+
return cuda::getCudaEnabledDeviceCount();
49+
}
50+
51+
int getDevice()
52+
{
53+
int device_id = -1;
54+
CUDA4DNN_CHECK_CUDA(cudaGetDevice(&device_id));
55+
return device_id;
56+
}
57+
58+
bool isDeviceCompatible()
59+
{
60+
int device_id = getDevice();
61+
if (device_id < 0)
62+
return false;
63+
64+
int major = 0, minor = 0;
65+
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id));
66+
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id));
67+
68+
if (cv::cuda::TargetArchs::hasEqualOrLessPtx(major, minor))
69+
return true;
70+
71+
for (int i = minor; i >= 0; i--)
72+
if (cv::cuda::TargetArchs::hasBin(major, i))
73+
return true;
74+
75+
return false;
76+
}
77+
78+
bool doesDeviceSupportFP16()
79+
{
80+
int device_id = getDevice();
81+
if (device_id < 0)
82+
return false;
83+
84+
int major = 0, minor = 0;
85+
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id));
86+
CUDA4DNN_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id));
87+
88+
int version = major * 10 + minor;
89+
if (version < 53)
90+
return false;
91+
return true;
92+
}
93+
94+
}}} /* namespace cv::dnn::cuda4dnn */
95+
96+
#endif /* OPENCV_DNN_SRC_CUDA4DNN_INIT_HPP */

modules/dnn/src/dnn.cpp

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
#include "op_cuda.hpp"
4848

4949
#ifdef HAVE_CUDA
50-
#include "cuda4dnn/primitives/eltwise.hpp"
50+
#include "cuda4dnn/init.hpp"
51+
#include "cuda4dnn/primitives/eltwise.hpp" // required by fuseLayers
5152
#endif
5253

5354
#include "halide_scheduler.hpp"
@@ -66,8 +67,6 @@
6667
#include <opencv2/core/utils/configuration.private.hpp>
6768
#include <opencv2/core/utils/logger.hpp>
6869

69-
#include <opencv2/core/cuda.hpp>
70-
7170
namespace cv {
7271
namespace dnn {
7372
CV__DNN_INLINE_NS_BEGIN
@@ -159,23 +158,6 @@ class BackendRegistry
159158
}
160159
#endif
161160

162-
#ifdef HAVE_CUDA
163-
static inline bool cudaDeviceSupportsFp16() {
164-
if (cv::cuda::getCudaEnabledDeviceCount() <= 0)
165-
return false;
166-
const int devId = cv::cuda::getDevice();
167-
if (devId<0)
168-
return false;
169-
cv::cuda::DeviceInfo dev_info(devId);
170-
if (!dev_info.isCompatible())
171-
return false;
172-
int version = dev_info.majorVersion() * 10 + dev_info.minorVersion();
173-
if (version < 53)
174-
return false;
175-
return true;
176-
}
177-
#endif
178-
179161
private:
180162
BackendRegistry()
181163
{
@@ -247,9 +229,10 @@ class BackendRegistry
247229
#endif
248230

249231
#ifdef HAVE_CUDA
250-
if (haveCUDA()) {
232+
if (haveCUDA() && cuda4dnn::isDeviceCompatible())
233+
{
251234
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA));
252-
if (cudaDeviceSupportsFp16())
235+
if (cuda4dnn::doesDeviceSupportFP16())
253236
backends.push_back(std::make_pair(DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16));
254237
}
255238
#endif
@@ -1189,19 +1172,6 @@ struct Net::Impl : public detail::NetImplBase
11891172
preferableBackend = DNN_BACKEND_DEFAULT;
11901173
preferableTarget = DNN_TARGET_CPU;
11911174
skipInfEngineInit = false;
1192-
1193-
#ifdef HAVE_CUDA
1194-
if (cv::cuda::getCudaEnabledDeviceCount() > 0)
1195-
{
1196-
cuda4dnn::csl::CSLContext context;
1197-
context.stream = cuda4dnn::csl::Stream(true);
1198-
context.cublas_handle = cuda4dnn::csl::cublas::Handle(context.stream);
1199-
context.cudnn_handle = cuda4dnn::csl::cudnn::Handle(context.stream);
1200-
1201-
auto d2h_stream = cuda4dnn::csl::Stream(true); // stream for background D2H data transfers
1202-
cudaInfo = std::unique_ptr<CudaInfo_t>(new CudaInfo_t(std::move(context), std::move(d2h_stream)));
1203-
}
1204-
#endif
12051175
}
12061176

12071177
Ptr<DataLayer> netInputLayer;
@@ -1300,13 +1270,6 @@ struct Net::Impl : public detail::NetImplBase
13001270
}
13011271

13021272
Ptr<BackendWrapper> wrapper = wrapMat(preferableBackend, preferableTarget, host);
1303-
#ifdef HAVE_CUDA
1304-
if (preferableBackend == DNN_BACKEND_CUDA)
1305-
{
1306-
auto cudaWrapper = wrapper.dynamicCast<CUDABackendWrapper>();
1307-
cudaWrapper->setStream(cudaInfo->context.stream, cudaInfo->d2h_stream);
1308-
}
1309-
#endif
13101273
backendWrappers[data] = wrapper;
13111274
return wrapper;
13121275
}
@@ -2374,10 +2337,57 @@ struct Net::Impl : public detail::NetImplBase
23742337
#endif
23752338
}
23762339

2377-
void initCUDABackend(const std::vector<LayerPin>& blobsToKeep_) {
2340+
void initCUDABackend(const std::vector<LayerPin>& blobsToKeep_)
2341+
{
23782342
CV_Assert(haveCUDA());
2343+
CV_Assert(preferableBackend == DNN_BACKEND_CUDA);
23792344

23802345
#ifdef HAVE_CUDA
2346+
if (cuda4dnn::getDeviceCount() <= 0)
2347+
CV_Error(Error::StsError, "No CUDA capable device found.");
2348+
2349+
if (cuda4dnn::getDevice() < 0)
2350+
CV_Error(Error::StsError, "No CUDA capable device selected.");
2351+
2352+
if (!cuda4dnn::isDeviceCompatible())
2353+
CV_Error(Error::GpuNotSupported, "OpenCV was not built to work with the selected device. Please check CUDA_ARCH_PTX or CUDA_ARCH_BIN in your build configuration.");
2354+
2355+
if (preferableTarget == DNN_TARGET_CUDA_FP16 && !cuda4dnn::doesDeviceSupportFP16())
2356+
CV_Error(Error::StsError, "The selected CUDA device does not support FP16 operations.");
2357+
2358+
if (!cudaInfo)
2359+
{
2360+
cuda4dnn::csl::CSLContext context;
2361+
context.stream = cuda4dnn::csl::Stream(true);
2362+
context.cublas_handle = cuda4dnn::csl::cublas::Handle(context.stream);
2363+
context.cudnn_handle = cuda4dnn::csl::cudnn::Handle(context.stream);
2364+
2365+
auto d2h_stream = cuda4dnn::csl::Stream(true); // stream for background D2H data transfers
2366+
cudaInfo = std::unique_ptr<CudaInfo_t>(new CudaInfo_t(std::move(context), std::move(d2h_stream)));
2367+
cuda4dnn::checkVersions();
2368+
}
2369+
2370+
cudaInfo->workspace = cuda4dnn::csl::Workspace(); // release workspace memory if any
2371+
2372+
for (auto& layer : layers)
2373+
{
2374+
auto& ld = layer.second;
2375+
if (ld.id == 0)
2376+
{
2377+
for (auto& wrapper : ld.inputBlobsWrappers)
2378+
{
2379+
auto cudaWrapper = wrapper.dynamicCast<CUDABackendWrapper>();
2380+
cudaWrapper->setStream(cudaInfo->context.stream, cudaInfo->d2h_stream);
2381+
}
2382+
}
2383+
2384+
for (auto& wrapper : ld.outputBlobsWrappers)
2385+
{
2386+
auto cudaWrapper = wrapper.dynamicCast<CUDABackendWrapper>();
2387+
cudaWrapper->setStream(cudaInfo->context.stream, cudaInfo->d2h_stream);
2388+
}
2389+
}
2390+
23812391
for (auto& layer : layers)
23822392
{
23832393
auto& ld = layer.second;
@@ -2653,11 +2663,11 @@ struct Net::Impl : public detail::NetImplBase
26532663
if (IS_DNN_CUDA_TARGET(preferableTarget) && !nextEltwiseLayer.empty())
26542664
{
26552665
// we create a temporary backend node for eltwise layer to obtain the eltwise configuration
2656-
auto context = cudaInfo->context; /* make a copy so that initCUDA doesn't modify cudaInfo */
2666+
cuda4dnn::csl::CSLContext context; // assume that initCUDA and EltwiseOp does not use the context during init
26572667
const auto node = nextData->layerInstance->initCUDA(&context, nextData->inputBlobsWrappers, nextData->outputBlobsWrappers);
26582668
const auto eltwiseNode = node.dynamicCast<cuda4dnn::EltwiseOpBase>();
26592669
if (eltwiseNode->op != cuda4dnn::EltwiseOpType::SUM || !eltwiseNode->coeffs.empty())
2660-
nextEltwiseLayer = Ptr<EltwiseLayer>();
2670+
nextEltwiseLayer = Ptr<EltwiseLayer>();
26612671

26622672
// check for variable channels
26632673
auto& inputs = nextData->inputBlobs;

0 commit comments

Comments
 (0)