We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a484559 commit 5328ec5Copy full SHA for 5328ec5
torch_struct/linearchain.py
@@ -53,7 +53,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
53
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)
54
55
# Init
56
- init = torch.zeros(*chart.shape).bool()
+ init = torch.zeros_like(chart).bool()
57
init.diagonal(0, 3, 4).fill_(True)
58
chart = semiring.fill(chart, init, semiring.one)
59
0 commit comments