Skip to content

Commit c4565c3

Browse files
RohitRathore1pytorchmergebot
authored andcommitted
[distributed] Replace 164 assert statements in fsdp directory (pytorch#165235)
Replace assert statements with explicit if/raise patterns across 20 files: - _optim_utils.py (38 asserts) - _flat_param.py (25 asserts) - _fully_shard/_fsdp_param.py (23 asserts) - sharded_grad_scaler.py (12 asserts) - fully_sharded_data_parallel.py (11 asserts) - wrap.py (10 asserts) - _state_dict_utils.py (9 asserts) - _fully_shard/_fsdp_param_group.py (8 asserts) - _runtime_utils.py (6 asserts) - _init_utils.py (6 asserts) - 10 additional files (16 asserts) This prevents assertions from being disabled with Python -O flag. Fixes partially pytorch#164878 Pull Request resolved: pytorch#165235 Approved by: https://github.com/albanD
1 parent 6918f17 commit c4565c3

20 files changed

+595
-328
lines changed

torch/distributed/fsdp/_common_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,10 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
203203
# handles, meaning no entry in `_fully_sharded_module_to_handles`
204204
if state._handle is None:
205205
return None
206-
assert module in state._fully_sharded_module_to_handle, (
207-
f"Expects a fully sharded module but got {module} on rank {state.rank}"
208-
)
206+
if module not in state._fully_sharded_module_to_handle:
207+
raise AssertionError(
208+
f"Expects a fully sharded module but got {module} on rank {state.rank}"
209+
)
209210
return state._fully_sharded_module_to_handle[module]
210211
else:
211212
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
@@ -258,9 +259,10 @@ def _named_parameters_with_duplicates(
258259
This API is required as some modules overwrite `named_parameters()` but do not support
259260
`remove_duplicate`.
260261
"""
261-
assert "remove_duplicate" not in kwargs, (
262-
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
263-
)
262+
if "remove_duplicate" in kwargs:
263+
raise AssertionError(
264+
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
265+
)
264266
kwargs["remove_duplicate"] = False
265267
try:
266268
ret = list(module.named_parameters(**kwargs))

torch/distributed/fsdp/_debug_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def reset(cls) -> None:
3939
@classmethod
4040
@contextmanager
4141
def profile(cls, profile_type: str) -> Iterator[None]:
42-
assert profile_type not in cls.profiling, (
43-
f"{profile_type} is already being profiled. "
44-
"SimpleProfiler does not support profiling multiple instances at "
45-
"the same time. "
46-
)
42+
if profile_type in cls.profiling:
43+
raise AssertionError(
44+
f"{profile_type} is already being profiled. "
45+
"SimpleProfiler does not support profiling multiple instances at "
46+
"the same time. "
47+
)
4748

4849
cls.profiling.add(profile_type)
4950
begin = time.monotonic()
@@ -129,7 +130,8 @@ def module_fn(
129130

130131
if handle:
131132
param = handle.flat_param
132-
assert isinstance(param, flat_param_file.FlatParameter)
133+
if not isinstance(param, flat_param_file.FlatParameter):
134+
raise AssertionError(f"Expected FlatParameter, got {type(param)}")
133135
global_fqns = [
134136
clean_tensor_name(prefix + name) for name in param._fqns
135137
] # prefixed from the top level `model` (i.e. including `prefix`)

torch/distributed/fsdp/_exec_order_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
214214
# parameters
215215
# TODO (awgu): Since every module has at most one handle in the
216216
# current implementation, this should never raise the error.
217-
assert self.world_size is not None # mypy
217+
if self.world_size is None:
218+
raise AssertionError("Expected world_size to not be None")
218219
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
219220
# TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
220221
# tensor comparison control flow.

torch/distributed/fsdp/_flat_param.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
360360
_is_padding_mask: list[bool]
361361

362362
def __new__(cls, data=None, requires_grad=True):
363-
assert cls is FlatParameter, "subclasses FlatParameter not supported"
363+
if cls is not FlatParameter:
364+
raise AssertionError("subclasses FlatParameter not supported")
364365
r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
365366
r._is_flat_param = True # type: ignore[attr-defined]
366367
return r
@@ -398,11 +399,26 @@ def _init_metadata(
398399
Args:
399400
See the Attributes in the class docstring.
400401
"""
401-
assert len(param_infos) == len(shapes)
402-
assert len(param_infos) == len(strides)
403-
assert len(param_infos) == len(contiguities)
404-
assert len(param_infos) == len(fqns)
405-
assert len(param_infos) == len(param_extensions)
402+
if len(param_infos) != len(shapes):
403+
raise AssertionError(
404+
f"Expected param_infos length {len(param_infos)} to match shapes length {len(shapes)}"
405+
)
406+
if len(param_infos) != len(strides):
407+
raise AssertionError(
408+
f"Expected param_infos length {len(param_infos)} to match strides length {len(strides)}"
409+
)
410+
if len(param_infos) != len(contiguities):
411+
raise AssertionError(
412+
f"Expected param_infos length {len(param_infos)} to match contiguities length {len(contiguities)}"
413+
)
414+
if len(param_infos) != len(fqns):
415+
raise AssertionError(
416+
f"Expected param_infos length {len(param_infos)} to match fqns length {len(fqns)}"
417+
)
418+
if len(param_infos) != len(param_extensions):
419+
raise AssertionError(
420+
f"Expected param_infos length {len(param_infos)} to match param_extensions length {len(param_extensions)}"
421+
)
406422
self._num_params = len(param_infos)
407423
self._param_infos = param_infos
408424
self._shapes = shapes
@@ -418,22 +434,32 @@ def _init_metadata(
418434
numels_without_padding.append(numel)
419435
self._numels = tuple(numels_without_padding)
420436
self._numels_with_padding = tuple(numels)
421-
assert len(self._numels) == self._num_params
437+
if len(self._numels) != self._num_params:
438+
raise AssertionError(
439+
f"Expected _numels length {len(self._numels)} to equal _num_params {self._num_params}"
440+
)
422441

423442
self._shared_param_infos = tuple(shared_param_infos)
424443
self._modules = {pi.module for pi in self._param_infos}.union(
425444
{spi.module for spi in self._shared_param_infos}
426445
)
427-
assert (params is None) == (shared_params is None)
428-
if params is not None:
429-
assert shared_params is not None and len(shared_params) == len(
430-
shared_param_infos
446+
if (params is None) != (shared_params is None):
447+
raise AssertionError(
448+
"Expected params and shared_params to both be None or both be not None"
431449
)
450+
if params is not None:
451+
if shared_params is None or len(shared_params) != len(shared_param_infos):
452+
raise AssertionError(
453+
f"Expected shared_params to be not None and have length {len(shared_param_infos)}, got {shared_params}"
454+
)
432455
self._params = []
433456
for param, is_padding in zip(params, is_padding_mask):
434457
if not is_padding:
435458
self._params.append(param)
436-
self._shared_params = shared_params
459+
if shared_params is not None:
460+
self._shared_params = shared_params
461+
else:
462+
self._shared_params = []
437463
# Mark the original parameters to avoid flattening them into
438464
# another `FlatParameter` during recursive construction
439465
for param in chain(self._params, self._shared_params):
@@ -579,7 +605,8 @@ def __init__(
579605
# before `_init_flat_param()`, which performs the actual validation
580606
self._orig_param_dtype = params[0].dtype
581607
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
582-
assert self._fwd_bwd_param_dtype is not None # mypy
608+
if self._fwd_bwd_param_dtype is None:
609+
raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") # mypy
583610
self._aligned_numel = (
584611
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
585612
if align_addresses
@@ -807,7 +834,8 @@ def _validate_tensors_to_flatten(
807834
dtype = tensor.dtype
808835
flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
809836
device = tensor.device
810-
assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
837+
if flat_param_requires_grad is None:
838+
raise AssertionError("Requires non-empty `tensors` list")
811839
return dtype, flat_param_requires_grad, device
812840

813841
def flatten_tensors(
@@ -908,8 +936,10 @@ def _init_param_reduce_dtypes(
908936
else:
909937
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
910938
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
911-
assert self._fwd_bwd_param_dtype is not None
912-
assert self._reduce_dtype is not None
939+
if self._fwd_bwd_param_dtype is None:
940+
raise AssertionError("Expected _fwd_bwd_param_dtype to be not None")
941+
if self._reduce_dtype is None:
942+
raise AssertionError("Expected _reduce_dtype to be not None")
913943

914944
###################################
915945
# SHARD INITIALIZATION & METADATA #
@@ -985,9 +1015,10 @@ def _init_shard_metadata(
9851015
shard_param_infos = self._get_shard_metadata(
9861016
unsharded_start_idx, unsharded_end_idx
9871017
)
988-
assert len(shard_param_infos) == flat_param._num_params, (
989-
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
990-
)
1018+
if len(shard_param_infos) != flat_param._num_params:
1019+
raise AssertionError(
1020+
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
1021+
)
9911022
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
9921023
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
9931024

@@ -1003,9 +1034,10 @@ def _get_shard_metadata(
10031034
unsharded flat parameter specifying the shard.
10041035
"""
10051036
flat_param_offsets = self._get_flat_param_offsets()
1006-
assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
1007-
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
1008-
)
1037+
if len(flat_param_offsets) != len(self.flat_param._numels_with_padding):
1038+
raise AssertionError(
1039+
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
1040+
)
10091041
shard_param_infos: list[_ShardParamInfo] = []
10101042
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
10111043
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
@@ -1033,12 +1065,13 @@ def _get_shard_metadata(
10331065
unsharded_start_idx - unsharded_param_start_idx
10341066
)
10351067
offset_in_shard = 0
1036-
assert (
1068+
if not (
10371069
offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
1038-
), (
1039-
f"Invalid `offset_in_shard` of {offset_in_shard} for "
1040-
f"sharded flat parameter with {sharded_flat_param_numel} numel"
1041-
)
1070+
):
1071+
raise AssertionError(
1072+
f"Invalid `offset_in_shard` of {offset_in_shard} for "
1073+
f"sharded flat parameter with {sharded_flat_param_numel} numel"
1074+
)
10421075
intra_param_end_idx = (
10431076
min(unsharded_param_end_idx, unsharded_end_idx)
10441077
- unsharded_param_start_idx
@@ -1082,9 +1115,10 @@ def _get_unpadded_shard(
10821115
else:
10831116
chunk = chunks[rank]
10841117
numel_to_pad = chunks[0].numel() - chunk.numel()
1085-
assert numel_to_pad >= 0, (
1086-
"Chunk's size should be at most the first chunk's size"
1087-
)
1118+
if numel_to_pad < 0:
1119+
raise AssertionError(
1120+
"Chunk's size should be at most the first chunk's size"
1121+
)
10881122
return chunk, numel_to_pad
10891123

10901124
@staticmethod
@@ -1115,12 +1149,16 @@ def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
11151149
This requires ``tensor`` to have 1D shape and ensures that the returned
11161150
shape is 1D.
11171151
"""
1118-
assert len(tensor.shape) == 1, f"{tensor.shape}"
1152+
if len(tensor.shape) != 1:
1153+
raise AssertionError(f"Expected 1D tensor shape, got {tensor.shape}")
11191154
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
11201155
tensor, rank, world_size
11211156
)
11221157
unpadded_sharded_size = unpadded_sharded_tensor.size()
1123-
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
1158+
if len(unpadded_sharded_size) != 1:
1159+
raise AssertionError(
1160+
f"Expected 1D unpadded_sharded_size, got {unpadded_sharded_size}"
1161+
)
11241162
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
11251163

11261164
def _get_flat_param_offsets(self) -> list[tuple[int, int]]:
@@ -2059,7 +2097,7 @@ def _use_unsharded_grad_views(self) -> None:
20592097
_p_assert(
20602098
hasattr(module, param_name),
20612099
f"{module_name + '.' + param_name if module_name else param_name} is missing",
2062-
) # did not save FQN info in `_shared_param_infos`
2100+
)
20632101
param = getattr(module, param_name)
20642102
prim_param = getattr(prim_module, prim_param_name)
20652103
if (
@@ -2130,7 +2168,8 @@ def _use_sharded_views(self) -> None:
21302168
offset = shard_param_info.offset_in_shard
21312169
numel_in_shard = shard_param_info.numel_in_shard
21322170
param.data = flat_param[offset : offset + numel_in_shard]
2133-
assert self.flat_param._shared_params is not None
2171+
if self.flat_param._shared_params is None:
2172+
raise AssertionError("Expected _shared_params to be not None")
21342173
for i, (
21352174
param,
21362175
(param_name, module, _, prim_param_name, prim_module, _),
@@ -2194,7 +2233,8 @@ def _use_sharded_grad_views(self) -> None:
21942233
)
21952234
else:
21962235
param.grad = None
2197-
assert flat_param._shared_params is not None
2236+
if flat_param._shared_params is None:
2237+
raise AssertionError("Expected _shared_params to be not None")
21982238
for param, (_, _, _, prim_param_name, prim_module, _) in zip(
21992239
flat_param._shared_params, flat_param._shared_param_infos
22002240
):
@@ -2408,7 +2448,8 @@ def _writeback_tensor(
24082448
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
24092449
else:
24102450
dst_tensor[offset : offset + expected_shape.numel()].zero_()
2411-
assert self.flat_param._is_grad_none_mask is not None
2451+
if self.flat_param._is_grad_none_mask is None:
2452+
raise AssertionError("Expected _is_grad_none_mask to be not None")
24122453
self.flat_param._is_grad_none_mask[tensor_index] = True
24132454

24142455
def _reset_flat_param_grad_info_if_needed(self):
@@ -2427,7 +2468,8 @@ def _reset_flat_param_grad_info_if_needed(self):
24272468
if not self._use_orig_params:
24282469
return
24292470
flat_param = self.flat_param
2430-
assert flat_param._params is not None # mypy
2471+
if flat_param._params is None:
2472+
raise AssertionError("Expected _params to be not None") # mypy
24312473
all_grad_none = True
24322474
requires_grad = False
24332475
for param in flat_param._params:
@@ -2571,12 +2613,16 @@ def _reset_is_grad_none(self) -> None:
25712613
"Expects to only be called in the post-backward after gradient computation",
25722614
)
25732615
flat_param = self.flat_param
2574-
assert flat_param._params is not None # mypy
2616+
if flat_param._params is None:
2617+
raise AssertionError("Expected _params to be not None") # mypy
25752618
for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
25762619
# As long as the parameter requires gradient, it should receive a
25772620
# meaningful gradient (even if the gradient happens to be zeros)
25782621
if param.requires_grad:
2579-
assert flat_param._is_grad_none_mask is not None # mypy
2622+
if flat_param._is_grad_none_mask is None:
2623+
raise AssertionError(
2624+
"Expected _is_grad_none_mask to be not None"
2625+
) # mypy
25802626
flat_param._is_grad_none_mask[i] = False
25812627

25822628
#######################

torch/distributed/fsdp/_fsdp_extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def _ext_pre_load_state_dict_transform(
161161
if fsdp_extension is not None:
162162
return fsdp_extension.pre_load_state_dict_transform(tensor)
163163

164-
assert type(tensor) is ShardedTensor
164+
if type(tensor) is not ShardedTensor:
165+
raise AssertionError(f"Expected ShardedTensor, got {type(tensor)}")
165166
shards = tensor.local_shards()
166167
return (tensor, shards)
167168

torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,10 @@ def foreach_reduce(
502502
):
503503
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
504504
continue
505-
assert unsharded_grad.size(shard_dim) % world_size == 0, (
506-
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
507-
)
505+
if unsharded_grad.size(shard_dim) % world_size != 0:
506+
raise AssertionError(
507+
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
508+
)
508509
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
509510
unsharded_grads[i] = torch.cat(chunks, dim=0)
510511

@@ -621,7 +622,10 @@ def foreach_reduce(
621622
# ensure that the D2H copy finishes before the optimizer
622623
fsdp_param.grad_offload_event = post_reduce_stream.record_event()
623624
if to_accumulate_grad:
624-
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
625+
if not isinstance(fsdp_param.sharded_param.grad, DTensor):
626+
raise AssertionError(
627+
f"Expected fsdp_param.sharded_param.grad to be DTensor, got {type(fsdp_param.sharded_param.grad)}"
628+
)
625629
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
626630
else:
627631
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(

torch/distributed/fsdp/_fully_shard/_fsdp_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818

1919
def detect_compiled_autograd():
20-
assert not torch.compiler.is_compiling(), (
21-
"`detect_compiled_autograd()` is designed to be called in eager mode"
22-
)
20+
if torch.compiler.is_compiling():
21+
raise AssertionError(
22+
"`detect_compiled_autograd()` is designed to be called in eager mode"
23+
)
2324
global _compiled_autograd_enabled
2425
import torch._dynamo.compiled_autograd as ca
2526

0 commit comments

Comments
 (0)