Skip to content

Commit 4d319d4

Browse files
authored
fix: add pooler weights to biencoder state dict adapter (#998)
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 272c9d7 commit 4d319d4

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

nemo_automodel/components/models/biencoder/state_dict_adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
4848
if key.startswith("lm_q."):
4949
new_key = key.replace("lm_q.", "model.")
5050
hf_state_dict[new_key] = value
51+
elif key.startswith("linear_pooler."):
52+
hf_state_dict[key] = value
5153

5254
return hf_state_dict
5355

@@ -76,6 +78,8 @@ def from_hf(
7678
biencoder_state_dict[new_key_q] = value
7779
new_key_p = key.replace("model.", "lm_p.")
7880
biencoder_state_dict[new_key_p] = value
81+
elif key.startswith("linear_pooler."):
82+
biencoder_state_dict[key] = value
7983

8084
return biencoder_state_dict
8185

@@ -94,6 +98,8 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
9498
if fqn.startswith("lm_q."):
9599
new_fqn = fqn.replace("lm_q.", "model.")
96100
return [(new_fqn, tensor)]
101+
if fqn.startswith("linear_pooler."):
102+
return [(fqn, tensor)]
97103

98104
# Skip tensors that are not part of lm_q
99105
return []

tests/unit_tests/models/biencoder/test_state_dict_adapter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ def test_to_hf_only_lm_q_keys(self, adapter):
8181
assert "model.layer1.weight" in hf_state_dict
8282
assert "model.layer1.bias" in hf_state_dict
8383

84+
def test_to_hf_includes_linear_pooler(self, adapter):
85+
"""Pooler weights should be retained during HF conversion."""
86+
biencoder_state_dict = {
87+
"lm_q.layer.weight": torch.randn(2, 2),
88+
"linear_pooler.weight": torch.randn(4, 4),
89+
"linear_pooler.bias": torch.randn(4),
90+
}
91+
92+
hf_state_dict = adapter.to_hf(biencoder_state_dict)
93+
94+
assert "model.layer.weight" in hf_state_dict # sanity for lm_q path
95+
assert "linear_pooler.weight" in hf_state_dict
96+
assert "linear_pooler.bias" in hf_state_dict
97+
assert torch.equal(hf_state_dict["linear_pooler.weight"], biencoder_state_dict["linear_pooler.weight"])
98+
assert torch.equal(hf_state_dict["linear_pooler.bias"], biencoder_state_dict["linear_pooler.bias"])
99+
84100
def test_from_hf_basic(self, adapter):
85101
"""Test basic conversion from HuggingFace to biencoder format."""
86102
hf_state_dict = {
@@ -103,6 +119,23 @@ def test_from_hf_basic(self, adapter):
103119
assert torch.equal(biencoder_state_dict["lm_q.layer2.bias"], hf_state_dict["model.layer2.bias"])
104120
assert torch.equal(biencoder_state_dict["lm_p.layer2.bias"], hf_state_dict["model.layer2.bias"])
105121

122+
def test_from_hf_includes_linear_pooler(self, adapter):
123+
"""Pooler weights should be retained when converting from HF."""
124+
hf_state_dict = {
125+
"model.layer.weight": torch.randn(2, 2),
126+
"linear_pooler.weight": torch.randn(4, 4),
127+
"linear_pooler.bias": torch.randn(4),
128+
}
129+
130+
biencoder_state_dict = adapter.from_hf(hf_state_dict)
131+
132+
assert "lm_q.layer.weight" in biencoder_state_dict
133+
assert "lm_p.layer.weight" in biencoder_state_dict
134+
assert "linear_pooler.weight" in biencoder_state_dict
135+
assert "linear_pooler.bias" in biencoder_state_dict
136+
assert torch.equal(biencoder_state_dict["linear_pooler.weight"], hf_state_dict["linear_pooler.weight"])
137+
assert torch.equal(biencoder_state_dict["linear_pooler.bias"], hf_state_dict["linear_pooler.bias"])
138+
106139
def test_from_hf_empty_state_dict(self, adapter):
107140
"""Test conversion with empty state dict."""
108141
biencoder_state_dict = adapter.from_hf({})
@@ -153,6 +186,15 @@ def test_convert_single_tensor_to_hf_other(self, adapter):
153186

154187
assert result == []
155188

189+
def test_convert_single_tensor_to_hf_linear_pooler(self, adapter):
190+
"""Test converting linear_pooler tensor (should be passed through)."""
191+
tensor = torch.randn(4)
192+
result = adapter.convert_single_tensor_to_hf("linear_pooler.bias", tensor)
193+
194+
assert len(result) == 1
195+
assert result[0][0] == "linear_pooler.bias"
196+
assert torch.equal(result[0][1], tensor)
197+
156198
def test_convert_single_tensor_to_hf_with_kwargs(self, adapter):
157199
"""Test that convert_single_tensor_to_hf accepts kwargs."""
158200
tensor = torch.randn(10, 10)

0 commit comments

Comments
 (0)