File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -1065,6 +1065,7 @@ def embed(
10651065
10661066 # get pooling information
10671067 pooling_type = self .pooling_type ()
1068+ logits_all = pooling_type == llama_cpp .LLAMA_POOLING_TYPE_NONE
10681069
10691070 if self .context_params .embeddings is False :
10701071 raise RuntimeError (
@@ -1142,7 +1143,7 @@ def decode_batch(seq_sizes: List[int]):
11421143 p_batch = 0
11431144
11441145 # add to batch
1145- self ._batch .add_sequence (tokens , p_batch )
1146+ self ._batch .add_sequence (tokens , p_batch , logits_all )
11461147
11471148 # update batch stats
11481149 s_batch .append (n_tokens )
Original file line number Diff line number Diff line change @@ -81,6 +81,7 @@ def test_real_model(llama_cpp_model_path):
8181 cparams .n_ubatch = 16
8282 cparams .n_threads = multiprocessing .cpu_count ()
8383 cparams .n_threads_batch = multiprocessing .cpu_count ()
84+ cparams .logits_all = False
8485 cparams .flash_attn = True
8586 cparams .swa_full = True
8687
@@ -103,7 +104,7 @@ def test_real_model(llama_cpp_model_path):
103104 result = tokens
104105 n_eval = 0
105106 for _ in range (4 ):
106- batch .set_batch (tokens , n_past = n_eval )
107+ batch .set_batch (tokens , n_past = n_eval , logits_all = False )
107108 context .decode (batch )
108109 n_eval += len (tokens )
109110 token_id = sampler .sample (context , - 1 )
@@ -122,8 +123,8 @@ def test_real_llama(llama_cpp_model_path):
122123 n_ubatch = 32 ,
123124 n_threads = multiprocessing .cpu_count (),
124125 n_threads_batch = multiprocessing .cpu_count (),
125- flash_attn = True ,
126126 logits_all = False ,
127+ flash_attn = True ,
127128 swa_full = True
128129 )
129130
You can’t perform that action at this time.
0 commit comments