Skip to content

Commit d5e1c73

Browse files
jantonguiraoJanuszL
authored andcommitted
Add fallback build configuration in case of no CUDA runtime library a… (#427)
* Add fallback build configuration in case of no CUDA runtime library at the moment of nvidia-dali-tf-plugin installation Signed-off-by: Joaquin Anton <janton@nvidia.com> * Make TF plugin compiling without CUDA runtime present Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent b9482fc commit d5e1c73

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

dali/python/tf_plugin/setup.py.in

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,57 @@ from setuptools import setup, find_packages
1616
from setuptools.command.install import install
1717
import subprocess
1818
import os
19+
import sys
1920

20-
class CustomInstall(install, object):
21-
def run(self):
22-
try:
23-
import tensorflow as tensorflow
24-
except:
25-
raise ImportError('Failed to import Tensorflow. Tensorflow must be installed before installing nvidia-dali-tf-plugin')
21+
def get_module_path(module_name):
22+
module_path = ''
23+
for d in sys.path:
24+
possible_path = os.path.join(d, module_name)
25+
# skip current dir as this is plugin dir
26+
if os.path.isdir(possible_path) and len(d) != 0:
27+
module_path = possible_path
28+
break
29+
return module_path
30+
31+
def get_tf_build_flags():
32+
tf_cflags = ''
33+
tf_lflags = ''
34+
try:
35+
import tensorflow as tensorflow
36+
tf_cflags=" ".join(tensorflow.sysconfig.get_compile_flags())
37+
tf_lflags=" ".join(tensorflow.sysconfig.get_link_flags())
38+
except:
39+
tensorflow_path = get_module_path('tensorflow')
40+
if tensorflow_path is not '':
41+
tf_cflags=" ".join(["-I" + tensorflow_path + "/include", "-I" + tensorflow_path + "/include/external/nsync/public", "-D_GLIBCXX_USE_CXX11_ABI=0"])
42+
tf_lflags=" ".join(["-L" + tensorflow_path, "-ltensorflow_framework"])
43+
44+
if tf_cflags is '' and tf_lflags is '':
45+
raise ImportError('Could not find Tensorflow. Tensorflow must be installed before installing nvidia-dali-tf-plugin')
46+
return (tf_cflags, tf_lflags)
2647

48+
def get_dali_build_flags():
49+
dali_cflags = ''
50+
dali_lflags = ''
51+
try:
2752
import nvidia.dali.sysconfig as dali_sc
2853
dali_lib_path = dali_sc.get_lib_dir()
2954
dali_cflags=" ".join(dali_sc.get_compile_flags())
3055
dali_lflags=" ".join(dali_sc.get_link_flags())
56+
except:
57+
dali_path = get_module_path('nvidia/dali')
58+
if dali_path is not '':
59+
dali_cflags=" ".join(["-I" + dali_path + "/include", "-D_GLIBCXX_USE_CXX11_ABI=0"])
60+
dali_lflags=" ".join(["-L" + dali_path, "-ldali"])
61+
if dali_cflags is '' and dali_lflags is '':
62+
raise ImportError('Could not find DALI.')
63+
return (dali_cflags, dali_lflags)
3164

32-
tf_cflags=" ".join(tensorflow.sysconfig.get_compile_flags())
33-
tf_lflags=" ".join(tensorflow.sysconfig.get_link_flags())
65+
class CustomInstall(install, object):
66+
def run(self):
67+
dali_cflags, dali_lflags = get_dali_build_flags()
68+
dali_lib_path = get_module_path('nvidia/dali')
69+
tf_cflags, tf_lflags = get_tf_build_flags()
3470

3571
src_path = os.path.dirname(os.path.realpath(__file__)) + '/nvidia/dali'
3672
plugin_src = src_path + '/plugin/daliop.cc'
@@ -41,8 +77,6 @@ class CustomInstall(install, object):
4177

4278
super(CustomInstall, self).run()
4379

44-
45-
4680
setup(name='nvidia-dali-tf-plugin',
4781
description='NVIDIA DALI Tensorflow plugin',
4882
url='https://github.com/NVIDIA/dali',

0 commit comments

Comments
 (0)