11#!/usr/bin/env python
2+ import glob
23import multiprocessing .pool
34import os
5+ import tarfile
6+ import urllib .request
7+ import warnings
48
59from setuptools import setup , find_packages , distutils
610from torch .utils .cpp_extension import BuildExtension
11+ from torch .utils .cpp_extension import CppExtension , include_paths
712
8- this_file = os .path .dirname (__file__ )
13+
14+ def download_extract (url , dl_path ):
15+ if not os .path .isfile (dl_path ):
16+ # Already downloaded
17+ urllib .request .urlretrieve (url , dl_path )
18+ if dl_path .endswith (".tar.gz" ) and os .path .isdir (dl_path [:- len (".tar.gz" )]):
19+ # Already extracted
20+ return
21+ tar = tarfile .open (dl_path )
22+ tar .extractall ('third_party/' )
23+ tar .close ()
24+
25+
26+ # Download/Extract openfst, boost
27+ download_extract ('https://github.com/parlance/ctcdecode/releases/download/v1.0/openfst-1.6.7.tar.gz' ,
28+ 'third_party/openfst-1.6.7.tar.gz' )
29+ download_extract ('https://github.com/parlance/ctcdecode/releases/download/v1.0/boost_1_67_0.tar.gz' ,
30+ 'third_party/boost_1_67_0.tar.gz' )
31+
32+ for file in ['third_party/kenlm/setup.py' , 'third_party/ThreadPool/ThreadPool.h' ]:
33+ if not os .path .exists (file ):
34+ warnings .warn ('File `{}` does not appear to be present. Did you forget `git submodule update`?' .format (file ))
35+
36+
37+ # Does gcc compile with this header and library?
38+ def compile_test (header , library ):
39+ dummy_path = os .path .join (os .path .dirname (__file__ ), "dummy" )
40+ command = "bash -c \" g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path \
41+ + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\" "
42+ return os .system (command ) == 0
43+
44+
45+ compile_args = ['-O3' , '-DKENLM_MAX_ORDER=6' , '-std=c++14' , '-fPIC' ]
46+ ext_libs = []
47+ if compile_test ('zlib.h' , 'z' ):
48+ compile_args .append ('-DHAVE_ZLIB' )
49+ ext_libs .append ('z' )
50+
51+ if compile_test ('bzlib.h' , 'bz2' ):
52+ compile_args .append ('-DHAVE_BZLIB' )
53+ ext_libs .append ('bz2' )
54+
55+ if compile_test ('lzma.h' , 'lzma' ):
56+ compile_args .append ('-DHAVE_XZLIB' )
57+ ext_libs .append ('lzma' )
58+
59+ third_party_libs = ["kenlm" , "openfst-1.6.7/src/include" , "ThreadPool" , "boost_1_67_0" , "utf8" ]
60+ compile_args .extend (['-DINCLUDE_KENLM' , '-DKENLM_MAX_ORDER=6' ])
61+ lib_sources = glob .glob ('third_party/kenlm/util/*.cc' ) + glob .glob ('third_party/kenlm/lm/*.cc' ) + glob .glob (
62+ 'third_party/kenlm/util/double-conversion/*.cc' ) + glob .glob ('third_party/openfst-1.6.7/src/lib/*.cc' )
63+ lib_sources = [fn for fn in lib_sources if not (fn .endswith ('main.cc' ) or fn .endswith ('test.cc' ))]
64+
65+ third_party_includes = [os .path .realpath (os .path .join ("third_party" , lib )) for lib in third_party_libs ]
66+ ctc_sources = glob .glob ('ctcdecode/src/*.cpp' )
67+
68+ extension = CppExtension (
69+ name = 'ctcdecode._ext.ctc_decode' ,
70+ package = True ,
71+ with_cuda = False ,
72+ sources = ctc_sources + lib_sources ,
73+ include_dirs = third_party_includes + include_paths (),
74+ libraries = ext_libs ,
75+ extra_compile_args = compile_args ,
76+ language = 'c++'
77+ )
978
1079
1180# monkey-patch for parallel compilation
@@ -40,7 +109,6 @@ def _single_compile(obj):
40109
41110# hack compile to support parallel compiling
42111distutils .ccompiler .CCompiler .compile = parallelCCompile
43- import build
44112
45113setup (
46114 name = "ctcdecode" ,
@@ -51,6 +119,6 @@ def _single_compile(obj):
5111952120 # Exclude the build files.
53121 packages = find_packages (exclude = ["build" ]),
54- ext_modules = [ build . extension ],
122+ ext_modules = [ extension ],
55123 cmdclass = {'build_ext' : BuildExtension }
56124)
0 commit comments