Skip to content

Commit e3ad344

Browse files
committed
chore: change ratio for hybrid dbfs
1 parent 630223c commit e3ad344

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

examples/inference/embedder/encoder_only/m3_single_device_ensemble.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55

66

77
def pad_colbert_vecs(colbert_vecs_list, device):
8+
"""
9+
Since ColBERT embeddings are computed on a token-level basis, each document (or query)
10+
may produce a different number of token embeddings. This function aligns all embeddings
11+
to the same length by padding shorter sequences with zeros, ensuring that every input
12+
ends up with a uniform shape.
13+
14+
Steps:
15+
1. Determine the maximum sequence length (i.e., the largest number of tokens in any
16+
query or passage within the batch).
17+
2. For each set of token embeddings, pad it with zeros until it matches the max
18+
sequence length. Zeros here act as placeholders and do not affect the similarity
19+
computations since they represent "no token."
20+
3. Convert all padded embeddings into a single, consistent tensor and move it to the
21+
specified device (e.g., GPU) for efficient batch computation.
22+
23+
By performing this padding operation, subsequent tensor operations (like the einsum
24+
computations for ColBERT scoring) become simpler and more efficient, as all sequences
25+
share a common shape.
26+
"""
27+
828
lengths = [vec.shape[0] for vec in colbert_vecs_list]
929
max_len = max(lengths)
1030
dim = colbert_vecs_list[0].shape[1]
@@ -18,18 +38,57 @@ def pad_colbert_vecs(colbert_vecs_list, device):
1838

1939

2040
def compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs):
21-
# query_colbert_vecs: (Q, Tq, D)
22-
# passage_colbert_vecs: (P, Tp, D)
23-
# einsum 식에서 q:queries, p:passages, r:query tokens dim, c:passage tokens dim, d:embedding dim
41+
"""
42+
Compute ColBERT scores:
43+
44+
ColBERT (Contextualized Late Interaction over BERT) evaluates the similarity
45+
between a query and a passage at the token level. Instead of producing a single
46+
dense vector for each query or passage, ColBERT maintains embeddings for every
47+
token. This allows for finer-grained matching, capturing more subtle similarities.
48+
49+
Definitions of variables:
50+
- q: Number of queries (Q)
51+
- p: Number of passages (P)
52+
- r: Number of tokens in each query (Tq)
53+
- c: Number of tokens in each passage (Tp)
54+
- d: Embedding dimension (D)
55+
56+
I used the operation `einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs)`:
57+
- einsum (Einstein summation) is a powerful notation and function for
58+
expressing and computing multi-dimensional tensor contractions. It allows you
59+
to specify how dimensions in input tensors correspond to each other and how
60+
they should be combined (multiplied and summed) to produce the output.
61+
62+
In this particular case:
63+
- "qrd" corresponds to (Q, Tq, D) for query token embeddings.
64+
- "pcd" corresponds to (P, Tp, D) for passage token embeddings.
65+
- "qrd,pcd->qprc" means:
66+
1. For each query q and passage p, compute the dot product between every query token
67+
embedding (r) and every passage token embedding (c) across the embedding dimension d.
68+
2. This results in a (Q, P, Tq, Tp) tensor (qprc), where each element is the similarity
69+
score between a single query token and a single passage token.
70+
71+
After computing this full matrix of token-to-token scores:
72+
- We take the maximum over the passage token dimension (c) for each query token (r).
73+
This step identifies, for each query token, which passage token is the "best match."
74+
- Then we sum over all query tokens (r) to aggregate their best matches into a single
75+
score per query-passage pair.
76+
77+
In summary:
78+
1. einsum to get all pairwise token similarities.
79+
2. max over passage tokens to find the best matching passage token for each query token.
80+
3. sum over query tokens to combine all the best matches into a final ColBERT score
81+
for each query-passage pair.
82+
"""
83+
2484
dot_products = torch.einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs) # Q,P,Tq,Tp
25-
max_per_query_token, _ = dot_products.max(dim=3) # max over c (Tp)
26-
colbert_scores = max_per_query_token.sum(dim=2) # sum over r (Tq)
85+
max_per_query_token, _ = dot_products.max(dim=3)
86+
colbert_scores = max_per_query_token.sum(dim=2)
2787
return colbert_scores
2888

2989

30-
def hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores, weights=(0.33, 0.33, 0.34)):
90+
def hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores, weights=(0.45, 0.45, 0.1)):
3191
w_dense, w_sparse, w_colbert = weights
32-
# 모든 입력이 torch.Tensor일 경우 아래 연산 정상 작동
3392
return w_dense * dense_scores + w_sparse * sparse_scores + w_colbert * colbert_scores
3493

3594

@@ -42,12 +101,12 @@ def test_m3_single_device():
42101
)
43102

44103
queries = [
45-
"What is BGE M3?",
46-
"Defination of BM25"
104+
"What is Sionic AI?",
105+
"Try https://sionicstorm.ai today!"
47106
] * 100
48107
passages = [
49-
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
50-
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
108+
"Sionic AI delivers more accessible and cost-effective AI technology addressing the various needs to boost productivity and drive innovation.",
109+
"The Large Language Model (LLM) is not for research and experimentation. We offer solutions that leverage LLM to add value to your business. Anyone can easily train and control AI."
51110
] * 100
52111

53112
queries_embeddings = model.encode_queries(
@@ -56,36 +115,32 @@ def test_m3_single_device():
56115
return_sparse=True,
57116
return_colbert_vecs=True,
58117
)
118+
59119
passages_embeddings = model.encode_corpus(
60120
passages,
61121
return_dense=True,
62122
return_sparse=True,
63123
return_colbert_vecs=True,
64124
)
65125

66-
# device 설정
67126
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68127

69-
# dense_vecs, lexical_weights 등이 numpy array 형태일 수 있으므로 텐서로 변환
70128
q_dense = torch.tensor(queries_embeddings["dense_vecs"], dtype=torch.float, device=device)
71129
p_dense = torch.tensor(passages_embeddings["dense_vecs"], dtype=torch.float, device=device)
72130
dense_scores = q_dense @ p_dense.T
73131

74-
# sparse_scores도 numpy array를 텐서로 변환
75132
sparse_scores_np = model.compute_lexical_matching_score(
76133
queries_embeddings["lexical_weights"],
77134
passages_embeddings["lexical_weights"]
78135
)
136+
79137
sparse_scores = torch.tensor(sparse_scores_np, dtype=torch.float, device=device)
80138

81-
# colbert_vecs 패딩 후 텐서 변환
82139
query_colbert_vecs = pad_colbert_vecs(queries_embeddings["colbert_vecs"], device)
83140
passage_colbert_vecs = pad_colbert_vecs(passages_embeddings["colbert_vecs"], device)
84-
85141
colbert_scores = compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs)
86142

87-
# 모든 스코어가 torch.Tensor이므로 오류 없이 연산 가능
88-
hybrid_scores = hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores)
143+
hybrid_scores = hybrid_dbfs_ensemble_simple_linear_combination(dense_scores, sparse_scores, colbert_scores)
89144

90145
print("Dense score:\n", dense_scores[:2, :2])
91146
print("Sparse score:\n", sparse_scores[:2, :2])
@@ -95,11 +150,14 @@ def test_m3_single_device():
95150

96151
if __name__ == '__main__':
97152
test_m3_single_device()
153+
print("Expected Vector Scores")
98154
print("--------------------------------")
99-
print("Expected Output for Dense & Sparse (original):")
100155
print("Dense score:")
101156
print(" [[0.626 0.3477]\n [0.3496 0.678 ]]")
102157
print("Sparse score:")
103158
print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]")
159+
print("ColBERT score:")
160+
print("[[5.8061, 3.1195] \n [5.6822, 4.6513]]")
161+
print("Hybrid DBSF Ensemble score:")
162+
print("[[0.9822, 0.5125] \n [0.8127, 0.6958]]")
104163
print("--------------------------------")
105-
print("ColBERT and Hybrid DBSF scores will vary depending on the actual embeddings.")

0 commit comments

Comments
 (0)