Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/ops/adam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam
from .zenflow_cpu_adam import ZenFlowCPUAdam
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3
333 changes: 333 additions & 0 deletions deepspeed/ops/adam/zenflow_torch_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,339 @@ def _group_step_with_offload(self, group_to_paramlist):
param.selected_grad = None


class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW):

def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs):
super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs)
self.offload = offload

if offload:
self.step = self._step_with_offload
self.temp_copy_param = self._temp_copy_param_with_offload
self.group_step = self._group_step_with_offload
self.bucket_size = bucket_size
else:
self.step = self._step_without_offload
self.temp_copy_param = self._temp_copy_param_without_offload
self.group_step = self._group_step_without_offload

@torch.no_grad()
def _temp_copy_param_without_offload(self, paramlist):
for param in paramlist:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)

if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view(
param.complete_numel // num_row, num_row)
param.temp_selected_param = param_2d[param.selected_indices, :].clone().detach()
else:
param.temp_selected_param = param.ds_tensor.data.clone().detach()

@torch.no_grad()
def _temp_copy_param_with_offload(self, paramlist):
for param in paramlist:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)

if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view(
param.complete_numel // num_row, num_row)
temp_selected_param = param_2d[param.selected_indices, :].clone().detach()
else:
temp_selected_param = param.ds_tensor.data.clone().detach()
param.temp_selected_param = temp_selected_param.cpu()

def clear_selected_mv(self):
print("Zenflow: clearing selective optimizer states...")
for group in self.param_groups:
for param in group['params']:
state = self.state.setdefault(param, {})
if len(state) == 0:
continue
if self.offload:
param.exp_avg_cpu_data.zero_()
param.exp_avg_sq_cpu_data.zero_()
else:
state["exp_avg"].zero_()
state["exp_avg_sq"].zero_()

@torch.no_grad()
def _step_without_offload(self):
for group in self.param_groups:

params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
for param in group["params"]:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
selected_param = param_2d[param.selected_indices, :]
else:
selected_param = param.ds_tensor.data
if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None:
selected_param.copy_(param.temp_selected_param)

state = self.state.setdefault(param, {})
if len(state) == 0:
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
state["exp_avg"] = torch.zeros_like(selected_param)
state["exp_avg_sq"] = torch.zeros_like(selected_param)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)

params_with_grad.append(selected_param)
grads.append(param.selected_grad)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=False,
)
for i, param in enumerate(group["params"]):
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = params_with_grad[i]

for param in group["params"]:
if hasattr(param, "temp_selected_param"):
param.temp_selected_param = None
param.selected_grad = None

@torch.no_grad()
def _group_step_without_offload(self, paramlist):

group_to_paramlist = {}
for param in paramlist:
group_id = param.group_id
if group_id not in group_to_paramlist:
group_to_paramlist[group_id] = []
group_to_paramlist[group_id].append(param)

for group_id in sorted(group_to_paramlist.keys()):
params = group_to_paramlist[group_id]
group = self.param_groups[group_id]

params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])

for param in params:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)

if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
selected_param = param_2d[param.selected_indices, :]
else:
selected_param = param.ds_tensor.data

state = self.state.setdefault(param, {})
if len(state) == 0:
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
state["exp_avg"] = torch.zeros_like(selected_param)
state["exp_avg_sq"] = torch.zeros_like(selected_param)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)

params_with_grad.append(selected_param)
grads.append(param.selected_grad)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=False,
)
for i, param in enumerate(params):
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = params_with_grad[i]

for param in params:
param.selected_grad = None

def copy_mv_from_cpu(self, params):
for param in params:
param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True)
param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True)

def copy_mv_to_cpu(self, params):
for param in params:
param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True)
param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True)
param.exp_avg = None
param.exp_avg_sq = None

@torch.no_grad()
def _group_step_with_offload(self, paramlist):

group_to_paramlist = {}
for param in paramlist:
group_id = param.group_id
if group_id not in group_to_paramlist:
group_to_paramlist[group_id] = []
group_to_paramlist[group_id].append(param)

for group_id in sorted(group_to_paramlist.keys()):
params = group_to_paramlist[group_id]
group = self.param_groups[group_id]

self.copy_mv_from_cpu(params)
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])

for param in params:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)

if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
selected_param = param_2d[param.selected_indices, :]
else:
selected_param = param.ds_tensor.data

state = self.state.setdefault(param, {})
if len(state) == 0:
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)

params_with_grad.append(selected_param)
grads.append(param.selected_grad)
exp_avgs.append(param.exp_avg.view_as(selected_param))
exp_avg_sqs.append(param.exp_avg_sq.view_as(selected_param))
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])

adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=False,
)

for i, param in enumerate(params):
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = params_with_grad[i]

self.copy_mv_to_cpu(params)

for param in params:
param.selected_grad = None

@torch.no_grad()
def _step_with_offload(self):
for group_id, group in enumerate(self.param_groups):
params = group["params"]

bucket = []
bucket_numel = 0

def flush_bucket():
if not bucket:
return
for param in bucket:
if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None:
temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True)
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = temp_selected_param
else:
param.ds_tensor.data.copy_(temp_selected_param)
param.temp_selected_param = None

self.group_step(bucket)
bucket.clear()

for param in params:
if hasattr(param, "selected_grad"):
bucket.append(param)
bucket_numel += param.numel()
if bucket_numel >= self.bucket_size:
flush_bucket()
bucket_numel = 0

flush_bucket()


def _single_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
Expand Down
1 change: 1 addition & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,7 @@ def _configure_zero_optimizer(self, optimizer):
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
zenflow_config=self.zenflow_config(),
sub_group_size=self.zero_sub_group_size(),
offload_ratio=self.zero_partial_offload(),
mpu=self.mpu,
Expand Down
Loading