From 50d442b8449dc2dc4eeacf659fed68c41033015b Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Mon, 9 Mar 2026 20:38:32 +0800 Subject: [PATCH] Add assert message for tensor_fusion_helper --- .../fleet/utils/tensor_fusion_helper.py | 126 ++++++++++++++---- 1 file changed, 97 insertions(+), 29 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 09d53b9c37e698..9de96ed521a044 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -147,7 +147,9 @@ def flatten_dense_tensors( dtype = parameters[0].dtype for param in parameters: - assert param.trainable, "param must be trainable..." + assert param.trainable, ( + f"param '{param.name}' must be trainable, but got trainable={param.trainable}" + ) size = np.prod(param.shape) * align[dtype] remaining = size % alignment[get_current_device_type()] ali = ( @@ -263,7 +265,10 @@ def _get_padding(self): padding = padding_end - padding_start grad_numel = self._slice_grad._numel() - assert grad_numel >= padding, f"{grad_numel} vs {padding}" + assert grad_numel >= padding, ( + f"grad_numel ({grad_numel}) should be >= padding ({padding}), " + f"padding_start={padding_start}, padding_end={padding_end}" + ) padding_grad = self._slice_grad._slice( grad_numel - padding, grad_numel ) @@ -272,7 +277,10 @@ def _get_padding(self): return None def _slice_grad_from_buffer(self): - assert self._grad_buffer is not None + assert self._grad_buffer is not None, ( + f"_grad_buffer should not be None when slicing grad from buffer, " + f"release_grad={self._release_grad}" + ) if self._param_begin < self._param_end: self._slice_grad = self._grad_buffer._slice( self._param_begin, self._param_end @@ -311,9 +319,16 @@ def fill_slice_param(self, slice_param): slice_begin = self._param_begin slice_end = self._param_end if slice_param._is_initialized(): - assert self._param_buffer._is_shared_buffer_with(slice_param) - assert len(slice_param.shape) == 1 - assert slice_param.shape[0] == (slice_end - slice_begin) + assert self._param_buffer._is_shared_buffer_with(slice_param), ( + f"param_buffer should share buffer with slice_param for param '{self._param.name}'" + ) + assert len(slice_param.shape) == 1, ( + f"slice_param should be 1-D tensor, but got shape {slice_param.shape}" + ) + assert slice_param.shape[0] == (slice_end - slice_begin), ( + f"slice_param shape[0] ({slice_param.shape[0]}) should equal to slice_end - slice_begin " + f"({slice_end - slice_begin}), slice_begin={slice_begin}, slice_end={slice_end}" + ) slice_begin = self._param_begin slice_end = self._param_end slice_buffer = self._param_buffer._slice(slice_begin, slice_end) @@ -321,7 +336,9 @@ def fill_slice_param(self, slice_param): slice_param.get_tensor()._set_dims([slice_end - slice_begin]) def assign_slice_grad(self, slice_param): - assert self._param_buffer._is_shared_buffer_with(self._param) + assert self._param_buffer._is_shared_buffer_with(self._param), ( + f"param_buffer should share buffer with param '{self._param.name}'" + ) slice_grad = self._slice_grad if slice_grad is None: return @@ -330,12 +347,16 @@ def assign_slice_grad(self, slice_param): if not hasattr(slice_param, "main_grad"): slice_param.main_grad = slice_grad else: - assert slice_param.main_grad is slice_grad + assert slice_param.main_grad is slice_grad, ( + f"slice_param.main_grad should be the same as slice_grad for param '{self._param.name}'" + ) elif slice_grad is not None: if slice_param.grad is None: slice_param._copy_gradient_from(slice_grad) else: - assert slice_param.grad._is_shared_buffer_with(slice_grad) + assert slice_param.grad._is_shared_buffer_with(slice_grad), ( + f"slice_param.grad should share buffer with slice_grad for param '{self._param.name}'" + ) def _clear_param_buffer(self): self._param._clear_to_zero_allocation() @@ -391,7 +412,9 @@ def get_padded_size(param): return padded_size for param in parameters: - assert param.trainable, "param must be trainable..." + assert param.trainable, ( + f"param '{param.name}' must be trainable, but got trainable={param.trainable}" + ) param2index[param.name] = total_buffer_size total_buffer_size += get_padded_size(param) @@ -420,7 +443,10 @@ def get_padded_size(param): release_grad, ) if init_slice_param and grad_view.has_effective_slice_param: - assert param.name in slice_params + assert param.name in slice_params, ( + f"param '{param.name}' should be in slice_params when init_slice_param=True, " + f"available keys: {list(slice_params.keys())}" + ) grad_view.fill_slice_param(slice_params[param.name]) # hack main_grad sharding_grad_view[param.name] = grad_view @@ -493,7 +519,10 @@ def __init__( assert act == HOOK_ACTION.REDUCE_SCATTER, ( "Currently, only support reduce_scatter" ) - assert release_grads, "Currently, only support release_grads" + assert release_grads, ( + "Currently, only support release_grads when using free_grads_in_comm, " + f"but got release_grads={release_grads}" + ) assert not (self._fuse_param and self._release_grads), ( "It's not supported when using fuse_param and release_grad at the same time." @@ -517,11 +546,17 @@ def __init__( self._act = act if self._act == HOOK_ACTION.ALL_REDUCE: - assert dst == -1 + assert dst == -1, ( + f"dst should be -1 for ALL_REDUCE action, but got dst={dst}" + ) elif self._act == HOOK_ACTION.REDUCE_SCATTER: - assert dst == -1 + assert dst == -1, ( + f"dst should be -1 for REDUCE_SCATTER action, but got dst={dst}" + ) elif self._act == HOOK_ACTION.REDUCE: - assert dst != -1 + assert dst != -1, ( + f"dst should not be -1 for REDUCE action, but got dst={dst}" + ) else: raise ValueError( "The act should be allreduce for dp or reduce for sharding." @@ -564,7 +599,10 @@ def __init__( warp_buffer=False, )[0].buffer else: - assert not self._fuse_param, "not supported" + assert not self._fuse_param, ( + "fuse_param is not supported when act is REDUCE_SCATTER, " + f"but got fuse_param={self._fuse_param}" + ) ( self._sharding_param_grad_view, self.buffer_size, @@ -637,7 +675,10 @@ def _copy_grad_to_buffer(self, param): return if self.grad_storage is None: - assert self._params_step_dict[param.name] == 0 + assert self._params_step_dict[param.name] == 0, ( + f"_params_step_dict['{param.name}'] should be 0 when grad_storage is None, " + f"but got {self._params_step_dict[param.name]}" + ) self.grad_storage = paddle.zeros( [self.buffer_size], dtype=self._dtype @@ -652,7 +693,10 @@ def _copy_grad_to_buffer(self, param): ]._slice_grad_from_buffer() else: grad_end = self.param2offset[param.name] + np.prod(param.shape) - assert grad_end <= self.buffer_size + assert grad_end <= self.buffer_size, ( + f"grad_end ({grad_end}) should be <= buffer_size ({self.buffer_size}), " + f"param='{param.name}', offset={self.param2offset[param.name]}, shape={param.shape}" + ) tmp_var = self.grad_storage._slice( self.param2offset[param.name], grad_end ) @@ -692,7 +736,10 @@ def _all_params_checked_in(self): ) def add_grad(self, param, use_comm=True): - assert param.name in self._params_step_dict + assert param.name in self._params_step_dict, ( + f"param '{param.name}' should be in _params_step_dict, " + f"available params: {list(self._params_step_dict.keys())}" + ) if not self._release_grads or self._params_step_dict[param.name] > 0: current_ptr = get_grad_address(param, self.use_main_grad) @@ -717,8 +764,13 @@ def add_grad(self, param, use_comm=True): @imperative_base.no_grad def assign_slice_grad(self, param, slice_param): - assert self._act == HOOK_ACTION.REDUCE_SCATTER - assert param.name in self._sharding_param_grad_view + assert self._act == HOOK_ACTION.REDUCE_SCATTER, ( + f"_act should be REDUCE_SCATTER for assign_slice_grad, but got act={self._act}" + ) + assert param.name in self._sharding_param_grad_view, ( + f"param '{param.name}' should be in _sharding_param_grad_view, " + f"available params: {list(self._sharding_param_grad_view.keys())}" + ) grad_view = self._sharding_param_grad_view[param.name] grad_view.assign_slice_grad(slice_param) @@ -726,7 +778,9 @@ def assign_slice_grad(self, param, slice_param): def sync_params(self, sync=True, param2task={}): if not self.need_reduce_scale_sync(): return - assert self._act == HOOK_ACTION.REDUCE_SCATTER + assert self._act == HOOK_ACTION.REDUCE_SCATTER, ( + f"_act should be REDUCE_SCATTER for sync_params, but got act={self._act}" + ) full_buffer = self.param_storage group = self._comm_group shard_size = full_buffer._numel() // group.nranks @@ -746,7 +800,10 @@ def sync_params(self, sync=True, param2task={}): self.sync_param_task = task for param in self.params: - assert param.name not in param2task + assert param.name not in param2task, ( + f"param '{param.name}' should not already exist in param2task. " + f"This might indicate duplicate param assignment." + ) param2task[param.name] = task @property @@ -844,7 +901,9 @@ def scale_grads(self): if self._comm_group.nranks == 1 and self._task is None: self._reset_params_checked_in() return - assert self._task is not None, "Task is not initialized." + assert self._task is not None, ( + "Task is not initialized. This might indicate that comm_grads() was not called before scale_grads()." + ) self._task.wait() # scale will be skipped when use reduce_avg comm operation @@ -936,7 +995,9 @@ def filter_params(params, is_fp32, is_distributed, need_clip): if dtype is None: dtype = p.dtype else: - assert dtype == p.dtype + assert dtype == p.dtype, ( + f"All params should have the same dtype. Expected {dtype}, but got {p.dtype} for param '{p.name}'" + ) return params, dtype @@ -971,7 +1032,10 @@ def _fused_parameters_impl( if no_fp32_dtype is None: no_fp32_dtype = dtype elif dtype is not None: - assert no_fp32_dtype == dtype + assert no_fp32_dtype == dtype, ( + f"All non-fp32 params should have the same dtype. " + f"Expected {no_fp32_dtype}, but got {dtype}" + ) attrs.append([dtype, dist, clip]) param_groups.append(params) @@ -1073,7 +1137,9 @@ def fused_parameters( ) comm_group = paddle.distributed.collective._get_default_group() if act == HOOK_ACTION.REDUCE: - assert dst != -1 + assert dst != -1, ( + f"dst should not be -1 for REDUCE action, but got dst={dst}" + ) elif act == HOOK_ACTION.ALL_REDUCE: dst = -1 @@ -1082,10 +1148,12 @@ def fused_parameters( comm_buffers = [] for idx, group_param in enumerate(parameters): assert isinstance(group_param, dict), ( - "For group params, each group should be a dictionary." + f"For group params, each group should be a dictionary, but got type {type(group_param)} " + f"at index {idx}" ) assert 'params' in group_param.keys(), ( - "For group params, each group should have parameters." + f"For group params, each group should have 'params' key, but got keys: " + f"{list(group_param.keys())} at index {idx}" ) real_param = group_param['params'] (