@@ -81,6 +81,22 @@ def test_to_hf_only_lm_q_keys(self, adapter):
8181 assert "model.layer1.weight" in hf_state_dict
8282 assert "model.layer1.bias" in hf_state_dict
8383
84+ def test_to_hf_includes_linear_pooler (self , adapter ):
85+ """Pooler weights should be retained during HF conversion."""
86+ biencoder_state_dict = {
87+ "lm_q.layer.weight" : torch .randn (2 , 2 ),
88+ "linear_pooler.weight" : torch .randn (4 , 4 ),
89+ "linear_pooler.bias" : torch .randn (4 ),
90+ }
91+
92+ hf_state_dict = adapter .to_hf (biencoder_state_dict )
93+
94+ assert "model.layer.weight" in hf_state_dict # sanity for lm_q path
95+ assert "linear_pooler.weight" in hf_state_dict
96+ assert "linear_pooler.bias" in hf_state_dict
97+ assert torch .equal (hf_state_dict ["linear_pooler.weight" ], biencoder_state_dict ["linear_pooler.weight" ])
98+ assert torch .equal (hf_state_dict ["linear_pooler.bias" ], biencoder_state_dict ["linear_pooler.bias" ])
99+
84100 def test_from_hf_basic (self , adapter ):
85101 """Test basic conversion from HuggingFace to biencoder format."""
86102 hf_state_dict = {
@@ -103,6 +119,23 @@ def test_from_hf_basic(self, adapter):
103119 assert torch .equal (biencoder_state_dict ["lm_q.layer2.bias" ], hf_state_dict ["model.layer2.bias" ])
104120 assert torch .equal (biencoder_state_dict ["lm_p.layer2.bias" ], hf_state_dict ["model.layer2.bias" ])
105121
122+ def test_from_hf_includes_linear_pooler (self , adapter ):
123+ """Pooler weights should be retained when converting from HF."""
124+ hf_state_dict = {
125+ "model.layer.weight" : torch .randn (2 , 2 ),
126+ "linear_pooler.weight" : torch .randn (4 , 4 ),
127+ "linear_pooler.bias" : torch .randn (4 ),
128+ }
129+
130+ biencoder_state_dict = adapter .from_hf (hf_state_dict )
131+
132+ assert "lm_q.layer.weight" in biencoder_state_dict
133+ assert "lm_p.layer.weight" in biencoder_state_dict
134+ assert "linear_pooler.weight" in biencoder_state_dict
135+ assert "linear_pooler.bias" in biencoder_state_dict
136+ assert torch .equal (biencoder_state_dict ["linear_pooler.weight" ], hf_state_dict ["linear_pooler.weight" ])
137+ assert torch .equal (biencoder_state_dict ["linear_pooler.bias" ], hf_state_dict ["linear_pooler.bias" ])
138+
106139 def test_from_hf_empty_state_dict (self , adapter ):
107140 """Test conversion with empty state dict."""
108141 biencoder_state_dict = adapter .from_hf ({})
@@ -153,6 +186,15 @@ def test_convert_single_tensor_to_hf_other(self, adapter):
153186
154187 assert result == []
155188
189+ def test_convert_single_tensor_to_hf_linear_pooler (self , adapter ):
190+ """Test converting linear_pooler tensor (should be passed through)."""
191+ tensor = torch .randn (4 )
192+ result = adapter .convert_single_tensor_to_hf ("linear_pooler.bias" , tensor )
193+
194+ assert len (result ) == 1
195+ assert result [0 ][0 ] == "linear_pooler.bias"
196+ assert torch .equal (result [0 ][1 ], tensor )
197+
156198 def test_convert_single_tensor_to_hf_with_kwargs (self , adapter ):
157199 """Test that convert_single_tensor_to_hf accepts kwargs."""
158200 tensor = torch .randn (10 , 10 )
0 commit comments