Skip to content

Conversation

@fanlu
Copy link

@fanlu fanlu commented Feb 11, 2020

Hi, @danpovey @csukuangfj. please review this tdnnf version. thanks

self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
bottleneck_dim=bottleneck_dim * 2,
time_stride=0)
bottleneck_dim=bottleneck_dim * 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm.. I'm a bit surprised this * 2 is here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm.. I think you were assuming that the final layer's bottleneck is always twice the TDNN-F layers' bottleneck.
In fact we generally leave the final layer's bottleneck at 256, which for some reason seems to work across a range
of conditions. You could make that a separate configuration value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I have checked the param shape of kaldi's model, I don't find the difference betwieen the final layer and previous layers what you said.

'tdnnf12.linear':'time_offsets': array([-3,  0]), 'params': (128, 2048)
'tdnnf12.affine':'time_offsets': array([0, 3]), 'params': (1024, 256)
'tdnnf13.linear':'time_offsets': array([-3,  0]), 'params': (128, 2048)
'tdnnf13.affine':'time_offsets': array([0, 3]), 'params': (1024, 256)
'prefinal-l':'params': (256, 1024)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm.. I'm a bit surprised this * 2 is here?

* 2 is used here to follow what kaldi does.
I've changed it to be configurable in this pullrequest: #3925

x_left = x[:, :, :cur_context]
x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor]
x_right = x[:, :, -cur_context:]
x = torch.cat([x_left, x_mid, x_right], dim=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that you are doing this manually rather than using a 1d convolution. This could be quite slow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want subsample the length of window only rather than left_context and right_context. And this is slower than before training, but it worked. please help me to write this 1d convolution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What might have happened here is that you tripled the dimension in the middle of the network.
This would have led to a system with many more parameters for your "dilation" system.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just subsample the t_out_length from (24+150+24) to (24+50+24) manually, the number of parameters will not increase than stride kernel(2,2) version. I explained this code behaviour in the picture below.

self.conv = nn.Conv1d(in_channels=dim,
out_channels=bottleneck_dim,
kernel_size=kernel_size,
dilation=dilation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should never need the dilation parameter. I think we discussed this before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... instead of using dilation, do a 3-fold subsampling after the last layer that had stride=1. Please don't argue about this! I remember last time was quite painful.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah, I find the discussed info before. I just to make the length of output is equal to the supervision

2020-02-11 14:59:29,145 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 202])
2020-02-11 14:59:29,150 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 200])
2020-02-11 14:59:29,156 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 198])
2020-02-11 14:59:29,162 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 98])
2020-02-11 14:59:29,175 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 92])
2020-02-11 14:59:29,184 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 86])
2020-02-11 14:59:29,195 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 80])
2020-02-11 14:59:29,204 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 74])
2020-02-11 14:59:29,215 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 68])
2020-02-11 14:59:29,224 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 62])
2020-02-11 14:59:29,233 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 56])
2020-02-11 14:59:29,243 (model_tdnnf3:172) DEBUG: x shape is torch.Size([1, 1024, 50])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these shape of output generated by tdnnf layers correct?

2020-02-11 19:21:40,284 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 202])
2020-02-11 19:21:40,291 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 200])
2020-02-11 19:21:40,297 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198])
2020-02-11 19:21:40,306 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198])
2020-02-11 19:21:40,317 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 192])
2020-02-11 19:21:40,327 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 186])
2020-02-11 19:21:40,336 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 180])
2020-02-11 19:21:40,346 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 174])
2020-02-11 19:21:40,355 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 168])
2020-02-11 19:21:40,364 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 162])
2020-02-11 19:21:40,373 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 156])
2020-02-11 19:21:40,382 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 150])
2020-02-11 19:21:40,382 (model_tdnnf3:182) DEBUG: x shape is torch.Size([1, 1024, 50])

@csukuangfj
Copy link
Contributor

csukuangfj commented Feb 11, 2020 via email

@csukuangfj
Copy link
Contributor

csukuangfj commented Feb 11, 2020 via email

@fanlu
Copy link
Author

fanlu commented Feb 11, 2020

I will help you tomorrow. Sent from myMail for iOS Tuesday, 11 February 2020, 19:04 +0800 from [email protected] [email protected]:

@fanlu commented on this pull request. ---------------------------------------------------------------------- In egs/aishell/s10/chain/model.py : > @@ -174,6 +170,13 @@ def forward(self, x): # tdnnf requires input of shape [N, C, T] for i in range(len(self.tdnnfs)): x = self.tdnnfsi + # stride manually, do not stride context + if self.tdnnfs[i].time_stride == 0: + cur_context = sum(self.time_stride_list[i:]) + x_left = x[:, :, :cur_context] + x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor] + x_right = x[:, :, -cur_context:] + x = torch.cat([x_left, x_mid, x_right], dim=2) I want subsample the length of window only rather than left_context and right_context. And this is slower than before training, but it worked. please help me to write this 1d convolution. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub , or unsubscribe .

thanks

@danpovey
Copy link
Contributor

danpovey commented Feb 11, 2020 via email

@danpovey
Copy link
Contributor

danpovey commented Feb 11, 2020 via email

@fanlu
Copy link
Author

fanlu commented Feb 11, 2020

Not really: somewhere near the beginning, there should be a subsampling by a factor of 3. The original script was much closer to being correct, it only needed a very small change.

On Tue, Feb 11, 2020 at 7:22 PM fanlu @.> wrote: @.* commented on this pull request. ------------------------------ In egs/aishell/s10/chain/tdnnf_layer.py <#3923 (comment)>: > # conv requires [N, C, T] self.conv = nn.Conv1d(in_channels=dim, out_channels=bottleneck_dim, kernel_size=kernel_size, + dilation=dilation, Are these shape of output generated by tdnnf layers correct? 2020-02-11 19:21:40,284 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 202]) 2020-02-11 19:21:40,291 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 200]) 2020-02-11 19:21:40,297 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198]) 2020-02-11 19:21:40,306 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198]) 2020-02-11 19:21:40,317 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 192]) 2020-02-11 19:21:40,327 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 186]) 2020-02-11 19:21:40,336 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 180]) 2020-02-11 19:21:40,346 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 174]) 2020-02-11 19:21:40,355 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 168]) 2020-02-11 19:21:40,364 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 162]) 2020-02-11 19:21:40,373 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 156]) 2020-02-11 19:21:40,382 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 150]) 2020-02-11 19:21:40,382 (model_tdnnf3:182) DEBUG: x shape is torch.Size([1, 1024, 50]) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#3923?email_source=notifications&email_token=AAZFLO4CMV65SVJAOGDDNSTRCKDA7A5CNFSM4KS576D2YY3PNVWWK3TUL52HS4DFWFIHK3DMKJSXC5LFON2FEZLWNFSXPKTDN5WW2ZLOORPWSZGOCVAPXHQ#discussion_r377575919>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO22IBPTMVDRSDCUNMTRCKDA7ANCNFSM4KS576DQ .

Not really: somewhere near the beginning, there should be a subsampling by a factor of 3. The original script was much closer to being correct, it only needed a very small change.

On Tue, Feb 11, 2020 at 7:22 PM fanlu @.> wrote: @.* commented on this pull request. ------------------------------ In egs/aishell/s10/chain/tdnnf_layer.py <#3923 (comment)>: > # conv requires [N, C, T] self.conv = nn.Conv1d(in_channels=dim, out_channels=bottleneck_dim, kernel_size=kernel_size, + dilation=dilation, Are these shape of output generated by tdnnf layers correct? 2020-02-11 19:21:40,284 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 202]) 2020-02-11 19:21:40,291 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 200]) 2020-02-11 19:21:40,297 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198]) 2020-02-11 19:21:40,306 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 198]) 2020-02-11 19:21:40,317 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 192]) 2020-02-11 19:21:40,327 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 186]) 2020-02-11 19:21:40,336 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 180]) 2020-02-11 19:21:40,346 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 174]) 2020-02-11 19:21:40,355 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 168]) 2020-02-11 19:21:40,364 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 162]) 2020-02-11 19:21:40,373 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 156]) 2020-02-11 19:21:40,382 (model_tdnnf3:179) DEBUG: x shape is torch.Size([1, 1024, 150]) 2020-02-11 19:21:40,382 (model_tdnnf3:182) DEBUG: x shape is torch.Size([1, 1024, 50]) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#3923?email_source=notifications&email_token=AAZFLO4CMV65SVJAOGDDNSTRCKDA7A5CNFSM4KS576D2YY3PNVWWK3TUL52HS4DFWFIHK3DMKJSXC5LFON2FEZLWNFSXPKTDN5WW2ZLOORPWSZGOCVAPXHQ#discussion_r377575919>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO22IBPTMVDRSDCUNMTRCKDA7ANCNFSM4KS576DQ .

OK, I will try the original script again with affine=False in bn and prefinal layer with 256.

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

I have changed the dilation to stride and kernel_size of (linear,affine) from (3,1) to (2,2)
and have changed affine=False in all BN1d in all stride version.
here is the result:

original stride kernel(3,1) stride kernel(2,2) dilation kernel(2,2)
dev_cer 7.36 7.35 6.67
dev_wer 15.55 15.48 14.72
test_cer 9.25 9.08 8.38
test_wer 18.08 17.85 17.08

@csukuangfj
Copy link
Contributor

You can format the table to the format supported by Markdown and Github will render it correctly.

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

I have test this two version on speed pertub data with hires feature and tree leaves from 5k to 4k corresponding tdnn_1c.
Here is the result

stride kernel(2,2) dilation kernel(2,2)
dev_cer 7.06 6.50
dev_wer 15.18 14.52
test_cer 8.57 7.81
test_wer 17.33 16.42

@danpovey
Copy link
Contributor

danpovey commented Feb 12, 2020 via email

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

Maybe this is clear about stride and dilation. sorry about the form
tdnnf

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

no stride with left and right context may be give the model more context info to make model more robust with test result

@danpovey
Copy link
Contributor

danpovey commented Feb 12, 2020 via email

kernel_size=1)
self.batchnorm2 = nn.BatchNorm1d(num_features=small_dim)

def forward(self, x):
Copy link
Contributor

@danpovey danpovey Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised you didn't implement the TDNN_F layer in the "obvious" way with 1-d convolution.
[oh, sorry, this is prefinal layer]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside OrthonormalLinear, it is nn.Conv1d. So it is ineeded implemented via 1-d convolution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_size=1 doesn't look right. Some extremely weird stuff is going on in this PR.

Copy link
Author

@fanlu fanlu Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, dan, this code's behavior is not different with before code's. just changed the param. the original stride version use kernel_size=1 as default also

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's getting less clear to me, not more clear, and in any case the code is not right.
You showed various experimental numbers but you never said with any clarity what code or parameters each one corresponded to.
I would prefer if you just went back to the original code and made the small and very specific changes that I asked for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about this

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

sorry, dan, I am confused about diff. Do I have to do anything else?

That doesn't sound right to me. Please show me the diff.

On Wed, Feb 12, 2020 at 10:48 AM fanlu @.***> wrote: no stride with left and right context may be give the model more context info to make model more robust with test result — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#3923?email_source=notifications&email_token=AAZFLOYPZFPOZWX64U5AW2LRCNPQFA5CNFSM4KS576D2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOELPETMQ#issuecomment-584993202>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3EE3LS6CG6YTH6TX3RCNPQFANCNFSM4KS576DQ .

@danpovey
Copy link
Contributor

sorry, dan, I am confused about diff. Do I have to do anything else?

That doesn't sound right to me. Please show me the diff.

On Wed, Feb 12, 2020 at 10:48 AM fanlu @.***> wrote: no stride with left and right context may be give the model more context info to make model more robust with test result — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#3923?email_source=notifications&email_token=AAZFLOYPZFPOZWX64U5AW2LRCNPQFA5CNFSM4KS576D2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOELPETMQ#issuecomment-584993202>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3EE3LS6CG6YTH6TX3RCNPQFANCNFSM4KS576DQ .

You showed a table with columns:
stride kernel(2,2) | dilation kernel(2,2)
What I want to know is: what was the code difference between those two? Or more specifically, what code and configuration was each one run with? If done correctly, those two numbers should be the same. The stride one is the one you should be using, it will be faster, but there must have been a bug.

original stride kernel(3,1) stride kernel(2,2) dilation kernel(2,2)
original stride kernel(3,1) stride kernel(2,2) dilation kernel(2,2)
original stride kernel(3,1) stride kernel(2,2) dilation kernel(2,2)

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

  • stride kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            self.kernel_size, self.stride = 2, 1
        else:
            self.kernel_size, self.stride = 1, 3

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=self.kernel_size,
                                        stride=1)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=self.kernel_size,
                                stride=self.stride)
  • dilation kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            kernel_size, dilation = 2, time_stride
        else:
            kernel_size, dilation = 1, 1

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=kernel_size,
                                        dilation=dilation)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=kernel_size,
                                dilation=dilation)


            if self.tdnnfs[i].time_stride == 0:
                cur_context = sum(self.time_stride_list[i:])
                x_left = x[:, :, :cur_context]
                x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor]
                x_right = x[:, :, -cur_context:]
                x = torch.cat([x_left, x_mid, x_right], dim=2)

Noticed that the dilation version must be different with original kaldi version.

@csukuangfj
Copy link
Contributor

@fanlu
I used [-1, 0, 1] instead of the one stride kernel(2,2) you are currently using since
it is not easy to support [-1, 0] and [0, 1], i..e, [t-1, t] and [t, t+1] in PyTorch.

Regarding the stride and dilation:

  • there is no need to use dilation in PyTorch. The original network can produce
    an output with length 50.

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

Compared the speed of this two version

100iter stride kernel(2,2) dilation kernel(2,2)
s 81-82s 33-34s

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

Yes. I have noticed that. So I called this is a new version.
If we use all kernel_size==2 when time_stride>0
after OrthonormalLinear conv1d convolution operator, we have t_out=t_in-1, and then affine conv1d convolution operator, we have another t_out_2=t_out-1, then we can assumed that we got [t-1, t] and [t, t+1] convolution after linear and affine, Am I right?

@fanlu
I used [-1, 0, 1] instead of the one stride kernel(2,2) you are currently using since
it is not easy to support [-1, 0] and [0, 1], i..e, [t-1, t] and [t, t+1] in PyTorch.

Regarding the stride and dilation:

  • there is no need to use dilation in PyTorch. The original network can produce
    an output with length 50.

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

But never mind, the result of stride kernel(3,1) and stride kernel(2,2) looks similar
@csukuangfj Do you know how to accelerate this operation? it's look weird also.

if self.tdnnfs[i].time_stride == 0:
                cur_context = sum(self.time_stride_list[i:])
                x_left = x[:, :, :cur_context]
                x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor]
                x_right = x[:, :, -cur_context:]
                x = torch.cat([x_left, x_mid, x_right], dim=2)

@danpovey
Copy link
Contributor

danpovey commented Feb 12, 2020

  • stride kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            self.kernel_size, self.stride = 2, 1
        else:
            self.kernel_size, self.stride = 1, 3

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=self.kernel_size,
                                        stride=1)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=self.kernel_size,
                                stride=self.stride)
  • dilation kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            kernel_size, dilation = 2, time_stride
        else:
            kernel_size, dilation = 1, 1

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=kernel_size,
                                        dilation=dilation)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=kernel_size,
                                dilation=dilation)
            if self.tdnnfs[i].time_stride == 0:
                cur_context = sum(self.time_stride_list[i:])
                x_left = x[:, :, :cur_context]
                x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor]
                x_right = x[:, :, -cur_context:]
                x = torch.cat([x_left, x_mid, x_right], dim=2)

Noticed that the dilation version must be different with original kaldi version.

OK, there are a couple problems with this. I'm confident that's what you are doing in the middle layer (with time_stride=0 in Kaldi) is not right, but what's going on is a bit strange and I'm not quite sure what it's doing, partly because I'm not sure where that code goes. But it doesn't matter.. let me describe how I think it should be done.

There issue of clarity, in that the time_stride is very unintuitive in this context. Better
to have pairs of (kernel_size, subsampling_factor). Let this be pairs
(2, 1), (2, 1), ..., (2, 3), (1, 1), (2, 1), (2, 1), ..

EDIT:
there is actually no issue about the asymmetry of the context, with the [0,1] vs [-1,1]. No special padding is needed, it is already correct. The explanation is a bit complicated.

Later we can figure out why in your experiments stride was not faster than dilation-- it definitely should be faster, as the flops differ by nearly a factor of 3. There must have been some problem or bug or unexpected machine load.

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

Thanks, Dan, Let me try to implement this based on #3925

  • stride kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            self.kernel_size, self.stride = 2, 1
        else:
            self.kernel_size, self.stride = 1, 3

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=self.kernel_size,
                                        stride=1)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=self.kernel_size,
                                stride=self.stride)
  • dilation kernel(2,2) time_stride=[1,1,1,0,3,3,3,3,3,3,3,3]
if time_stride > 0:
            kernel_size, dilation = 2, time_stride
        else:
            kernel_size, dilation = 1, 1

        # linear requires [N, C, T]
        self.linear = OrthonormalLinear(dim=dim,
                                        bottleneck_dim=bottleneck_dim,
                                        kernel_size=kernel_size,
                                        dilation=dilation)
        
        # affine requires [N, C, T]
        self.affine = nn.Conv1d(in_channels=bottleneck_dim,
                                out_channels=dim,
                                kernel_size=kernel_size,
                                dilation=dilation)
            if self.tdnnfs[i].time_stride == 0:
                cur_context = sum(self.time_stride_list[i:])
                x_left = x[:, :, :cur_context]
                x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor]
                x_right = x[:, :, -cur_context:]
                x = torch.cat([x_left, x_mid, x_right], dim=2)

Noticed that the dilation version must be different with original kaldi version.

OK, there are a couple problems with this. I'm confident that's what you are doing in the middle layer (with time_stride=0 in Kaldi) is not right, but what's going on is a bit strange and I'm not quite sure what it's doing, partly because I'm not sure where that code goes. But it doesn't matter.. let me describe how I think it should be done.

There issue of clarity, in that the time_stride is very unintuitive in this context. Better
to have pairs of (kernel_size, subsampling_factor). Let this be pairs
(2, 1), (2, 1), ..., (2, 3), (1, 1), (2, 1), (2, 1), ..

Next, there is an issue of the asymmetry of the context, since you can't have (-1,0) like in Kaldi. I'd make the second conv1d, in self.affine, have padding="same" and remove the last time index after that convolution. This actually makes a difference, there is a time asymmetry and the bypass gets done wrong otherwise.

Later we can figure out why in your experiments stride was not faster than dilation-- it definitely should be faster, as the flops differ by nearly a factor of 3. There must have been some problem or bug or unexpected machine load.

@danpovey
Copy link
Contributor

.. also, I suspect that in your "dilated" version of the code, the layer with (in Kaldi terminology) time_stride=0 had effectively time_stride=1. That may be why the WER differed.
In the terminology introduced my previous comment where there are pairs of (kernel_size, subsampling_factor) with values
(2, 1), (2, 1), ..., (2, 3), (1, 1), (2, 1), (2, 1), ..
you could just change the (2, 3), (1, 1) near the middle to: (2, 1), (2, 3).

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

sorry about mismatch

100iter stride kernel(2,2) dilation kernel(2,2)
s 33-34s 81-82s

Compared the speed of this two version

100iter stride kernel(2,2) dilation kernel(2,2)
s 81-82s 33-34s

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

kernel_size_list=[2, 2, **2, 1**, 2, 2, 2, 2, 2, 2, 2, 2],
subsampling_factor_list=[1, 1, **3, 1**, 1, 1, 1, 1, 1, 1, 1, 1],
kernel_size_list=[2, 2, **2, 1**, 2, 2, 2, 2, 2, 2, 2, 2],
subsampling_factor_list=[1, 1, **1, 3**, 1, 1, 1, 1, 1, 1, 1, 1],

which is correct?

.. also, I suspect that in your "dilated" version of the code, the layer with (in Kaldi terminology) time_stride=0 had effectively time_stride=1. That may be why the WER differed.
In the terminology introduced my previous comment where there are pairs of (kernel_size, subsampling_factor) with values
(2, 1), (2, 1), ..., (2, 3), (1, 1), (2, 1), (2, 1), ..
you could just change the (2, 3), (1, 1) near the middle to: (2, 1), (2, 3).

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

@danpovey this is the pr of stride kernel(2,2) https://github.com/mobvoi/kaldi/pull/1, please have a look

@danpovey
Copy link
Contributor

danpovey commented Feb 12, 2020 via email

@fanlu
Copy link
Author

fanlu commented Feb 12, 2020

considering the kaldi version. the left or right context is sum([1,1,1,0,3,3,3,3,3,3,3,3])=27
but in this version, after stride manually150->50, using dilation=3 can see the maximum 3*8(after stride manually)->3*24=72(original window)context window in each side,
And this operation dose not overstep the boundary of egs context, because of I do not stride left and right context
And I think this is the reason of training speed that is slower than stride version
But never mind. Thanks for explaining the tdnnf's logic

OK. That's not right though, I'm afraid. (I mean, the code is doing something super strange that can't at all be the right thing).

On Wed, Feb 12, 2020 at 9:22 PM fanlu @.> wrote: @.* commented on this pull request. ------------------------------ In egs/aishell/s10/chain/model.py <#3923 (comment)>: > @@ -174,6 +170,13 @@ def forward(self, x): # tdnnf requires input of shape [N, C, T] for i in range(len(self.tdnnfs)): x = self.tdnnfsi + # stride manually, do not stride context + if self.tdnnfs[i].time_stride == 0: + cur_context = sum(self.time_stride_list[i:]) + x_left = x[:, :, :cur_context] + x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor] + x_right = x[:, :, -cur_context:] + x = torch.cat([x_left, x_mid, x_right], dim=2) Just subsample the t_out_length from (24+150+24) to (24+50+24) manually, the number of parameters will not increase than stride kernel(2,2) version. I explained this code behaviour in the picture below. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#3923?email_source=notifications&email_token=AAZFLO5QFOWTYNDLCHYZLZLRCPZXZA5CNFSM4KS576D2YY3PNVWWK3TUL52HS4DFWFIHK3DMKJSXC5LFON2FEZLWNFSXPKTDN5WW2ZLOORPWSZGOCVHCY6Y#discussion_r378245403>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO7KK2F6XKI26GJMRRDRCPZXZANCNFSM4KS576DQ .

@danpovey
Copy link
Contributor

I am assuming this is not ready to merge, since I am still seeing dilation in the code?
Is this still being developed?

@fanlu
Copy link
Author

fanlu commented Feb 18, 2020

I will close it since of #3925

@fanlu fanlu closed this Feb 18, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants