Skip to content

Commit 3196c03

Browse files
committed
add pooler weights to sdadapter
Signed-off-by: HuiyingLi <[email protected]>
1 parent 1d42deb commit 3196c03

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

nemo_automodel/components/models/biencoder/state_dict_adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
4848
if key.startswith("lm_q."):
4949
new_key = key.replace("lm_q.", "model.")
5050
hf_state_dict[new_key] = value
51+
elif key.startswith("linear_pooler."):
52+
hf_state_dict[key] = value
5153

5254
return hf_state_dict
5355

@@ -76,6 +78,8 @@ def from_hf(
7678
biencoder_state_dict[new_key_q] = value
7779
new_key_p = key.replace("model.", "lm_p.")
7880
biencoder_state_dict[new_key_p] = value
81+
elif key.startswith("linear_pooler."):
82+
biencoder_state_dict[key] = value
7983

8084
return biencoder_state_dict
8185

@@ -94,6 +98,8 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
9498
if fqn.startswith("lm_q."):
9599
new_fqn = fqn.replace("lm_q.", "model.")
96100
return [(new_fqn, tensor)]
101+
if fqn.startswith("linear_pooler."):
102+
return [(fqn, tensor)]
97103

98104
# Skip tensors that are not part of lm_q
99105
return []

0 commit comments

Comments
 (0)