Skip to content

Commit 4a01c40

Browse files
authored
Update stream llm to get correct outputs and re-enable rerotated-attention test. (#656)
During the update of pytorch/HF, there seem to be a change of how causal mask was being handled. It seems like the attention.forward function used to get a `causal_mask` from the argument as `attention_mask` when is_causal is on. Now it seems like we would need to construct our own mask when `is_causal` is true. This was causing numerical issues in this test as well as on Llama2 qualitatively. This PR introduces construction of causal mask, as well as removing unnecessary tensor parallel config checks which simplifies the code quite a bit.
1 parent 7877444 commit 4a01c40

File tree

2 files changed

+25
-58
lines changed

2 files changed

+25
-58
lines changed

models/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py

Lines changed: 25 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,9 @@ def llama_pos_shift_attention_forward(
3737
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
3838
bsz, q_len, _ = hidden_states.size()
3939

40-
if self.config.pretraining_tp > 1:
41-
key_value_slicing = (
42-
self.num_key_value_heads * self.head_dim
43-
) // self.config.pretraining_tp
44-
query_slices = self.q_proj.weight.split(
45-
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
46-
)
47-
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
48-
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
49-
50-
query_states = [
51-
F.linear(hidden_states, query_slices[i])
52-
for i in range(self.config.pretraining_tp)
53-
]
54-
query_states = torch.cat(query_states, dim=-1)
55-
56-
key_states = [
57-
F.linear(hidden_states, key_slices[i])
58-
for i in range(self.config.pretraining_tp)
59-
]
60-
key_states = torch.cat(key_states, dim=-1)
61-
62-
value_states = [
63-
F.linear(hidden_states, value_slices[i])
64-
for i in range(self.config.pretraining_tp)
65-
]
66-
value_states = torch.cat(value_states, dim=-1)
67-
68-
else:
69-
query_states = self.q_proj(hidden_states)
70-
key_states = self.k_proj(hidden_states)
71-
value_states = self.v_proj(hidden_states)
40+
query_states = self.q_proj(hidden_states)
41+
key_states = self.k_proj(hidden_states)
42+
value_states = self.v_proj(hidden_states)
7243

7344
query_states = query_states.view(
7445
bsz, q_len, self.num_heads, self.head_dim
@@ -103,9 +74,9 @@ def llama_pos_shift_attention_forward(
10374
# repeat k/v heads if n_kv_heads < n_heads
10475
key_states = repeat_kv(key_states, self.num_key_value_groups)
10576
value_states = repeat_kv(value_states, self.num_key_value_groups)
106-
107-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
108-
self.head_dim
77+
softmax_scale = 1.0 / math.sqrt(self.head_dim)
78+
attn_weights = (
79+
torch.matmul(query_states, key_states.transpose(2, 3)) * softmax_scale
10980
)
11081

11182
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@@ -114,6 +85,23 @@ def llama_pos_shift_attention_forward(
11485
f" {attn_weights.size()}"
11586
)
11687

88+
# For causal mode, we use to get input mask, but now causal mode does not expect a mask
89+
# and we need to generate the causal mask ourselves.
90+
current_is_causal = False
91+
if self.is_causal and attention_mask is None and q_len > 1:
92+
current_is_causal = True
93+
if current_is_causal and attention_mask is None:
94+
bool_attention_mask = torch.ones(
95+
[query_states.shape[-2], key_states.shape[-2]],
96+
device=query_states.device,
97+
dtype=torch.bool,
98+
).tril()
99+
additive_attention_mask = torch.zeros_like(
100+
bool_attention_mask, dtype=attn_weights.dtype
101+
).masked_fill(bool_attention_mask.logical_not(), -10000)
102+
attn_weights = attn_weights + additive_attention_mask
103+
104+
# Legacy support to take in mask for non-causal mode.
117105
if attention_mask is not None:
118106
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
119107
raise ValueError(
@@ -132,30 +120,10 @@ def llama_pos_shift_attention_forward(
132120
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
133121
f" {attn_output.size()}"
134122
)
135-
136123
attn_output = attn_output.transpose(1, 2).contiguous()
137124
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
138-
139-
if self.config.pretraining_tp > 1:
140-
attn_output = attn_output.split(
141-
self.hidden_size // self.config.pretraining_tp, dim=2
142-
)
143-
o_proj_slices = self.o_proj.weight.split(
144-
self.hidden_size // self.config.pretraining_tp, dim=1
145-
)
146-
attn_output = sum(
147-
[
148-
F.linear(attn_output[i], o_proj_slices[i])
149-
for i in range(self.config.pretraining_tp)
150-
]
151-
)
152-
else:
153-
attn_output = self.o_proj(attn_output)
154-
155-
if not output_attentions:
156-
attn_weights = None
157-
158-
return attn_output, attn_weights, past_key_value
125+
attn_output = self.o_proj(attn_output)
126+
return attn_output, None, past_key_value
159127

160128

161129
def enable_llama_pos_shift_attention(model):

models/turbine_models/tests/stateless_llama_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def test_streaming_vmfb_comparison(self):
193193

194194
# See: https://github.com/nod-ai/SHARK-Turbine/issues/560
195195
# Developed issues related to the pytorch 2.3 upgrade.
196-
@unittest.expectedFailure
197196
def test_rerotated_torch_comparison(self):
198197
torch_str = llm_runner.run_torch_llm(
199198
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",

0 commit comments

Comments
 (0)