Skip to content

Commit dff1c0e

Browse files
committed
add findCUDA method
1 parent 79bd969 commit dff1c0e

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import edgeml_pytorch.utils as utils
1111

12-
if "CUDA_HOME" in os.environ:
12+
if utils.findCUDA() is not None:
1313
import fastgrnn_cuda
1414

1515
def onnx_exportable_rnn(input, fargs, cell, output):
@@ -320,7 +320,7 @@ class FastGRNNCUDACell(RNNCell):
320320
'''
321321
def __init__(self, input_size, hidden_size, zetaInit=1.0, nuInit=-4.0, name="FastGRNNCUDACell"):
322322
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, "sigmoid", "tanh", 1, 1, 2)
323-
if not "CUDA_HOME" in os.environ:
323+
if utils.findCUDA() is None:
324324
raise Exception('FastGRNNCUDACell is supported only on GPU devices.')
325325
self._input_size = input_size
326326
self._hidden_size = hidden_size

pytorch/edgeml_pytorch/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,40 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
3-
3+
import sys
4+
import os
45
import numpy as np
56
import torch
67
import torch.nn.functional as F
78
import torch.optim as optim
89

910

11+
def findCUDA():
12+
'''Finds the CUDA install path.'''
13+
# Guess #1
14+
IS_WINDOWS = sys.platform == 'win32'
15+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
16+
if cuda_home is None:
17+
# Guess #2
18+
try:
19+
which = 'where' if IS_WINDOWS else 'which'
20+
nvcc = subprocess.check_output(
21+
[which, 'nvcc']).decode().rstrip('\r\n')
22+
cuda_home = os.path.dirname(os.path.dirname(nvcc))
23+
except Exception:
24+
# Guess #3
25+
if IS_WINDOWS:
26+
cuda_homes = glob.glob(
27+
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
28+
if len(cuda_homes) == 0:
29+
cuda_home = ''
30+
else:
31+
cuda_home = cuda_homes[0]
32+
else:
33+
cuda_home = '/usr/local/cuda'
34+
if not os.path.exists(cuda_home):
35+
cuda_home = None
36+
return cuda_home
37+
1038
def multiClassHingeLoss(logits, labels):
1139
'''
1240
MultiClassHingeLoss to match C++ Version - No pytorch internal version

pytorch/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import setuptools #enables develop
22
import os
33
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4+
from edgeml_pytorch.utils import findCUDA
45

5-
if "CUDA_HOME" in os.environ:
6+
if findCUDA() is not None:
67
setup(
78
name='fastgrnn_cuda',
89
ext_modules=[

0 commit comments

Comments
 (0)