Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lm_engine/hf_models/config/sequence_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class _GRUArgs(BaseArgs):
add_bias: bool = False
normalization_function: str | None = None
gradient_clipping: float | None = None
kernel_size: int | None = None
activation_function: str | None = None

def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "gru"
Expand All @@ -92,6 +94,8 @@ class _RNNArgs(BaseArgs):
add_bias: bool = False
normalization_function: str | None = None
gradient_clipping: float | None = None
kernel_size: int | None = None
activation_function: str | None = None

def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "rnn"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def get_sequence_mixer(
num_weight_heads=block.num_weight_heads,
num_forget_weight_heads=block.num_forget_weight_heads,
num_reset_weight_heads=block.num_reset_weight_heads,
kernel_size=block.kernel_size,
activation_function=block.activation_function,
add_bias=block.add_bias,
gradient_clipping=block.gradient_clipping,
initializer_range=config.initializer_range,
Expand All @@ -72,6 +74,8 @@ def get_sequence_mixer(
output_size=config.hidden_size,
num_input_heads=block.num_input_heads,
num_weight_heads=block.num_weight_heads,
kernel_size=block.kernel_size,
activation_function=block.activation_function,
add_bias=block.add_bias,
gradient_clipping=block.gradient_clipping,
initializer_range=config.initializer_range,
Expand Down
123 changes: 76 additions & 47 deletions lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from ....utils import divide_if_divisible, is_xma_available
from ...cache import GenerationCache
from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay
from ..activations import get_activation_function, is_glu
from ..convolution import ParameterizedConv1d
from ..linear import ParameterizedLinear
from ..normalization import get_normalization_function
from .causal_convolution import causal_convolution
from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence


Expand All @@ -34,6 +37,8 @@ def __init__(
num_weight_heads: int,
num_forget_weight_heads: int,
num_reset_weight_heads: int,
kernel_size: int | None,
activation_function: str | None,
add_bias: bool,
gradient_clipping: float | None,
initializer_range: float,
Expand Down Expand Up @@ -71,25 +76,47 @@ def __init__(
divide_if_divisible(self.num_heads, self.num_reset_weight_heads)

self.gradient_clipping = gradient_clipping

self.state_head_dim = state_head_dim
self.state_size = self.num_heads * self.state_head_dim

self.kernel_size = kernel_size
self.activation_string = activation_function
self.layer_idx = layer_idx
self.use_padding_free_transformer = use_padding_free_transformer

self.x_shape = self.num_input_heads * self.state_head_dim
self.xf_shape = self.num_forget_input_heads * self.state_head_dim
self.xr_shape = self.num_reset_input_heads * self.state_head_dim
self.g_shape = self.num_heads * self.state_head_dim

std = initializer_range
if init_method == "mup":
std /= math.sqrt(m_width)
self.state_weight_std = std

self.input_projection = ParameterizedLinear(
input_size,
(self.num_input_heads + self.num_forget_input_heads + self.num_reset_input_heads + self.num_heads)
* self.state_head_dim,
bias=add_bias,
std=std,
input_size, self.x_shape + self.xf_shape + self.xr_shape + self.g_shape, bias=add_bias, std=std
)

if kernel_size is None:
assert activation_function is None
else:
assert not is_glu(self.activation_string)

self.conv1d = ParameterizedConv1d(
in_channels=self.state_size,
out_channels=self.state_size,
kernel_size=kernel_size,
bias=add_bias,
padding=kernel_size - 1,
groups=self.state_size,
std=std,
)

mark_parameter_as_mup_learning_rate(self.conv1d.weight)

self.activation_function = get_activation_function(self.activation_string)

self.state_weight = nn.Parameter(
torch.empty(
self.num_weight_heads + self.num_forget_weight_heads + self.num_reset_weight_heads,
Expand All @@ -105,17 +132,17 @@ def __init__(

self.norm = get_normalization_function(normalization_function, self.state_size)

self.reset_parameters()

mark_parameter_as_mup_learning_rate(self.input_projection.weight)
mark_parameter_as_mup_learning_rate(self.state_weight)
mark_parameter_as_mup_learning_rate(self.output_projection.weight)

mark_parameter_as_no_weight_decay(self.state_weight)

self.reset_parameters()

def forward(
self,
input: torch.Tensor,
x: torch.Tensor,
cache_params: GenerationCache | None = None,
attention_mask: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
Expand All @@ -128,62 +155,64 @@ def forward(
assert cu_seqlens is None
assert max_seqlen is None

batch_size, sequence_length = input.size()[:2]
B, S = x.size()[:2]

if attention_mask is not None:
cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask)
input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens)

input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx)

input = self.input_projection(input)
input, forget_input, reset_input, gate = input.split(
(
self.num_input_heads * self.state_head_dim,
self.num_forget_input_heads * self.state_head_dim,
self.num_reset_input_heads * self.state_head_dim,
self.num_heads * self.state_head_dim,
),
dim=-1,
)
x = pack_sequence(inputs=x, cu_seqlens=cu_seqlens)

c, h = (None, None) if cache_params is None else cache_params.get_cache(self.layer_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a mismatch with the GenerationCache implementation for RNN/GRU layers (_RNNCache).

  1. This line attempts to unpack two values (c, h) from cache_params.get_cache(self.layer_idx). However, the current _RNNCache.get_cache returns a single tensor, which will lead to a TypeError during unpacking.
  2. Similarly, on line 209, cache_params.update is called with conv_state and ssm_state arguments, but the _RNNCache.update method does not accept these keyword arguments, which will also cause a TypeError.

It appears the _RNNCache class needs to be updated to store and manage two separate states (one for the convolution and one for the GRU state) when causal convolution is used.


x = self.input_projection(x)
x, xf, xr, g = x.split((self.x_shape, self.xf_shape, self.xr_shape, self.g_shape), dim=-1)

if self.kernel_size is None:
x = self.activation_function(x)
else:
x, c = causal_convolution(
hidden_states=x,
input_state=c,
attention_mask=attention_mask,
conv1d_weight=self.conv1d.weight,
conv1d_bias=self.conv1d.bias,
conv1d_num_groups=self.state_size,
return_cache_state=cache_params is not None,
activation_string=self.activation_string,
conv1d_padding=self.kernel_size - 1,
conv1d_stride=1,
)

input, forget_input, reset_input = [
i.view(*i.size()[:-1], -1, self.state_head_dim) for i in (input, forget_input, reset_input)
]
x, xf, xr = [i.view(*i.size()[:-1], -1, self.state_head_dim) for i in (x, xf, xr)]

weight, forget_weight, reset_weight = self.state_weight.split(
W, Wf, Wr = self.state_weight.split(
(self.num_weight_heads, self.num_forget_weight_heads, self.num_reset_weight_heads), dim=0
)

input, input_state = gru(
input=input,
weight=weight,
forget_input=forget_input,
forget_weight=forget_weight,
reset_input=reset_input,
reset_weight=reset_weight,
input_state=input_state,
x, h = gru(
input=x,
weight=W,
forget_input=xf,
forget_weight=Wf,
reset_input=xr,
reset_weight=Wr,
input_state=h,
gradient_clipping=self.gradient_clipping,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

if not self.use_padding_free_transformer and attention_mask is not None:
input = unpack_sequence(
inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:])
)
x = unpack_sequence(inputs=x, cu_seqlens=cu_seqlens, output_shape=(B, S, *x.size()[1:]))

if cache_params is not None:
cache_params.update(state=input_state, num_tokens_added=input.size(1), layer_idx=self.layer_idx)

input = input.view(*input.size()[:-2], -1)

input = input * F.silu(gate)
input = self.norm(input)
cache_params.update(conv_state=c, ssm_state=h, num_tokens_added=x.size(1), layer_idx=self.layer_idx)

input = self.output_projection(input)
x = x.flatten(-2, -1)
x = x * F.silu(g)
x = self.norm(x)
x = self.output_projection(x)

return input
return x

@torch.no_grad()
def reset_parameters(self) -> None:
Expand Down
94 changes: 65 additions & 29 deletions lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from ....utils import divide_if_divisible, is_xma_available
from ...cache import GenerationCache
from ...parameter import mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay
from ..activations import get_activation_function, is_glu
from ..convolution import ParameterizedConv1d
from ..linear import ParameterizedLinear
from ..normalization import get_normalization_function
from .causal_convolution import causal_convolution
from .utils import compute_cu_seqlens_and_max_seqlen_from_attention_mask, pack_sequence, unpack_sequence


Expand All @@ -30,6 +33,8 @@ def __init__(
output_size: int,
num_input_heads: int,
num_weight_heads: int,
kernel_size: int | None,
activation_function: str | None,
add_bias: bool,
gradient_clipping: float | None,
initializer_range: float,
Expand All @@ -53,18 +58,39 @@ def __init__(

self.state_head_dim = state_head_dim
self.state_size = self.num_heads * self.state_head_dim

self.kernel_size = kernel_size
self.activation_string = activation_function
self.layer_idx = layer_idx
self.use_padding_free_transformer = use_padding_free_transformer

self.x_shape = self.num_input_heads * self.state_head_dim
self.g_shape = self.num_heads * self.state_head_dim

std = initializer_range
if init_method == "mup":
std /= math.sqrt(m_width)
self.state_weight_std = std

self.input_projection = ParameterizedLinear(
input_size, (self.num_input_heads + self.num_heads) * self.state_head_dim, bias=add_bias, std=std
)
self.input_projection = ParameterizedLinear(input_size, self.x_shape + self.g_shape, bias=add_bias, std=std)

if kernel_size is None:
assert activation_function is None
else:
assert not is_glu(self.activation_string)

self.conv1d = ParameterizedConv1d(
in_channels=self.state_size,
out_channels=self.state_size,
kernel_size=kernel_size,
bias=add_bias,
padding=kernel_size - 1,
groups=self.state_size,
std=std,
)

mark_parameter_as_mup_learning_rate(self.conv1d.weight)

self.activation_function = get_activation_function(self.activation_string)

self.state_weight = nn.Parameter(torch.empty(self.num_heads, self.state_head_dim, self.state_head_dim))

Expand All @@ -75,17 +101,17 @@ def __init__(

self.norm = get_normalization_function(normalization_function, self.state_size)

self.reset_parameters()

mark_parameter_as_mup_learning_rate(self.input_projection.weight)
mark_parameter_as_mup_learning_rate(self.state_weight)
mark_parameter_as_mup_learning_rate(self.output_projection.weight)

mark_parameter_as_no_weight_decay(self.state_weight)

self.reset_parameters()

def forward(
self,
input: torch.Tensor,
x: torch.Tensor,
cache_params: GenerationCache | None = None,
attention_mask: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
Expand All @@ -98,46 +124,56 @@ def forward(
assert cu_seqlens is None
assert max_seqlen is None

batch_size, sequence_length = input.size()[:2]
B, S = x.size()[:2]

if attention_mask is not None:
cu_seqlens, max_seqlen = compute_cu_seqlens_and_max_seqlen_from_attention_mask(attention_mask)
input = pack_sequence(inputs=input, cu_seqlens=cu_seqlens)
x = pack_sequence(inputs=x, cu_seqlens=cu_seqlens)

input_state = None if cache_params is None else cache_params.get_cache(self.layer_idx)
c, h = (None, None) if cache_params is None else cache_params.get_cache(self.layer_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a mismatch with the _RNNCache implementation that will cause runtime errors.

  1. This line attempts to unpack two values (c, h) from cache_params.get_cache(self.layer_idx), but _RNNCache.get_cache returns a single tensor, which will cause a TypeError.
  2. On line 170, cache_params.update is called with conv_state and ssm_state, but the _RNNCache.update method signature doesn't support these arguments, leading to another TypeError.

The _RNNCache needs to be updated to handle separate states for the convolution and the RNN when causal convolution is enabled.


input = self.input_projection(input)
input, gate = input.split(
(self.num_input_heads * self.state_head_dim, self.num_heads * self.state_head_dim), dim=-1
)
x = self.input_projection(x)
x, g = x.split((self.x_shape, self.g_shape), dim=-1)

input = input.view(*input.size()[:-1], -1, self.state_head_dim)
if self.kernel_size is None:
x = self.activation_function(x)
else:
x, c = causal_convolution(
hidden_states=x,
input_state=c,
attention_mask=attention_mask,
conv1d_weight=self.conv1d.weight,
conv1d_bias=self.conv1d.bias,
conv1d_num_groups=self.state_size,
return_cache_state=cache_params is not None,
activation_string=self.activation_string,
conv1d_padding=self.kernel_size - 1,
conv1d_stride=1,
)

input, input_state = rnn(
input=input,
x = x.view(*x.size()[:-1], -1, self.state_head_dim)

x, h = rnn(
input=x,
weight=self.state_weight,
input_state=input_state,
input_state=h,
gradient_clipping=self.gradient_clipping,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

if not self.use_padding_free_transformer and attention_mask is not None:
input = unpack_sequence(
inputs=input, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *input.size()[1:])
)
x = unpack_sequence(inputs=x, cu_seqlens=cu_seqlens, output_shape=(B, S, *x.size()[1:]))

if cache_params is not None:
cache_params.update(state=input_state, num_tokens_added=input.size(1), layer_idx=self.layer_idx)

input = input.view(*input.size()[:-2], -1)

input = input * F.silu(gate)
input = self.norm(input)
cache_params.update(conv_state=c, ssm_state=h, num_tokens_added=x.size(1), layer_idx=self.layer_idx)

input = self.output_projection(input)
x = x.flatten(-2, -1)
x = x * F.silu(g)
x = self.norm(x)
x = self.output_projection(x)

return input
return x

@torch.no_grad()
def reset_parameters(self) -> None:
Expand Down