@@ -16,21 +16,57 @@ from setuptools import setup, find_packages
1616from setuptools.command.install import install
1717import subprocess
1818import os
19+ import sys
1920
20- class CustomInstall(install, object):
21- def run(self):
22- try:
23- import tensorflow as tensorflow
24- except:
25- raise ImportError('Failed to import Tensorflow. Tensorflow must be installed before installing nvidia-dali-tf-plugin')
21+ def get_module_path(module_name):
22+ module_path = ''
23+ for d in sys.path:
24+ possible_path = os.path.join(d, module_name)
25+ # skip current dir as this is plugin dir
26+ if os.path.isdir(possible_path) and len(d) != 0:
27+ module_path = possible_path
28+ break
29+ return module_path
30+
31+ def get_tf_build_flags():
32+ tf_cflags = ''
33+ tf_lflags = ''
34+ try:
35+ import tensorflow as tensorflow
36+ tf_cflags=" ".join(tensorflow.sysconfig.get_compile_flags())
37+ tf_lflags=" ".join(tensorflow.sysconfig.get_link_flags())
38+ except:
39+ tensorflow_path = get_module_path('tensorflow')
40+ if tensorflow_path is not '':
41+ tf_cflags=" ".join(["-I" + tensorflow_path + "/include", "-I" + tensorflow_path + "/include/external/nsync/public", "-D_GLIBCXX_USE_CXX11_ABI=0"])
42+ tf_lflags=" ".join(["-L" + tensorflow_path, "-ltensorflow_framework"])
43+
44+ if tf_cflags is '' and tf_lflags is '':
45+ raise ImportError('Could not find Tensorflow. Tensorflow must be installed before installing nvidia-dali-tf-plugin')
46+ return (tf_cflags, tf_lflags)
2647
48+ def get_dali_build_flags():
49+ dali_cflags = ''
50+ dali_lflags = ''
51+ try:
2752 import nvidia.dali.sysconfig as dali_sc
2853 dali_lib_path = dali_sc.get_lib_dir()
2954 dali_cflags=" ".join(dali_sc.get_compile_flags())
3055 dali_lflags=" ".join(dali_sc.get_link_flags())
56+ except:
57+ dali_path = get_module_path('nvidia/dali')
58+ if dali_path is not '':
59+ dali_cflags=" ".join(["-I" + dali_path + "/include", "-D_GLIBCXX_USE_CXX11_ABI=0"])
60+ dali_lflags=" ".join(["-L" + dali_path, "-ldali"])
61+ if dali_cflags is '' and dali_lflags is '':
62+ raise ImportError('Could not find DALI.')
63+ return (dali_cflags, dali_lflags)
3164
32- tf_cflags=" ".join(tensorflow.sysconfig.get_compile_flags())
33- tf_lflags=" ".join(tensorflow.sysconfig.get_link_flags())
65+ class CustomInstall(install, object):
66+ def run(self):
67+ dali_cflags, dali_lflags = get_dali_build_flags()
68+ dali_lib_path = get_module_path('nvidia/dali')
69+ tf_cflags, tf_lflags = get_tf_build_flags()
3470
3571 src_path = os.path.dirname(os.path.realpath(__file__)) + '/nvidia/dali'
3672 plugin_src = src_path + '/plugin/daliop.cc'
@@ -41,8 +77,6 @@ class CustomInstall(install, object):
4177
4278 super(CustomInstall, self).run()
4379
44-
45-
4680setup(name='nvidia-dali-tf-plugin',
4781 description='NVIDIA DALI Tensorflow plugin',
4882 url='https://github.com/NVIDIA/dali',
0 commit comments