Skip to content

Commit fe750ec

Browse files
committed
Refactor FastGRNN CUDA Setup
1 parent cbba9f8 commit fe750ec

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

pytorch/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Install appropriate CUDA and cuDNN [Tested with >= CUDA 8.1 and cuDNN >= 6.1]
6868
```
6969
pip install -r requirements-gpu.txt
7070
pip install -e .
71+
pip install -e edgeml_pytorch/cuda/
7172
```
7273

7374
**Note**: For using the optimized FastGRNNCUDA implementation, it is recommended to use CUDA v10.1, gcc 7.5 and cuDNN v7.6 and torch==1.4.0. Also, there are some known issues when compiling custom CUDA kernels on Windows [pytorch/#11004](https://github.com/pytorch/pytorch/issues/11004).
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import setuptools #enables develop
2+
import os
3+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4+
from edgeml_pytorch.utils import findCUDA
5+
6+
if findCUDA() is not None:
7+
setuptools.setup(
8+
name='fastgrnn_cuda',
9+
ext_modules=[
10+
CUDAExtension('fastgrnn_cuda', [
11+
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
12+
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
13+
]),
14+
],
15+
cmdclass={
16+
'build_ext': BuildExtension
17+
}
18+
)

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010
import edgeml_pytorch.utils as utils
1111

12-
if utils.findCUDA() is not None:
13-
import fastgrnn_cuda
12+
try:
13+
if utils.findCUDA() is not None:
14+
import fastgrnn_cuda
15+
except:
16+
pass
1417

1518

1619
# All the matrix vector computations of the form Wx are done

pytorch/setup.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
11
import setuptools #enables develop
22
import os
3-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4-
from edgeml_pytorch.utils import findCUDA
5-
6-
if findCUDA() is not None:
7-
setuptools.setup(
8-
name='fastgrnn_cuda',
9-
ext_modules=[
10-
CUDAExtension('fastgrnn_cuda', [
11-
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
12-
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
13-
]),
14-
],
15-
cmdclass={
16-
'build_ext': BuildExtension
17-
}
18-
)
193

204
setuptools.setup(
215
name='edgeml',

0 commit comments

Comments
 (0)