Skip to content

Commit b638352

Browse files
authored
fix transpose key (#2693)
1 parent 298aac8 commit b638352

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

paddleformers/transformers/model_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,12 @@ def _is_need_transpose(key):
397397

398398
def _transpose_hf_weight(key, weight):
399399
if _is_need_transpose(key):
400-
return np.ascontiguousarray(weight.transpose([-1, -2]))
400+
if isinstance(weight, np.ndarray):
401+
return np.ascontiguousarray(weight.transpose([-1, -2]))
402+
elif isinstance(weight, paddle.Tensor):
403+
return weight.transpose([-1, -2]).contiguous()
404+
else:
405+
raise ValueError(f"Unsupported weight type: {type(weight)}. Expected np.ndarray or paddle.Tensor")
401406
return weight
402407

403408
part_state_dict = {}

0 commit comments

Comments
 (0)