Skip to content

Commit 635c899

Browse files
committed
Add FastGRNNCUDA to FastCells and Fix torch.randn() Argument Errors
1 parent 9020ed3 commit 635c899

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

examples/pytorch/FastCells/fastcell_example.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def main():
5757
gate_nonlinearity=gate_non_linearity,
5858
update_nonlinearity=update_non_linearity,
5959
wRank=wRank, uRank=uRank)
60+
elif cell == "FastGRNNCUDA":
61+
FastCell = FastGRNNCUDACell(inputDims, hiddenDims,
62+
gate_nonlinearity=gate_non_linearity,
63+
update_nonlinearity=update_non_linearity,
64+
wRank=wRank, uRank=uRank)
6065
elif cell == "FastRNN":
6166
FastCell = FastRNNCell(inputDims, hiddenDims,
6267
update_nonlinearity=update_non_linearity,

examples/pytorch/FastCells/helpermethods.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def getArgs():
8888
'train.npy and test.npy')
8989

9090
parser.add_argument('-c', '--cell', type=str, default="FastGRNN",
91-
help='Choose between [FastGRNN, FastRNN, UGRNN' +
92-
', GRU, LSTM], default: FastGRNN')
91+
help='Choose between [FastGRNN, FastGRNNCUDA, FastRNN,' +
92+
' UGRNN, GRU, LSTM], default: FastGRNN')
9393

9494
parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True,
9595
help='Input Dimension of RNN, each timestep will ' +

pytorch/edgeml_pytorch/cuda/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
name='fastgrnn_cuda',
99
ext_modules=[
1010
CUDAExtension('fastgrnn_cuda', [
11-
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
12-
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
11+
'fastgrnn_cuda.cpp',
12+
'fastgrnn_cuda_kernel.cu',
1313
]),
1414
],
1515
cmdclass={

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
if utils.findCUDA() is not None:
1414
import fastgrnn_cuda
1515
except:
16+
print("Running without FastGRNN CUDA")
1617
pass
1718

1819

@@ -354,29 +355,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
354355
self._name = name
355356

356357
if wRank is None:
357-
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device))
358+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device))
358359
self.W1 = torch.empty(0)
359360
self.W2 = torch.empty(0)
360361
else:
361362
self.W = torch.empty(0)
362-
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device))
363-
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device))
363+
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device))
364+
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device))
364365

365366
if uRank is None:
366-
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device))
367+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device))
367368
self.U1 = torch.empty(0)
368369
self.U2 = torch.empty(0)
369370
else:
370371
self.U = torch.empty(0)
371-
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device))
372-
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device))
372+
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device))
373+
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device))
373374

374375
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
375376

376-
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device))
377-
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device))
378-
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device))
379-
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device))
377+
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
378+
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
379+
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device))
380+
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device))
380381

381382
@property
382383
def name(self):

0 commit comments

Comments
 (0)