Skip to content

Commit ac8695f

Browse files
author
tangbinhan
committed
fix paddle from_dlpack error
1 parent e32703f commit ac8695f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

fastsafetensors/frameworks/_paddle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
dtype_convert: Dict[DType, Any] = {
2323
DType.BOOL: paddle.bool,
24-
DType.I8: paddle.uint8,
24+
DType.U8: paddle.uint8,
2525
DType.I8: paddle.int8,
2626
DType.I16: paddle.int16,
2727
DType.U16: paddle.bfloat16,
@@ -37,6 +37,7 @@
3737
need_workaround_dtypes: Dict[DType, DType] = {
3838
DType.F8_E5M2: DType.I8,
3939
DType.F8_E4M3: DType.I8,
40+
DType.U16: DType.BF16,
4041
}
4142

4243
if hasattr(paddle, "float8_e5m2"):
@@ -198,7 +199,7 @@ def concat_tensors(self, tensors: List[PaddleTensor], dim) -> PaddleTensor:
198199
def get_dtype_size(self, dtype: DType) -> int:
199200
return paddle_core.size_of_dtype(dtype_convert[dtype])
200201

201-
def from_dlpack(self, dl_tensor: Any, device: Device, dtype: DType) -> PaddleTensor:
202+
def from_dlpack(self, dl_tensor: Any, device: Device, dtype) -> PaddleTensor:
202203
return PaddleTensor(device, dtype, paddle.utils.dlpack.from_dlpack(dl_tensor))
203204

204205
def copy_tensor(self, dst: PaddleTensor, src: PaddleTensor) -> None:

0 commit comments

Comments
 (0)