Skip to content

Commit 14bfc1e

Browse files
committed
add test
Signed-off-by: HuiyingLi <[email protected]>
1 parent 3196c03 commit 14bfc1e

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/unit_tests/models/biencoder/test_state_dict_adapter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ def test_to_hf_only_lm_q_keys(self, adapter):
8181
assert "model.layer1.weight" in hf_state_dict
8282
assert "model.layer1.bias" in hf_state_dict
8383

84+
def test_to_hf_includes_linear_pooler(self, adapter):
85+
"""Pooler weights should be retained during HF conversion."""
86+
biencoder_state_dict = {
87+
"lm_q.layer.weight": torch.randn(2, 2),
88+
"linear_pooler.weight": torch.randn(4, 4),
89+
"linear_pooler.bias": torch.randn(4),
90+
}
91+
92+
hf_state_dict = adapter.to_hf(biencoder_state_dict)
93+
94+
assert "model.layer.weight" in hf_state_dict # sanity for lm_q path
95+
assert "linear_pooler.weight" in hf_state_dict
96+
assert "linear_pooler.bias" in hf_state_dict
97+
assert torch.equal(hf_state_dict["linear_pooler.weight"], biencoder_state_dict["linear_pooler.weight"])
98+
assert torch.equal(hf_state_dict["linear_pooler.bias"], biencoder_state_dict["linear_pooler.bias"])
99+
84100
def test_from_hf_basic(self, adapter):
85101
"""Test basic conversion from HuggingFace to biencoder format."""
86102
hf_state_dict = {
@@ -103,6 +119,23 @@ def test_from_hf_basic(self, adapter):
103119
assert torch.equal(biencoder_state_dict["lm_q.layer2.bias"], hf_state_dict["model.layer2.bias"])
104120
assert torch.equal(biencoder_state_dict["lm_p.layer2.bias"], hf_state_dict["model.layer2.bias"])
105121

122+
def test_from_hf_includes_linear_pooler(self, adapter):
123+
"""Pooler weights should be retained when converting from HF."""
124+
hf_state_dict = {
125+
"model.layer.weight": torch.randn(2, 2),
126+
"linear_pooler.weight": torch.randn(4, 4),
127+
"linear_pooler.bias": torch.randn(4),
128+
}
129+
130+
biencoder_state_dict = adapter.from_hf(hf_state_dict)
131+
132+
assert "lm_q.layer.weight" in biencoder_state_dict
133+
assert "lm_p.layer.weight" in biencoder_state_dict
134+
assert "linear_pooler.weight" in biencoder_state_dict
135+
assert "linear_pooler.bias" in biencoder_state_dict
136+
assert torch.equal(biencoder_state_dict["linear_pooler.weight"], hf_state_dict["linear_pooler.weight"])
137+
assert torch.equal(biencoder_state_dict["linear_pooler.bias"], hf_state_dict["linear_pooler.bias"])
138+
106139
def test_from_hf_empty_state_dict(self, adapter):
107140
"""Test conversion with empty state dict."""
108141
biencoder_state_dict = adapter.from_hf({})
@@ -153,6 +186,15 @@ def test_convert_single_tensor_to_hf_other(self, adapter):
153186

154187
assert result == []
155188

189+
def test_convert_single_tensor_to_hf_linear_pooler(self, adapter):
190+
"""Test converting linear_pooler tensor (should be passed through)."""
191+
tensor = torch.randn(4)
192+
result = adapter.convert_single_tensor_to_hf("linear_pooler.bias", tensor)
193+
194+
assert len(result) == 1
195+
assert result[0][0] == "linear_pooler.bias"
196+
assert torch.equal(result[0][1], tensor)
197+
156198
def test_convert_single_tensor_to_hf_with_kwargs(self, adapter):
157199
"""Test that convert_single_tensor_to_hf accepts kwargs."""
158200
tensor = torch.randn(10, 10)

0 commit comments

Comments
 (0)