@@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
34
34
)
35
35
36
36
# Init.
37
- mask = torch .zeros (* init .shape ).bool ()
37
+ mask = torch .zeros (* init .shape , device = log_potentials . device ).bool ()
38
38
mask [:, :, :, 0 , 0 ].diagonal (0 , - 2 , - 1 ).fill_ (True )
39
39
init = semiring .fill (init , mask , semiring .one )
40
40
@@ -61,10 +61,13 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
61
61
c [:, :, : K - 1 , 0 ] = semiring .sum (
62
62
torch .stack ([c .data [:, :, : K - 1 , 0 ], lp [:, :, 1 :K ]], dim = - 1 )
63
63
)
64
- end = torch .min (lengths ) - 1
65
- mask = torch .zeros (* init .shape ).bool ()
64
+ mask = torch .zeros (* init .shape , device = log_potentials .device ).bool ()
65
+ mask_length = torch .arange (bin_N ).view (1 , bin_N , 1 ).expand (batch , bin_N , C )
66
+ mask_length = mask_length .to (log_potentials .device )
66
67
for k in range (1 , K - 1 ):
67
- mask [:, :, : end - (k - 1 ), k - 1 , k ].diagonal (0 , - 2 , - 1 ).fill_ (True )
68
+ mask_length_k = mask_length < (lengths - 1 - (k - 1 )).view (batch , 1 , 1 )
69
+ mask_length_k = semiring .convert (mask_length_k )
70
+ mask [:, :, :, k - 1 , k ].diagonal (0 , - 2 , - 1 ).masked_fill_ (mask_length_k , True )
68
71
init = semiring .fill (init , mask , semiring .one )
69
72
70
73
K_1 = K - 1
@@ -83,37 +86,37 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
83
86
v = semiring .sum (semiring .sum (final [:, :, 0 , :, 0 , :].contiguous ()))
84
87
return v , [log_potentials ]
85
88
86
- # def _dp_standard(self, edge, lengths=None, force_grad=False):
87
- # semiring = self.semiring
88
- # ssize = semiring.size()
89
- # edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
90
- # edge.requires_grad_(True)
91
-
92
- # # Init
93
- # # All paths starting at N of len K
94
- # alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]
95
-
96
- # # All paths finishing at N with label C
97
- # beta = self._make_chart(N, (batch, C), edge, force_grad)
98
- # semiring.one_ (beta[0].data )
99
-
100
- # # Main.
101
- # for n in range(1, N):
102
- # alpha[:, :, n - 1] = semiring.dot(
103
- # beta[n - 1].view(ssize, batch, 1, 1, C),
104
- # edge[:, :, n - 1].view(ssize, batch, K, C, C),
105
- # )
106
-
107
- # t = max(n - K, -1)
108
- # f1 = torch.arange(n - 1, t, -1)
109
- # f2 = torch.arange(1, len(f1) + 1)
110
- # beta[n][:] = semiring.sum(
111
- # torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
112
- # )
113
- # v = semiring.sum(
114
- # torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
115
- # )
116
- # return v, [edge], beta
89
+ def _dp_standard (self , edge , lengths = None , force_grad = False ):
90
+ semiring = self .semiring
91
+ ssize = semiring .size ()
92
+ edge , batch , N , K , C , lengths = self ._check_potentials (edge , lengths )
93
+ edge .requires_grad_ (True )
94
+
95
+ # Init
96
+ # All paths starting at N of len K
97
+ alpha = self ._make_chart (1 , (batch , N , K , C ), edge , force_grad )[0 ]
98
+
99
+ # All paths finishing at N with label C
100
+ beta = self ._make_chart (N , (batch , C ), edge , force_grad )
101
+ beta [ 0 ] = semiring .fill (beta [0 ], torch . tensor ( True ). to ( edge . device ), semiring . one )
102
+
103
+ # Main.
104
+ for n in range (1 , N ):
105
+ alpha [:, :, n - 1 ] = semiring .dot (
106
+ beta [n - 1 ].view (ssize , batch , 1 , 1 , C ),
107
+ edge [:, :, n - 1 ].view (ssize , batch , K , C , C ),
108
+ )
109
+
110
+ t = max (n - K , - 1 )
111
+ f1 = torch .arange (n - 1 , t , - 1 )
112
+ f2 = torch .arange (1 , len (f1 ) + 1 )
113
+ beta [n ][:] = semiring .sum (
114
+ torch .stack ([alpha [:, :, a , b ] for a , b in zip (f1 , f2 )], dim = - 1 )
115
+ )
116
+ v = semiring .sum (
117
+ torch .stack ([beta [l - 1 ][:, i ] for i , l in enumerate (lengths )], dim = 1 )
118
+ )
119
+ return v , [edge ], beta
117
120
118
121
@staticmethod
119
122
def to_parts (sequence , extra , lengths = None ):
0 commit comments