66import numpy as np
77from bonito .nn import Module , Convolution , LinearCRFEncoder , Serial , Permute , layers , from_dict
88
9- import seqdist .sparse
10- from seqdist .ctc_simple import logZ_cupy , viterbi_alignments
9+ if torch .cuda .is_available ():
10+ import seqdist .sparse
11+ from seqdist .ctc_simple import logZ_cupy , viterbi_alignments
1112from seqdist .core import SequenceDist , Max , Log , semiring
1213
1314
@@ -21,6 +22,58 @@ def get_stride(m):
2122 return 1
2223
2324
25+ def logZ_fwd_cpu (Ms , idx , v0 , vT , S ):
26+ T , N , C , NZ = Ms .shape
27+ Ms_grad = torch .zeros (T , N , C , NZ )
28+
29+ a = v0
30+ for t in range (T ):
31+ s = S .mul (a [:, idx ], Ms [t ])
32+ a = S .sum (s , - 1 )
33+ Ms_grad [t ] = s
34+ return S .sum (a + vT , dim = 1 ), Ms_grad
35+
36+
37+ def logZ_bwd_cpu (Ms , idx , vT , S , K = 1 ):
38+ assert (K == 1 )
39+ T , N , C , NZ = Ms .shape
40+ Ms = Ms .reshape (T , N , - 1 )
41+ idx_T = idx .flatten ().argsort ().to (dtype = torch .long ).reshape (C , NZ )
42+
43+ betas = torch .ones (T + 1 , N , C )
44+
45+ a = vT
46+ betas [T ] = a
47+ for t in reversed (range (T )):
48+ s = S .mul (a [:, idx_T // NZ ], Ms [t , :, idx_T ])
49+ a = S .sum (s , - 1 )
50+ betas [t ] = a
51+ return betas
52+
53+
54+ class _LogZ (torch .autograd .Function ):
55+ @staticmethod
56+ def forward (ctx , Ms , idx , v0 , vT , S :semiring ):
57+ idx = idx .to (dtype = torch .long , device = Ms .device )
58+ logZ , Ms_grad = logZ_fwd_cpu (Ms , idx , v0 , vT , S )
59+ ctx .save_for_backward (Ms_grad , Ms , idx , vT )
60+ ctx .semiring = S
61+ return logZ
62+
63+ @staticmethod
64+ def backward (ctx , grad ):
65+ Ms_grad , Ms , idx , vT = ctx .saved_tensors
66+ S = ctx .semiring
67+ T , N , C , NZ = Ms .shape
68+ betas = logZ_bwd_cpu (Ms , idx , vT , S )
69+ Ms_grad = S .mul (Ms_grad , betas [1 :,:,:,None ])
70+ Ms_grad = S .dsum (Ms_grad .reshape (T , N , - 1 ), dim = 2 ).reshape (T , N , C , NZ )
71+ return grad [None , :, None , None ] * Ms_grad , None , None , None , None , None
72+
73+ def sparse_logZ (Ms , idx , v0 , vT , S :semiring = Log ):
74+ return _LogZ .apply (Ms , idx , v0 , vT , S )
75+
76+
2477class CTC_CRF (SequenceDist ):
2578
2679 def __init__ (self , state_len , alphabet ):
@@ -43,7 +96,10 @@ def logZ(self, scores, S:semiring=Log):
4396 Ms = scores .reshape (T , N , - 1 , len (self .alphabet ))
4497 alpha_0 = Ms .new_full ((N , self .n_base ** (self .state_len )), S .one )
4598 beta_T = Ms .new_full ((N , self .n_base ** (self .state_len )), S .one )
46- return seqdist .sparse .logZ (Ms , self .idx , alpha_0 , beta_T , S )
99+ if not Ms .device .index is None :
100+ return seqdist .sparse .logZ (Ms , self .idx , alpha_0 , beta_T , S )
101+ else :
102+ return sparse_logZ (Ms , self .idx , alpha_0 , beta_T , S )
47103
48104 def normalise (self , scores ):
49105 return (scores - self .logZ (scores )[:, None ] / len (scores ))
@@ -58,7 +114,10 @@ def backward_scores(self, scores, S: semiring=Log):
58114 T , N , _ = scores .shape
59115 Ms = scores .reshape (T , N , - 1 , self .n_base + 1 )
60116 beta_T = Ms .new_full ((N , self .n_base ** (self .state_len )), S .one )
61- return seqdist .sparse .bwd_scores_cupy (Ms , self .idx , beta_T , S , K = 1 )
117+ if not Ms .device .index is None :
118+ return seqdist .sparse .bwd_scores_cupy (Ms , self .idx , beta_T , S , K = 1 )
119+ else :
120+ return logZ_bwd_cpu (Ms , self .idx , beta_T , S , K = 1 )
62121
63122 def compute_transition_probs (self , scores , betas ):
64123 T , N , C = scores .shape
0 commit comments