@@ -68,17 +68,17 @@ def __post_init__(self):
68
68
logger .info (
69
69
"%s defaults to %d" ,
70
70
bold ("prefill_chunk_size" ),
71
- min (self .context_window_size , 8192 ),
71
+ min (self .context_window_size , 2048 ),
72
72
)
73
- self .prefill_chunk_size = min (self .context_window_size , 8192 )
73
+ self .prefill_chunk_size = min (self .context_window_size , 2048 )
74
74
elif self .prefill_chunk_size > self .context_window_size :
75
75
logger .info (
76
76
"Overriding %s from %d to %d" ,
77
77
bold ("prefill_chunk_size" ),
78
78
self .prefill_chunk_size ,
79
- min (self .context_window_size , 8192 ),
79
+ min (self .context_window_size , 2048 ),
80
80
)
81
- self .prefill_chunk_size = min (self .context_window_size , 8192 )
81
+ self .prefill_chunk_size = min (self .context_window_size , 2048 )
82
82
83
83
84
84
# pylint: disable=invalid-name,missing-docstring
@@ -178,11 +178,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
178
178
residual = hidden_states
179
179
hidden_states = self .attention_norm (hidden_states )
180
180
hidden_states = self .attention (hidden_states , paged_kv_cache , layer_id )
181
- hidden_states = self ._apply_residual (residual , residual = hidden_states )
181
+ hidden_states = self ._apply_residual (hidden_states , residual = residual )
182
182
residual = hidden_states
183
183
hidden_states = self .ffn_norm (hidden_states )
184
184
hidden_states = self .feed_forward (hidden_states )
185
- hidden_states = self ._apply_residual (residual , residual = hidden_states )
185
+ hidden_states = self ._apply_residual (hidden_states , residual = residual )
186
186
return hidden_states
187
187
188
188
def _apply_residual (self , out , residual ):
0 commit comments