@@ -68,15 +68,16 @@ def forward(
6868 key_states = self .k_proj (hidden_states )
6969 value_states = self .v_proj (hidden_states )
7070
71- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
72- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
73- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
74-
7571 # >>> START AH Changes <<<
72+ # Loosen constraint on batch_size to allow parallel adapter composition
73+ query_states = query_states .view (- 1 , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
74+ key_states = key_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
75+ value_states = value_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
76+
7677 query_states , key_states , value_states = match_attn_matrices_for_parallel (
7778 query_states , key_states , value_states
7879 )
79- (attention_mask ,) = adjust_tensors_for_parallel (query_states , attention_mask )
80+ (attention_mask , position_ids ) = adjust_tensors_for_parallel (query_states , attention_mask , position_ids )
8081 # >>> END AH Changes <<<
8182
8283 cos , sin = self .rotary_emb (value_states , position_ids )
@@ -153,15 +154,16 @@ def forward(
153154 key_states = self .k_proj (hidden_states )
154155 value_states = self .v_proj (hidden_states )
155156
156- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
157- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
158- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
159-
160157 # >>> START AH Changes <<<
158+ # Loosen constraint on batch_size to allow parallel adapter composition
159+ query_states = query_states .view (- 1 , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
160+ key_states = key_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
161+ value_states = value_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
162+
161163 query_states , key_states , value_states = match_attn_matrices_for_parallel (
162164 query_states , key_states , value_states
163165 )
164- (attention_mask ,) = adjust_tensors_for_parallel (query_states , attention_mask )
166+ (attention_mask , position_ids ) = adjust_tensors_for_parallel (query_states , attention_mask , position_ids )
165167 # >>> END AH Changes <<<
166168
167169 kv_seq_len = key_states .shape [- 2 ]
@@ -310,15 +312,16 @@ def forward(
310312 key_states = self .k_proj (hidden_states )
311313 value_states = self .v_proj (hidden_states )
312314
313- query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
314- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
315- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
316-
317315 # >>> START AH Changes <<<
316+ # Loosen constraint on batch_size to allow parallel adapter composition
317+ query_states = query_states .view (- 1 , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
318+ key_states = key_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
319+ value_states = value_states .view (- 1 , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
320+
318321 query_states , key_states , value_states = match_attn_matrices_for_parallel (
319322 query_states , key_states , value_states
320323 )
321- (attention_mask ,) = adjust_tensors_for_parallel (query_states , attention_mask )
324+ (attention_mask , position_ids ) = adjust_tensors_for_parallel (query_states , attention_mask , position_ids )
322325 # >>> END AH Changes <<<
323326
324327 cos , sin = self .rotary_emb (value_states , position_ids )
0 commit comments