Skip to content

Commit 27843a9

Browse files
committed
[WIP]: add TDNNF to pytorch.
1 parent 139efff commit 27843a9

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
4+
# Apache 2.0
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
11+
def constraint_orthonormal_internal(M):
12+
'''
13+
Refer to
14+
void ConstrainOrthonormalInternal(BaseFloat scale, CuMatrixBase<BaseFloat> *M)
15+
from
16+
https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982
17+
18+
Note that we always use the **floating** case.
19+
'''
20+
assert M.ndim == 2
21+
22+
num_rows = M.size(0)
23+
num_cols = M.size(1)
24+
25+
assert num_rows <= num_cols
26+
27+
# P = M * M^T
28+
P = torch.mm(M, M.t())
29+
P_PT = torch.mm(P, P.t())
30+
31+
trace_P = torch.trace(P)
32+
trace_P_P = torch.trace(P_PT)
33+
34+
scale = torch.sqrt(trace_P_P / trace_P)
35+
36+
ratio = trace_P_P * num_rows / (trace_P * trace_P)
37+
assert ratio > 0.99
38+
39+
update_speed = 0.125
40+
41+
if ratio > 1.02:
42+
update_speed *= 0.5
43+
if ratio > 1.1:
44+
update_speed *= 0.5
45+
46+
identity = torch.eye(num_rows, dtype=P.dtype, device=P.device)
47+
P = P - scale * scale * identity
48+
49+
alpha = update_speed / (scale * scale)
50+
M = M - 4 * alpha * torch.mm(P, M)
51+
return M
52+
53+
54+
class FactorizedTDNN(nn.Module):
55+
'''
56+
This class implements the following topology in kaldi:
57+
tdnnf-layer name=tdnnf2 $tdnnf_opts dim=1024 bottleneck-dim=128 time-stride=1
58+
59+
References:
60+
- http://danielpovey.com/files/2018_interspeech_tdnnf.pdf
61+
- ConstrainOrthonormalInternal() from
62+
https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982
63+
'''
64+
65+
def __init__(self, dim, bottleneck_dim, time_stride):
66+
super().__init__()
67+
assert time_stride in [0, 1]
68+
69+
if time_stride == 0:
70+
kernel_size = 1
71+
else:
72+
kernel_size = 3
73+
74+
# WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
75+
# and [0, 1] for the second affine layer;
76+
# We use [-1, 0, 1] for the first linear layer
77+
78+
# conv requires [N, C, T]
79+
self.conv = nn.Conv1d(in_channels=dim,
80+
out_channels=bottleneck_dim,
81+
kernel_size=kernel_size,
82+
bias=False)
83+
84+
# affine requires [N, T, C]
85+
self.affine = nn.Linear(in_features=bottleneck_dim, out_features=dim)
86+
87+
# batchnorm requires [N, C, T]
88+
self.batchnorm = nn.BatchNorm1d(num_features=dim)
89+
90+
def forward(self, x):
91+
# input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T]
92+
assert x.ndim == 3
93+
x = self.conv(x)
94+
# at this point, x is [N, C, T]
95+
96+
x = x.permute(0, 2, 1)
97+
# at this point, x is [N, T, C]
98+
99+
x = self.affine(x)
100+
# at this point, x is [N, T, C]
101+
102+
x = F.relu(x)
103+
# at this point, x is [N, T, C]
104+
105+
x = x.permute(0, 2, 1)
106+
# at this point, x is [N, C, T]
107+
x = self.batchnorm(x)
108+
109+
# TODO(fangjun): implement GeneralDropoutComponent in PyTorch
110+
111+
# at this point, x is [N, C, T]
112+
return x
113+
114+
def constraint_orthonormal(self):
115+
state_dict = self.conv.state_dict()
116+
w = state_dict['weight']
117+
# w is of shape [out_channels, in_channels, kernel_size]
118+
out_channels = w.size(0)
119+
in_channels = w.size(1)
120+
kernel_size = w.size(2)
121+
122+
w = w.reshape(out_channels, -1)
123+
124+
num_rows = w.size(0)
125+
num_cols = w.size(1)
126+
127+
need_transpose = False
128+
if num_rows > num_cols:
129+
w = w.t()
130+
need_transpose = True
131+
132+
w = constraint_orthonormal_internal(w)
133+
134+
if need_transpose:
135+
w = w.t()
136+
137+
w = w.reshape(out_channels, in_channels, kernel_size)
138+
139+
state_dict['weight'] = w
140+
self.conv.load_state_dict(state_dict)
141+
142+
143+
def _test_constraint_orthonormal():
144+
145+
def compute_loss(M):
146+
P = torch.mm(M, M.t())
147+
P_PT = torch.mm(P, P.t())
148+
149+
trace_P = torch.trace(P)
150+
trace_P_P = torch.trace(P_PT)
151+
152+
scale = torch.sqrt(trace_P_P / trace_P)
153+
154+
identity = torch.eye(P.size(0), dtype=P.dtype, device=P.device)
155+
Q = P / (scale * scale) - identity
156+
loss = torch.norm(Q, p='fro') # Frobenius norm
157+
158+
return loss
159+
160+
w = torch.randn(6, 8) * 10
161+
loss = []
162+
loss.append(compute_loss(w))
163+
for i in range(15):
164+
w = constraint_orthonormal_internal(w)
165+
loss.append(compute_loss(w))
166+
# TODO(fangjun): draw the loss using matplotlib
167+
print(loss)
168+
169+
170+
def _test_factorized_tdnn():
171+
N = 1
172+
T = 10
173+
C = 4
174+
model = FactorizedTDNN(dim=C, bottleneck_dim=2, time_stride=1)
175+
x = torch.arange(N * T * C).reshape(N, C, T).float()
176+
y = model(x)
177+
assert y.size(2) == T - 2
178+
179+
180+
if __name__ == '__main__':
181+
torch.manual_seed(20200130)
182+
_test_factorized_tdnn()
183+
_test_constraint_orthonormal()

0 commit comments

Comments
 (0)