diff --git a/FlagEmbedding/inference/embedder/encoder_only/m3.py b/FlagEmbedding/inference/embedder/encoder_only/m3.py index 2082003a..6e582bf2 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/m3.py +++ b/FlagEmbedding/inference/embedder/encoder_only/m3.py @@ -363,9 +363,12 @@ def _process_token_weights(token_weights: np.ndarray, input_ids: list): return result def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): - # delte the vectors of padding tokens + # Remove padding and EOS token vectors + # Note: CLS is already excluded in colbert_embedding (last_hidden_state[:, 1:]) + # tokens_num includes CLS and EOS, but colbert_vecs excludes CLS + # So we use tokens_num - 2 to also exclude EOS tokens_num = np.sum(attention_mask) - return colbert_vecs[:tokens_num - 1] # we don't use the embedding of cls, so select tokens_num-1 + return colbert_vecs[:tokens_num - 2] # tokenize without padding to get the correct length all_inputs = [] diff --git a/tests/test_colbert_vecs.py b/tests/test_colbert_vecs.py new file mode 100644 index 00000000..7ce9dc20 --- /dev/null +++ b/tests/test_colbert_vecs.py @@ -0,0 +1,73 @@ +"""Test that _process_colbert_vecs correctly excludes special tokens.""" + +import numpy as np + + +def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list): + """Process colbert vectors to exclude special tokens. + + This is the fixed version that correctly excludes EOS token. + CLS is already excluded in colbert_embedding (last_hidden_state[:, 1:]). + """ + tokens_num = np.sum(attention_mask) + return colbert_vecs[:tokens_num - 2] + + +def test_process_colbert_vecs_excludes_eos(): + """Test that _process_colbert_vecs excludes EOS token. + + Scenario: + - Original sequence: [CLS, tok1, tok2, tok3, EOS, PAD, PAD] + - attention_mask: [1, 1, 1, 1, 1, 0, 0] (5 valid tokens) + - colbert_vecs already excludes CLS, so it's [tok1, tok2, tok3, EOS, PAD, PAD] + - Expected output: [tok1, tok2, tok3] (3 vectors, excluding EOS) + """ + # Simulate colbert_vecs after CLS removal (4 valid + 2 padding) + # Shape: (6, hidden_dim) where hidden_dim = 4 for testing + colbert_vecs = np.array([ + [1.0, 0.0, 0.0, 0.0], # tok1 + [0.0, 1.0, 0.0, 0.0], # tok2 + [0.0, 0.0, 1.0, 0.0], # tok3 + [0.0, 0.0, 0.0, 1.0], # EOS (should be excluded) + [0.0, 0.0, 0.0, 0.0], # PAD + [0.0, 0.0, 0.0, 0.0], # PAD + ]) + + # Original attention_mask (includes CLS position) + attention_mask = [1, 1, 1, 1, 1, 0, 0] # CLS, tok1, tok2, tok3, EOS, PAD, PAD + + result = _process_colbert_vecs(colbert_vecs, attention_mask) + + # Should return only tok1, tok2, tok3 (3 vectors) + assert result.shape[0] == 3, f"Expected 3 vectors, got {result.shape[0]}" + + # Verify the content + expected = np.array([ + [1.0, 0.0, 0.0, 0.0], # tok1 + [0.0, 1.0, 0.0, 0.0], # tok2 + [0.0, 0.0, 1.0, 0.0], # tok3 + ]) + np.testing.assert_array_equal(result, expected) + + +def test_process_colbert_vecs_single_token(): + """Test with minimal valid tokens (just CLS, one token, EOS).""" + colbert_vecs = np.array([ + [1.0, 0.0], # tok1 + [0.0, 1.0], # EOS + ]) + attention_mask = [1, 1, 1] # CLS, tok1, EOS + + result = _process_colbert_vecs(colbert_vecs, attention_mask) + + # Should return only tok1 + assert result.shape[0] == 1, f"Expected 1 vector, got {result.shape[0]}" + np.testing.assert_array_equal(result, np.array([[1.0, 0.0]])) + + +if __name__ == "__main__": + test_process_colbert_vecs_excludes_eos() + print("test_process_colbert_vecs_excludes_eos passed!") + test_process_colbert_vecs_single_token() + print("test_process_colbert_vecs_single_token passed!") + print("All tests passed!")