From 33b10358a0c64e6399dad59bf0eef4c5a2aa76ad Mon Sep 17 00:00:00 2001 From: LEON <2194369201@qq.com> Date: Thu, 4 Jul 2024 20:03:41 +0800 Subject: [PATCH] Update flag_models.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 109 行 torch.stack 修改为 torch.concatenate,这里的逻辑是合并多卡处理的 embeddings,所以应该是 stack,而且 convert_to_numpy 里的逻辑就是np.concatenate --- FlagEmbedding/flag_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/FlagEmbedding/flag_models.py b/FlagEmbedding/flag_models.py index 121f6533..74932bb1 100644 --- a/FlagEmbedding/flag_models.py +++ b/FlagEmbedding/flag_models.py @@ -106,7 +106,7 @@ def encode(self, if convert_to_numpy: all_embeddings = np.concatenate(all_embeddings, axis=0) else: - all_embeddings = torch.stack(all_embeddings) + all_embeddings = torch.concatenate(all_embeddings) if input_was_string: return all_embeddings[0]