Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 97 additions & 29 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def flatten_dense_tensors(
dtype = parameters[0].dtype

for param in parameters:
assert param.trainable, "param must be trainable..."
assert param.trainable, (
f"param '{param.name}' must be trainable, but got trainable={param.trainable}"
)
size = np.prod(param.shape) * align[dtype]
remaining = size % alignment[get_current_device_type()]
ali = (
Expand Down Expand Up @@ -263,7 +265,10 @@ def _get_padding(self):

padding = padding_end - padding_start
grad_numel = self._slice_grad._numel()
assert grad_numel >= padding, f"{grad_numel} vs {padding}"
assert grad_numel >= padding, (
f"grad_numel ({grad_numel}) should be >= padding ({padding}), "
f"padding_start={padding_start}, padding_end={padding_end}"
)
padding_grad = self._slice_grad._slice(
grad_numel - padding, grad_numel
)
Expand All @@ -272,7 +277,10 @@ def _get_padding(self):
return None

def _slice_grad_from_buffer(self):
assert self._grad_buffer is not None
assert self._grad_buffer is not None, (
f"_grad_buffer should not be None when slicing grad from buffer, "
f"release_grad={self._release_grad}"
)
if self._param_begin < self._param_end:
self._slice_grad = self._grad_buffer._slice(
self._param_begin, self._param_end
Expand Down Expand Up @@ -311,17 +319,26 @@ def fill_slice_param(self, slice_param):
slice_begin = self._param_begin
slice_end = self._param_end
if slice_param._is_initialized():
assert self._param_buffer._is_shared_buffer_with(slice_param)
assert len(slice_param.shape) == 1
assert slice_param.shape[0] == (slice_end - slice_begin)
assert self._param_buffer._is_shared_buffer_with(slice_param), (
f"param_buffer should share buffer with slice_param for param '{self._param.name}'"
)
assert len(slice_param.shape) == 1, (
f"slice_param should be 1-D tensor, but got shape {slice_param.shape}"
)
assert slice_param.shape[0] == (slice_end - slice_begin), (
f"slice_param shape[0] ({slice_param.shape[0]}) should equal to slice_end - slice_begin "
f"({slice_end - slice_begin}), slice_begin={slice_begin}, slice_end={slice_end}"
)
slice_begin = self._param_begin
slice_end = self._param_end
slice_buffer = self._param_buffer._slice(slice_begin, slice_end)
slice_buffer._share_buffer_to(slice_param)
slice_param.get_tensor()._set_dims([slice_end - slice_begin])

def assign_slice_grad(self, slice_param):
assert self._param_buffer._is_shared_buffer_with(self._param)
assert self._param_buffer._is_shared_buffer_with(self._param), (
f"param_buffer should share buffer with param '{self._param.name}'"
)
slice_grad = self._slice_grad
if slice_grad is None:
return
Expand All @@ -330,12 +347,16 @@ def assign_slice_grad(self, slice_param):
if not hasattr(slice_param, "main_grad"):
slice_param.main_grad = slice_grad
else:
assert slice_param.main_grad is slice_grad
assert slice_param.main_grad is slice_grad, (
f"slice_param.main_grad should be the same as slice_grad for param '{self._param.name}'"
)
elif slice_grad is not None:
if slice_param.grad is None:
slice_param._copy_gradient_from(slice_grad)
else:
assert slice_param.grad._is_shared_buffer_with(slice_grad)
assert slice_param.grad._is_shared_buffer_with(slice_grad), (
f"slice_param.grad should share buffer with slice_grad for param '{self._param.name}'"
)

def _clear_param_buffer(self):
self._param._clear_to_zero_allocation()
Expand Down Expand Up @@ -391,7 +412,9 @@ def get_padded_size(param):
return padded_size

for param in parameters:
assert param.trainable, "param must be trainable..."
assert param.trainable, (
f"param '{param.name}' must be trainable, but got trainable={param.trainable}"
)
param2index[param.name] = total_buffer_size
total_buffer_size += get_padded_size(param)

Expand Down Expand Up @@ -420,7 +443,10 @@ def get_padded_size(param):
release_grad,
)
if init_slice_param and grad_view.has_effective_slice_param:
assert param.name in slice_params
assert param.name in slice_params, (
f"param '{param.name}' should be in slice_params when init_slice_param=True, "
f"available keys: {list(slice_params.keys())}"
)
grad_view.fill_slice_param(slice_params[param.name])
# hack main_grad
sharding_grad_view[param.name] = grad_view
Expand Down Expand Up @@ -493,7 +519,10 @@ def __init__(
assert act == HOOK_ACTION.REDUCE_SCATTER, (
"Currently, only support reduce_scatter"
)
assert release_grads, "Currently, only support release_grads"
assert release_grads, (
"Currently, only support release_grads when using free_grads_in_comm, "
f"but got release_grads={release_grads}"
)

assert not (self._fuse_param and self._release_grads), (
"It's not supported when using fuse_param and release_grad at the same time."
Expand All @@ -517,11 +546,17 @@ def __init__(

self._act = act
if self._act == HOOK_ACTION.ALL_REDUCE:
assert dst == -1
assert dst == -1, (
f"dst should be -1 for ALL_REDUCE action, but got dst={dst}"
)
elif self._act == HOOK_ACTION.REDUCE_SCATTER:
assert dst == -1
assert dst == -1, (
f"dst should be -1 for REDUCE_SCATTER action, but got dst={dst}"
)
elif self._act == HOOK_ACTION.REDUCE:
assert dst != -1
assert dst != -1, (
f"dst should not be -1 for REDUCE action, but got dst={dst}"
)
else:
raise ValueError(
"The act should be allreduce for dp or reduce for sharding."
Expand Down Expand Up @@ -564,7 +599,10 @@ def __init__(
warp_buffer=False,
)[0].buffer
else:
assert not self._fuse_param, "not supported"
assert not self._fuse_param, (
"fuse_param is not supported when act is REDUCE_SCATTER, "
f"but got fuse_param={self._fuse_param}"
)
(
self._sharding_param_grad_view,
self.buffer_size,
Expand Down Expand Up @@ -637,7 +675,10 @@ def _copy_grad_to_buffer(self, param):
return

if self.grad_storage is None:
assert self._params_step_dict[param.name] == 0
assert self._params_step_dict[param.name] == 0, (
f"_params_step_dict['{param.name}'] should be 0 when grad_storage is None, "
f"but got {self._params_step_dict[param.name]}"
)

self.grad_storage = paddle.zeros(
[self.buffer_size], dtype=self._dtype
Expand All @@ -652,7 +693,10 @@ def _copy_grad_to_buffer(self, param):
]._slice_grad_from_buffer()
else:
grad_end = self.param2offset[param.name] + np.prod(param.shape)
assert grad_end <= self.buffer_size
assert grad_end <= self.buffer_size, (
f"grad_end ({grad_end}) should be <= buffer_size ({self.buffer_size}), "
f"param='{param.name}', offset={self.param2offset[param.name]}, shape={param.shape}"
)
tmp_var = self.grad_storage._slice(
self.param2offset[param.name], grad_end
)
Expand Down Expand Up @@ -692,7 +736,10 @@ def _all_params_checked_in(self):
)

def add_grad(self, param, use_comm=True):
assert param.name in self._params_step_dict
assert param.name in self._params_step_dict, (
f"param '{param.name}' should be in _params_step_dict, "
f"available params: {list(self._params_step_dict.keys())}"
)

if not self._release_grads or self._params_step_dict[param.name] > 0:
current_ptr = get_grad_address(param, self.use_main_grad)
Expand All @@ -717,16 +764,23 @@ def add_grad(self, param, use_comm=True):

@imperative_base.no_grad
def assign_slice_grad(self, param, slice_param):
assert self._act == HOOK_ACTION.REDUCE_SCATTER
assert param.name in self._sharding_param_grad_view
assert self._act == HOOK_ACTION.REDUCE_SCATTER, (
f"_act should be REDUCE_SCATTER for assign_slice_grad, but got act={self._act}"
)
assert param.name in self._sharding_param_grad_view, (
f"param '{param.name}' should be in _sharding_param_grad_view, "
f"available params: {list(self._sharding_param_grad_view.keys())}"
)
grad_view = self._sharding_param_grad_view[param.name]
grad_view.assign_slice_grad(slice_param)

@imperative_base.no_grad
def sync_params(self, sync=True, param2task={}):
if not self.need_reduce_scale_sync():
return
assert self._act == HOOK_ACTION.REDUCE_SCATTER
assert self._act == HOOK_ACTION.REDUCE_SCATTER, (
f"_act should be REDUCE_SCATTER for sync_params, but got act={self._act}"
)
full_buffer = self.param_storage
group = self._comm_group
shard_size = full_buffer._numel() // group.nranks
Expand All @@ -746,7 +800,10 @@ def sync_params(self, sync=True, param2task={}):

self.sync_param_task = task
for param in self.params:
assert param.name not in param2task
assert param.name not in param2task, (
f"param '{param.name}' should not already exist in param2task. "
f"This might indicate duplicate param assignment."
)
param2task[param.name] = task

@property
Expand Down Expand Up @@ -844,7 +901,9 @@ def scale_grads(self):
if self._comm_group.nranks == 1 and self._task is None:
self._reset_params_checked_in()
return
assert self._task is not None, "Task is not initialized."
assert self._task is not None, (
"Task is not initialized. This might indicate that comm_grads() was not called before scale_grads()."
)
self._task.wait()

# scale will be skipped when use reduce_avg comm operation
Expand Down Expand Up @@ -936,7 +995,9 @@ def filter_params(params, is_fp32, is_distributed, need_clip):
if dtype is None:
dtype = p.dtype
else:
assert dtype == p.dtype
assert dtype == p.dtype, (
f"All params should have the same dtype. Expected {dtype}, but got {p.dtype} for param '{p.name}'"
)

return params, dtype

Expand Down Expand Up @@ -971,7 +1032,10 @@ def _fused_parameters_impl(
if no_fp32_dtype is None:
no_fp32_dtype = dtype
elif dtype is not None:
assert no_fp32_dtype == dtype
assert no_fp32_dtype == dtype, (
f"All non-fp32 params should have the same dtype. "
f"Expected {no_fp32_dtype}, but got {dtype}"
)
attrs.append([dtype, dist, clip])
param_groups.append(params)

Expand Down Expand Up @@ -1073,7 +1137,9 @@ def fused_parameters(
)
comm_group = paddle.distributed.collective._get_default_group()
if act == HOOK_ACTION.REDUCE:
assert dst != -1
assert dst != -1, (
f"dst should not be -1 for REDUCE action, but got dst={dst}"
)
elif act == HOOK_ACTION.ALL_REDUCE:
dst = -1

Expand All @@ -1082,10 +1148,12 @@ def fused_parameters(
comm_buffers = []
for idx, group_param in enumerate(parameters):
assert isinstance(group_param, dict), (
"For group params, each group should be a dictionary."
f"For group params, each group should be a dictionary, but got type {type(group_param)} "
f"at index {idx}"
)
assert 'params' in group_param.keys(), (
"For group params, each group should have parameters."
f"For group params, each group should have 'params' key, but got keys: "
f"{list(group_param.keys())} at index {idx}"
)
real_param = group_param['params']
(
Expand Down
Loading