@@ -360,7 +360,8 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
360360 _is_padding_mask : list [bool ]
361361
362362 def __new__ (cls , data = None , requires_grad = True ):
363- assert cls is FlatParameter , "subclasses FlatParameter not supported"
363+ if cls is not FlatParameter :
364+ raise AssertionError ("subclasses FlatParameter not supported" )
364365 r = nn .Parameter .__new__ (nn .Parameter , data , requires_grad ) # type: ignore[call-arg]
365366 r ._is_flat_param = True # type: ignore[attr-defined]
366367 return r
@@ -398,11 +399,26 @@ def _init_metadata(
398399 Args:
399400 See the Attributes in the class docstring.
400401 """
401- assert len (param_infos ) == len (shapes )
402- assert len (param_infos ) == len (strides )
403- assert len (param_infos ) == len (contiguities )
404- assert len (param_infos ) == len (fqns )
405- assert len (param_infos ) == len (param_extensions )
402+ if len (param_infos ) != len (shapes ):
403+ raise AssertionError (
404+ f"Expected param_infos length { len (param_infos )} to match shapes length { len (shapes )} "
405+ )
406+ if len (param_infos ) != len (strides ):
407+ raise AssertionError (
408+ f"Expected param_infos length { len (param_infos )} to match strides length { len (strides )} "
409+ )
410+ if len (param_infos ) != len (contiguities ):
411+ raise AssertionError (
412+ f"Expected param_infos length { len (param_infos )} to match contiguities length { len (contiguities )} "
413+ )
414+ if len (param_infos ) != len (fqns ):
415+ raise AssertionError (
416+ f"Expected param_infos length { len (param_infos )} to match fqns length { len (fqns )} "
417+ )
418+ if len (param_infos ) != len (param_extensions ):
419+ raise AssertionError (
420+ f"Expected param_infos length { len (param_infos )} to match param_extensions length { len (param_extensions )} "
421+ )
406422 self ._num_params = len (param_infos )
407423 self ._param_infos = param_infos
408424 self ._shapes = shapes
@@ -418,22 +434,32 @@ def _init_metadata(
418434 numels_without_padding .append (numel )
419435 self ._numels = tuple (numels_without_padding )
420436 self ._numels_with_padding = tuple (numels )
421- assert len (self ._numels ) == self ._num_params
437+ if len (self ._numels ) != self ._num_params :
438+ raise AssertionError (
439+ f"Expected _numels length { len (self ._numels )} to equal _num_params { self ._num_params } "
440+ )
422441
423442 self ._shared_param_infos = tuple (shared_param_infos )
424443 self ._modules = {pi .module for pi in self ._param_infos }.union (
425444 {spi .module for spi in self ._shared_param_infos }
426445 )
427- assert (params is None ) == (shared_params is None )
428- if params is not None :
429- assert shared_params is not None and len (shared_params ) == len (
430- shared_param_infos
446+ if (params is None ) != (shared_params is None ):
447+ raise AssertionError (
448+ "Expected params and shared_params to both be None or both be not None"
431449 )
450+ if params is not None :
451+ if shared_params is None or len (shared_params ) != len (shared_param_infos ):
452+ raise AssertionError (
453+ f"Expected shared_params to be not None and have length { len (shared_param_infos )} , got { shared_params } "
454+ )
432455 self ._params = []
433456 for param , is_padding in zip (params , is_padding_mask ):
434457 if not is_padding :
435458 self ._params .append (param )
436- self ._shared_params = shared_params
459+ if shared_params is not None :
460+ self ._shared_params = shared_params
461+ else :
462+ self ._shared_params = []
437463 # Mark the original parameters to avoid flattening them into
438464 # another `FlatParameter` during recursive construction
439465 for param in chain (self ._params , self ._shared_params ):
@@ -579,7 +605,8 @@ def __init__(
579605 # before `_init_flat_param()`, which performs the actual validation
580606 self ._orig_param_dtype = params [0 ].dtype
581607 self ._init_param_reduce_dtypes (mp_param_dtype , mp_reduce_dtype )
582- assert self ._fwd_bwd_param_dtype is not None # mypy
608+ if self ._fwd_bwd_param_dtype is None :
609+ raise AssertionError ("Expected _fwd_bwd_param_dtype to be not None" ) # mypy
583610 self ._aligned_numel = (
584611 _get_aligned_numel (unsharded_dtype = self ._fwd_bwd_param_dtype )
585612 if align_addresses
@@ -807,7 +834,8 @@ def _validate_tensors_to_flatten(
807834 dtype = tensor .dtype
808835 flat_param_requires_grad = flat_param_requires_grad or tensor .requires_grad
809836 device = tensor .device
810- assert flat_param_requires_grad is not None , "Requires non-empty `tensors` list"
837+ if flat_param_requires_grad is None :
838+ raise AssertionError ("Requires non-empty `tensors` list" )
811839 return dtype , flat_param_requires_grad , device
812840
813841 def flatten_tensors (
@@ -908,8 +936,10 @@ def _init_param_reduce_dtypes(
908936 else :
909937 self ._fwd_bwd_param_dtype = mp_param_dtype or self ._orig_param_dtype
910938 self ._reduce_dtype = mp_reduce_dtype or self ._orig_param_dtype
911- assert self ._fwd_bwd_param_dtype is not None
912- assert self ._reduce_dtype is not None
939+ if self ._fwd_bwd_param_dtype is None :
940+ raise AssertionError ("Expected _fwd_bwd_param_dtype to be not None" )
941+ if self ._reduce_dtype is None :
942+ raise AssertionError ("Expected _reduce_dtype to be not None" )
913943
914944 ###################################
915945 # SHARD INITIALIZATION & METADATA #
@@ -985,9 +1015,10 @@ def _init_shard_metadata(
9851015 shard_param_infos = self ._get_shard_metadata (
9861016 unsharded_start_idx , unsharded_end_idx
9871017 )
988- assert len (shard_param_infos ) == flat_param ._num_params , (
989- f"Expects length { flat_param ._num_params } but got { len (shard_param_infos )} "
990- )
1018+ if len (shard_param_infos ) != flat_param ._num_params :
1019+ raise AssertionError (
1020+ f"Expects length { flat_param ._num_params } but got { len (shard_param_infos )} "
1021+ )
9911022 flat_param ._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
9921023 flat_param ._shard_numel_padded = numel_padded # type: ignore[attr-defined]
9931024
@@ -1003,9 +1034,10 @@ def _get_shard_metadata(
10031034 unsharded flat parameter specifying the shard.
10041035 """
10051036 flat_param_offsets = self ._get_flat_param_offsets ()
1006- assert len (flat_param_offsets ) == len (self .flat_param ._numels_with_padding ), (
1007- f"Expected { len (self .flat_param ._numels_with_padding )} but got { len (flat_param_offsets )} "
1008- )
1037+ if len (flat_param_offsets ) != len (self .flat_param ._numels_with_padding ):
1038+ raise AssertionError (
1039+ f"Expected { len (self .flat_param ._numels_with_padding )} but got { len (flat_param_offsets )} "
1040+ )
10091041 shard_param_infos : list [_ShardParamInfo ] = []
10101042 sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
10111043 # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
@@ -1033,12 +1065,13 @@ def _get_shard_metadata(
10331065 unsharded_start_idx - unsharded_param_start_idx
10341066 )
10351067 offset_in_shard = 0
1036- assert (
1068+ if not (
10371069 offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
1038- ), (
1039- f"Invalid `offset_in_shard` of { offset_in_shard } for "
1040- f"sharded flat parameter with { sharded_flat_param_numel } numel"
1041- )
1070+ ):
1071+ raise AssertionError (
1072+ f"Invalid `offset_in_shard` of { offset_in_shard } for "
1073+ f"sharded flat parameter with { sharded_flat_param_numel } numel"
1074+ )
10421075 intra_param_end_idx = (
10431076 min (unsharded_param_end_idx , unsharded_end_idx )
10441077 - unsharded_param_start_idx
@@ -1082,9 +1115,10 @@ def _get_unpadded_shard(
10821115 else :
10831116 chunk = chunks [rank ]
10841117 numel_to_pad = chunks [0 ].numel () - chunk .numel ()
1085- assert numel_to_pad >= 0 , (
1086- "Chunk's size should be at most the first chunk's size"
1087- )
1118+ if numel_to_pad < 0 :
1119+ raise AssertionError (
1120+ "Chunk's size should be at most the first chunk's size"
1121+ )
10881122 return chunk , numel_to_pad
10891123
10901124 @staticmethod
@@ -1115,12 +1149,16 @@ def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
11151149 This requires ``tensor`` to have 1D shape and ensures that the returned
11161150 shape is 1D.
11171151 """
1118- assert len (tensor .shape ) == 1 , f"{ tensor .shape } "
1152+ if len (tensor .shape ) != 1 :
1153+ raise AssertionError (f"Expected 1D tensor shape, got { tensor .shape } " )
11191154 unpadded_sharded_tensor , numel_to_pad = FlatParamHandle ._get_unpadded_shard (
11201155 tensor , rank , world_size
11211156 )
11221157 unpadded_sharded_size = unpadded_sharded_tensor .size ()
1123- assert len (unpadded_sharded_size ) == 1 , f"{ unpadded_sharded_size } "
1158+ if len (unpadded_sharded_size ) != 1 :
1159+ raise AssertionError (
1160+ f"Expected 1D unpadded_sharded_size, got { unpadded_sharded_size } "
1161+ )
11241162 return torch .Size ([unpadded_sharded_size [0 ] + numel_to_pad ])
11251163
11261164 def _get_flat_param_offsets (self ) -> list [tuple [int , int ]]:
@@ -2059,7 +2097,7 @@ def _use_unsharded_grad_views(self) -> None:
20592097 _p_assert (
20602098 hasattr (module , param_name ),
20612099 f"{ module_name + '.' + param_name if module_name else param_name } is missing" ,
2062- ) # did not save FQN info in `_shared_param_infos`
2100+ )
20632101 param = getattr (module , param_name )
20642102 prim_param = getattr (prim_module , prim_param_name )
20652103 if (
@@ -2130,7 +2168,8 @@ def _use_sharded_views(self) -> None:
21302168 offset = shard_param_info .offset_in_shard
21312169 numel_in_shard = shard_param_info .numel_in_shard
21322170 param .data = flat_param [offset : offset + numel_in_shard ]
2133- assert self .flat_param ._shared_params is not None
2171+ if self .flat_param ._shared_params is None :
2172+ raise AssertionError ("Expected _shared_params to be not None" )
21342173 for i , (
21352174 param ,
21362175 (param_name , module , _ , prim_param_name , prim_module , _ ),
@@ -2194,7 +2233,8 @@ def _use_sharded_grad_views(self) -> None:
21942233 )
21952234 else :
21962235 param .grad = None
2197- assert flat_param ._shared_params is not None
2236+ if flat_param ._shared_params is None :
2237+ raise AssertionError ("Expected _shared_params to be not None" )
21982238 for param , (_ , _ , _ , prim_param_name , prim_module , _ ) in zip (
21992239 flat_param ._shared_params , flat_param ._shared_param_infos
22002240 ):
@@ -2408,7 +2448,8 @@ def _writeback_tensor(
24082448 dst_tensor [offset : offset + expected_shape .numel ()].copy_ (src_tensor )
24092449 else :
24102450 dst_tensor [offset : offset + expected_shape .numel ()].zero_ ()
2411- assert self .flat_param ._is_grad_none_mask is not None
2451+ if self .flat_param ._is_grad_none_mask is None :
2452+ raise AssertionError ("Expected _is_grad_none_mask to be not None" )
24122453 self .flat_param ._is_grad_none_mask [tensor_index ] = True
24132454
24142455 def _reset_flat_param_grad_info_if_needed (self ):
@@ -2427,7 +2468,8 @@ def _reset_flat_param_grad_info_if_needed(self):
24272468 if not self ._use_orig_params :
24282469 return
24292470 flat_param = self .flat_param
2430- assert flat_param ._params is not None # mypy
2471+ if flat_param ._params is None :
2472+ raise AssertionError ("Expected _params to be not None" ) # mypy
24312473 all_grad_none = True
24322474 requires_grad = False
24332475 for param in flat_param ._params :
@@ -2571,12 +2613,16 @@ def _reset_is_grad_none(self) -> None:
25712613 "Expects to only be called in the post-backward after gradient computation" ,
25722614 )
25732615 flat_param = self .flat_param
2574- assert flat_param ._params is not None # mypy
2616+ if flat_param ._params is None :
2617+ raise AssertionError ("Expected _params to be not None" ) # mypy
25752618 for i , param in enumerate (flat_param ._params ): # type: ignore[arg-type]
25762619 # As long as the parameter requires gradient, it should receive a
25772620 # meaningful gradient (even if the gradient happens to be zeros)
25782621 if param .requires_grad :
2579- assert flat_param ._is_grad_none_mask is not None # mypy
2622+ if flat_param ._is_grad_none_mask is None :
2623+ raise AssertionError (
2624+ "Expected _is_grad_none_mask to be not None"
2625+ ) # mypy
25802626 flat_param ._is_grad_none_mask [i ] = False
25812627
25822628 #######################
0 commit comments