-
Notifications
You must be signed in to change notification settings - Fork 24
[RNN, GRU] support causal convolution with RNN and GRU layers #360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,11 @@ | |
| from ....utils import divide_if_divisible, is_xma_available | ||
| from ...cache import GenerationCache | ||
| from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay | ||
| from ..activations import get_activation_function, is_glu | ||
| from ..convolution import ParameterizedConv1d | ||
| from ..linear import ParameterizedLinear | ||
| from ..normalization import get_normalization_function | ||
| from .causal_convolution import causal_convolution | ||
| from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence | ||
|
|
||
|
|
||
|
|
@@ -30,6 +33,8 @@ def __init__( | |
| output_size: int, | ||
| num_input_heads: int, | ||
| num_weight_heads: int, | ||
| kernel_size: int | None, | ||
| activation_function: str | None, | ||
| add_bias: bool, | ||
| gradient_clipping: float | None, | ||
| initializer_range: float, | ||
|
|
@@ -53,18 +58,39 @@ def __init__( | |
|
|
||
| self.state_head_dim = state_head_dim | ||
| self.state_size = self.num_heads * self.state_head_dim | ||
|
|
||
| self.kernel_size = kernel_size | ||
| self.activation_string = activation_function | ||
| self.layer_idx = layer_idx | ||
| self.use_padding_free_transformer = use_padding_free_transformer | ||
|
|
||
| self.x_shape = self.num_input_heads * self.state_head_dim | ||
| self.g_shape = self.num_heads * self.state_head_dim | ||
|
|
||
| std = initializer_range | ||
| if init_method == "mup": | ||
| std /= math.sqrt(m_width) | ||
| self.state_weight_std = std | ||
|
|
||
| self.input_projection = ParameterizedLinear( | ||
| input_size, (self.num_input_heads + self.num_heads) * self.state_head_dim, bias=add_bias, std=std | ||
| ) | ||
| self.input_projection = ParameterizedLinear(input_size, self.x_shape + self.g_shape, bias=add_bias, std=std) | ||
|
|
||
| if kernel_size is None: | ||
| assert activation_function is None | ||
| else: | ||
| assert not is_glu(self.activation_string) | ||
|
|
||
| self.conv1d = ParameterizedConv1d( | ||
| in_channels=self.state_size, | ||
| out_channels=self.state_size, | ||
| kernel_size=kernel_size, | ||
| bias=add_bias, | ||
| padding=kernel_size - 1, | ||
| groups=self.state_size, | ||
| std=std, | ||
| ) | ||
|
|
||
| mark_parameter_as_mup_learning_rate(self.conv1d.weight) | ||
|
|
||
| self.activation_function = get_activation_function(self.activation_string) | ||
|
|
||
| self.state_weight = nn.Parameter(torch.empty(self.num_heads, self.state_head_dim, self.state_head_dim)) | ||
|
|
||
|
|
@@ -75,17 +101,17 @@ def __init__( | |
|
|
||
| self.norm = get_normalization_function(normalization_function, self.state_size) | ||
|
|
||
| self.reset_parameters() | ||
|
|
||
| mark_parameter_as_mup_learning_rate(self.input_projection.weight) | ||
| mark_parameter_as_mup_learning_rate(self.state_weight) | ||
| mark_parameter_as_mup_learning_rate(self.output_projection.weight) | ||
|
|
||
| mark_parameter_as_no_weight_decay(self.state_weight) | ||
|
|
||
| self.reset_parameters() | ||
|
|
||
| def forward( | ||
| self, | ||
| input: torch.Tensor, | ||
| x: torch.Tensor, | ||
| cache_params: GenerationCache | None = None, | ||
| attention_mask: torch.Tensor | None = None, | ||
| cu_seqlens: torch.Tensor | None = None, | ||
|
|
@@ -98,46 +124,56 @@ def forward( | |
| assert cu_seqlens is None | ||
| assert max_seqlen is None | ||
|
|
||
| batch_size, sequence_length = input.size()[:2] | ||
| B, S = x.size()[:2] | ||
|
|
||
| if attention_mask is not None: | ||
| cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask) | ||
| input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens) | ||
| x = pack_sequence(inputs=x, cu_seqlens=cu_seqlens) | ||
|
|
||
| input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx) | ||
| c, h = (None, None) if cache_params is None else cache_params.get_cache(self.layer_idx) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a mismatch with the
The |
||
|
|
||
| input = self.input_projection(input) | ||
| input, gate = input.split( | ||
| (self.num_input_heads * self.state_head_dim, self.num_heads * self.state_head_dim), dim=-1 | ||
| ) | ||
| x = self.input_projection(x) | ||
| x, g = x.split((self.x_shape, self.g_shape), dim=-1) | ||
|
|
||
| input = input.view(*input.size()[:-1], -1, self.state_head_dim) | ||
| if self.kernel_size is None: | ||
| x = self.activation_function(x) | ||
| else: | ||
| x, c = causal_convolution( | ||
| hidden_states=x, | ||
| input_state=c, | ||
| attention_mask=attention_mask, | ||
| conv1d_weight=self.conv1d.weight, | ||
| conv1d_bias=self.conv1d.bias, | ||
| conv1d_num_groups=self.state_size, | ||
| return_cache_state=cache_params is not None, | ||
| activation_string=self.activation_string, | ||
| conv1d_padding=self.kernel_size - 1, | ||
| conv1d_stride=1, | ||
| ) | ||
|
|
||
| input, input_state = rnn( | ||
| input=input, | ||
| x = x.view(*x.size()[:-1], -1, self.state_head_dim) | ||
|
|
||
| x, h = rnn( | ||
| input=x, | ||
| weight=self.state_weight, | ||
| input_state=input_state, | ||
| input_state=h, | ||
| gradient_clipping=self.gradient_clipping, | ||
| cu_seqlens=cu_seqlens, | ||
| max_seqlen=max_seqlen, | ||
| ) | ||
|
|
||
| if not self.use_padding_free_transformer and attention_mask is not None: | ||
| input = unpack_sequence( | ||
| inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:]) | ||
| ) | ||
| x = unpack_sequence(inputs=x, cu_seqlens=cu_seqlens, output_shape=(B, S, *x.size()[1:])) | ||
|
|
||
| if cache_params is not None: | ||
| cache_params.update(state=input_state, num_tokens_added=input.size(1), layer_idx=self.layer_idx) | ||
|
|
||
| input = input.view(*input.size()[:-2], -1) | ||
|
|
||
| input = input * F.silu(gate) | ||
| input = self.norm(input) | ||
| cache_params.update(conv_state=c, ssm_state=h, num_tokens_added=x.size(1), layer_idx=self.layer_idx) | ||
|
|
||
| input = self.output_projection(input) | ||
| x = x.flatten(-2, -1) | ||
| x = x * F.silu(g) | ||
| x = self.norm(x) | ||
| x = self.output_projection(x) | ||
|
|
||
| return input | ||
| return x | ||
|
|
||
| @torch.no_grad() | ||
| def reset_parameters(self) -> None: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a mismatch with the
GenerationCacheimplementation for RNN/GRU layers (_RNNCache).c,h) fromcache_params.get_cache(self.layer_idx). However, the current_RNNCache.get_cachereturns a single tensor, which will lead to aTypeErrorduring unpacking.cache_params.updateis called withconv_stateandssm_statearguments, but the_RNNCache.updatemethod does not accept these keyword arguments, which will also cause aTypeError.It appears the
_RNNCacheclass needs to be updated to store and manage two separate states (one for the convolution and one for the GRU state) when causal convolution is used.