We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 298aac8 commit b638352Copy full SHA for b638352
paddleformers/transformers/model_utils.py
@@ -397,7 +397,12 @@ def _is_need_transpose(key):
397
398
def _transpose_hf_weight(key, weight):
399
if _is_need_transpose(key):
400
- return np.ascontiguousarray(weight.transpose([-1, -2]))
+ if isinstance(weight, np.ndarray):
401
+ return np.ascontiguousarray(weight.transpose([-1, -2]))
402
+ elif isinstance(weight, paddle.Tensor):
403
+ return weight.transpose([-1, -2]).contiguous()
404
+ else:
405
+ raise ValueError(f"Unsupported weight type: {type(weight)}. Expected np.ndarray or paddle.Tensor")
406
return weight
407
408
part_state_dict = {}
0 commit comments