|
13 | 13 | if utils.findCUDA() is not None:
|
14 | 14 | import fastgrnn_cuda
|
15 | 15 | except:
|
| 16 | + print("Running without FastGRNN CUDA") |
16 | 17 | pass
|
17 | 18 |
|
18 | 19 |
|
@@ -354,29 +355,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
354 | 355 | self._name = name
|
355 | 356 |
|
356 | 357 | if wRank is None:
|
357 |
| - self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device)) |
| 358 | + self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device)) |
358 | 359 | self.W1 = torch.empty(0)
|
359 | 360 | self.W2 = torch.empty(0)
|
360 | 361 | else:
|
361 | 362 | self.W = torch.empty(0)
|
362 |
| - self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device)) |
363 |
| - self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device)) |
| 363 | + self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device)) |
| 364 | + self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device)) |
364 | 365 |
|
365 | 366 | if uRank is None:
|
366 |
| - self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device)) |
| 367 | + self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device)) |
367 | 368 | self.U1 = torch.empty(0)
|
368 | 369 | self.U2 = torch.empty(0)
|
369 | 370 | else:
|
370 | 371 | self.U = torch.empty(0)
|
371 |
| - self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device)) |
372 |
| - self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device)) |
| 372 | + self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device)) |
| 373 | + self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device)) |
373 | 374 |
|
374 | 375 | self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
|
375 | 376 |
|
376 |
| - self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
377 |
| - self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
378 |
| - self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device)) |
379 |
| - self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device)) |
| 377 | + self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 378 | + self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 379 | + self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device)) |
| 380 | + self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device)) |
380 | 381 |
|
381 | 382 | @property
|
382 | 383 | def name(self):
|
|
0 commit comments