44try :
55 import paddle
66 import paddle .distributed as pdist
7- from paddle .framework import core as paddle_core
87 from paddle .distributed .communication .group import Group
8+ from paddle .framework import core as paddle_core
99except ImportError as e :
1010 raise ImportError (
1111 "could not import paddle, paddle_core, or numpy. Please install them."
3838 DType .F8_E4M3 : DType .I8 ,
3939}
4040
41- if hasattr (paddle , ' float8_e5m2' ):
41+ if hasattr (paddle , " float8_e5m2" ):
4242 dtype_convert [DType .F8_E5M2 ] = paddle .float8_e5m2
43- if hasattr (paddle , ' float8_e4m3fn' ):
43+ if hasattr (paddle , " float8_e4m3fn" ):
4444 dtype_convert [DType .F8_E4M3 ] = paddle .float8_e4m3fn
4545
46+
4647@dataclass
4748class PaddleTensor (TensorBase ):
4849 real_tensor : paddle .Tensor
@@ -222,7 +223,9 @@ def as_workaround_dtype(self, dtype: DType) -> DType:
222223
223224 def get_process_group (self , pg : Optional [Any ]) -> PaddleProcessGroup :
224225 if pg is not None and not isinstance (pg , Group ):
225- raise Exception ("pg must be an instance of paddle.distributed.communication.group.Group" )
226+ raise Exception (
227+ "pg must be an instance of paddle.distributed.communication.group.Group"
228+ )
226229 return PaddleProcessGroup (pg )
227230
228231 # for testing
@@ -232,7 +235,11 @@ def is_equal(self, wrapped: PaddleTensor, real: Any) -> bool:
232235 raise Exception ("real is not paddle.Tensor" )
233236
234237 def randn (self , s : tuple , device : Device , dtype : DType ) -> PaddleTensor :
235- return PaddleTensor (device , dtype , paddle .randn (s , dtype = dtype_convert [dtype ]).to (device = device .as_str ()))
238+ return PaddleTensor (
239+ device ,
240+ dtype ,
241+ paddle .randn (s , dtype = dtype_convert [dtype ]).to (device = device .as_str ()),
242+ )
236243
237244 def support_fp8 (self ) -> bool :
238- return DType .F8_E5M2 in dtype_convert
245+ return DType .F8_E5M2 in dtype_convert
0 commit comments