Skip to content

Commit 8477ac0

Browse files
authored
[Fix][Model] Fix tensor parallelism of internlm2 model (#3058)
This PR fixes a bug that causes incorrect output when running internlm2 with tensor parallelism.
1 parent d23d6f5 commit 8477ac0

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

python/mlc_llm/model/internlm2/internlm2_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,17 @@ def __post_init__(self):
6868
logger.info(
6969
"%s defaults to %d",
7070
bold("prefill_chunk_size"),
71-
min(self.context_window_size, 8192),
71+
min(self.context_window_size, 2048),
7272
)
73-
self.prefill_chunk_size = min(self.context_window_size, 8192)
73+
self.prefill_chunk_size = min(self.context_window_size, 2048)
7474
elif self.prefill_chunk_size > self.context_window_size:
7575
logger.info(
7676
"Overriding %s from %d to %d",
7777
bold("prefill_chunk_size"),
7878
self.prefill_chunk_size,
79-
min(self.context_window_size, 8192),
79+
min(self.context_window_size, 2048),
8080
)
81-
self.prefill_chunk_size = min(self.context_window_size, 8192)
81+
self.prefill_chunk_size = min(self.context_window_size, 2048)
8282

8383

8484
# pylint: disable=invalid-name,missing-docstring
@@ -178,11 +178,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
178178
residual = hidden_states
179179
hidden_states = self.attention_norm(hidden_states)
180180
hidden_states = self.attention(hidden_states, paged_kv_cache, layer_id)
181-
hidden_states = self._apply_residual(residual, residual=hidden_states)
181+
hidden_states = self._apply_residual(hidden_states, residual=residual)
182182
residual = hidden_states
183183
hidden_states = self.ffn_norm(hidden_states)
184184
hidden_states = self.feed_forward(hidden_states)
185-
hidden_states = self._apply_residual(residual, residual=hidden_states)
185+
hidden_states = self._apply_residual(hidden_states, residual=residual)
186186
return hidden_states
187187

188188
def _apply_residual(self, out, residual):

0 commit comments

Comments
 (0)