Skip to content

Commit c5c84af

Browse files
authored
Fixkeybuffer (#2512)
* fix buffers loading (awq)
1 parent da2fe8f commit c5c84af

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

onmt/models/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def count_parameters(self, log=print):
4646
raise NotImplementedError
4747

4848
def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset):
49-
5049
if name.split(".")[-1] in [
5150
"linear_keys",
5251
"linear_values",
@@ -73,7 +72,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
7372
row_slice_start:row_slice_end,
7473
].size()
7574
), "An error in model's partition and checkpoint's slice was detected"
76-
if param_name in buf_list:
75+
if name + "." + param_name in buf_list:
7776
module.register_buffer(
7877
param_name,
7978
ckpt_t[
@@ -90,7 +89,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
9089
assert (
9190
param.data.size() == ckpt_t[col_slice_start:col_slice_end].size()
9291
), "An error in model's partition and checkpoint's slice was detected"
93-
if param_name in buf_list:
92+
if name + "." + param_name in buf_list:
9493
module.register_buffer(
9594
param_name, ckpt_t[col_slice_start:col_slice_end]
9695
)
@@ -120,9 +119,9 @@ def load_state_dict(
120119
if device == torch.device("cpu"):
121120
offset = 0
122121
buf_list = []
122+
for buf_name, buf in self.named_buffers():
123+
buf_list.append(buf_name)
123124
for name, module in self.named_modules():
124-
for buf_name, buf in module.named_buffers():
125-
buf_list.append(buf_name)
126125
named_buf_and_param = list(module.named_buffers()) + list(
127126
module.named_parameters()
128127
)
@@ -205,9 +204,9 @@ def load_safe_state_dict(
205204
if device == torch.device("cpu"):
206205
offset = 0
207206
buf_list = []
207+
for buf_name, buf in self.named_buffers():
208+
buf_list.append(buf_name)
208209
for name, module in self.named_modules():
209-
for buf_name, buf in module.named_buffers():
210-
buf_list.append(buf_name)
211210
named_buf_and_param = list(module.named_buffers()) + list(
212211
module.named_parameters()
213212
)

0 commit comments

Comments
 (0)