Skip to content

Commit 1f5872c

Browse files
committed
fix save load contiguous
1 parent f25a0f3 commit 1f5872c

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

paddleformers/transformers/conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals
745745
return None
746746
if transpose:
747747
if isinstance(x, paddle.Tensor):
748-
x = paddle.transpose(x, [1, 0])
748+
x = paddle.transpose(x, [1, 0]).contiguous()
749749
else:
750750
x = np.transpose(x, [1, 0])
751751
if is_old_qkv:
@@ -1252,7 +1252,7 @@ def convert_transpose_selected_weights(state_dict: dict, transpose_weight_keys:
12521252
continue
12531253
for trans_key in transpose_weight_keys:
12541254
if re.search(f"\.{trans_key}\.weight$", key) or re.fullmatch(f"^{trans_key}\.weight$", key):
1255-
state_dict[key] = state_dict.pop(key).transpose([-1, -2])
1255+
state_dict[key] = state_dict.pop(key).transpose([-1, -2]).contiguous()
12561256
return state_dict
12571257

12581258
@classmethod

paddleformers/transformers/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _is_need_transpose(key):
397397

398398
def _transpose_hf_weight(key, weight):
399399
if _is_need_transpose(key):
400-
return weight.transpose([-1, -2])
400+
return weight.transpose([-1, -2]).contiguous()
401401
return weight
402402

403403
part_state_dict = {}

0 commit comments

Comments
 (0)