|
9 | 9 |
|
10 | 10 | import edgeml_pytorch.utils as utils
|
11 | 11 |
|
12 |
| -if utils.findCUDA() is not None: |
13 |
| - import fastgrnn_cuda |
| 12 | +try: |
| 13 | + if utils.findCUDA() is not None: |
| 14 | + import fastgrnn_cuda |
| 15 | +except: |
| 16 | + print("Running without FastGRNN CUDA") |
| 17 | + pass |
14 | 18 |
|
15 | 19 |
|
16 | 20 | # All the matrix vector computations of the form Wx are done
|
@@ -351,29 +355,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
351 | 355 | self._name = name
|
352 | 356 |
|
353 | 357 | if wRank is None:
|
354 |
| - self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device)) |
| 358 | + self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device)) |
355 | 359 | self.W1 = torch.empty(0)
|
356 | 360 | self.W2 = torch.empty(0)
|
357 | 361 | else:
|
358 | 362 | self.W = torch.empty(0)
|
359 |
| - self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device)) |
360 |
| - self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device)) |
| 363 | + self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device)) |
| 364 | + self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device)) |
361 | 365 |
|
362 | 366 | if uRank is None:
|
363 |
| - self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device)) |
| 367 | + self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device)) |
364 | 368 | self.U1 = torch.empty(0)
|
365 | 369 | self.U2 = torch.empty(0)
|
366 | 370 | else:
|
367 | 371 | self.U = torch.empty(0)
|
368 |
| - self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device)) |
369 |
| - self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device)) |
| 372 | + self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device)) |
| 373 | + self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device)) |
370 | 374 |
|
371 | 375 | self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
|
372 | 376 |
|
373 |
| - self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
374 |
| - self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device)) |
375 |
| - self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device)) |
376 |
| - self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device)) |
| 377 | + self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 378 | + self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device)) |
| 379 | + self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device)) |
| 380 | + self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device)) |
377 | 381 |
|
378 | 382 | @property
|
379 | 383 | def name(self):
|
|
0 commit comments