Skip to content

Commit 6a7bf57

Browse files
authored
enable split sgd for embeddingbag's weight. conv/linear's bias (#127)
* enable split sgd for embeddingbag's weight. conv/linear's bias * refine ipex.optimize code * fix split std momentum_buffer dtype
1 parent c4809ed commit 6a7bf57

File tree

9 files changed

+494
-282
lines changed

9 files changed

+494
-282
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@
1111
from .ops import *
1212
from .utils import *
1313
from .weight_prepack import *
14+
from .optimizer_utils import *
15+
from .weight_cast import *
1416
from .optim import *

intel_pytorch_extension_py/optim/_functional.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,25 +81,48 @@ def sgd(params: List[Tensor],
8181

8282
for i, param in enumerate(params):
8383

84+
8485
d_p = d_p_list[i]
86+
float_d_p, float_param = None, None
87+
if d_p.dtype == torch.bfloat16:
88+
assert param in attr, "split sgd requires record 'trail' part of params in attr"
89+
trail = attr[param]['trail']
90+
91+
if weight_decay != 0 or momentum != 0:
92+
float_d_p = d_p.float()
93+
if d_p.dtype == torch.bfloat16:
94+
float_d_p = d_p.float()
95+
float_param = torch.ops.torch_ipex.cat_bfloat16_float(param, trail)
96+
else:
97+
float_param = param
98+
float_d_p = d_p
99+
85100
if weight_decay != 0:
86-
d_p = d_p.add(param, alpha=weight_decay)
101+
float_d_p = float_d_p.add(float_param, alpha=weight_decay)
87102

88103
if momentum != 0:
89104
buf = momentum_buffer_list[i]
90-
91105
if buf is None:
92-
buf = torch.clone(d_p).detach()
106+
buf = torch.clone(float_d_p).detach()
93107
momentum_buffer_list[i] = buf
94108
else:
95-
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
109+
buf.mul_(momentum).add_(float_d_p, alpha=1 - dampening)
96110

97111
if nesterov:
98-
d_p = d_p.add(buf, alpha=momentum)
112+
float_d_p = d_p.add(buf, alpha=momentum)
99113
else:
100-
d_p = buf
101-
if param.dtype == torch.bfloat16 and param in attr:
102-
trail = attr[param]['trail']
103-
torch.ops.torch_ipex.packed_add(param, trail, d_p, alpha=-lr)
114+
float_d_p = buf
115+
116+
if param.dtype is torch.bfloat16:
117+
if float_d_p is not None and float_param is not None:
118+
float_param.add_(float_d_p, alpha=-lr)
119+
top_half, bot_half = torch.ops.torch_ipex.split_float_bfloat16(float_param)
120+
param.copy_(top_half)
121+
trail.copy_(bot_half)
122+
else:
123+
torch.ops.torch_ipex.packed_add(param, trail, d_p, alpha=-lr)
104124
else:
105-
param.add_(d_p, alpha=-lr)
125+
if float_d_p is not None:
126+
param.add_(float_d_p, alpha=-lr)
127+
else:
128+
param.add_(d_p, alpha=-lr)

intel_pytorch_extension_py/optimizer_utils.py

Lines changed: 15 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,75 +14,27 @@ class _ipex_optimizer(torch.optim.Optimizer):
1414
1515
Args:
1616
optimizer: optimized optimizer, contains optimized model's paramerter setting.
17-
weight_params_attr: the prepacked parameters' attrs, to do prepack for corresponding
18-
momentum_buffer or other state according those attrs.
19-
dtype: can be torch.bfloat16 or torch.float32(torch.float), determin doing bfloat16 training
20-
or float training.
17+
params_attr: the parameters' attrs, to cat top_half and bottom(trail) half back to fp32
18+
2119
"""
2220

23-
def __init__(self, optimizer, weight_params_attr, dtype):
24-
if type(optimizer) in IPEX_OPTIMIZER_MAPPING and dtype == torch.bfloat16:
25-
self.optimizer = IPEX_OPTIMIZER_MAPPING[type(optimizer)] (optimizer, weight_params_attr)
21+
def __init__(self, optimizer, params_attr):
22+
if type(optimizer) in IPEX_OPTIMIZER_MAPPING:
23+
self.optimizer = IPEX_OPTIMIZER_MAPPING[type(optimizer)] (optimizer, params_attr)
2624
self.master_weight_split = True
2725
else:
2826
self.optimizer = optimizer
2927
self.master_weight_split = False
30-
self.weight_params_attr = weight_params_attr
28+
self.params_attr = params_attr
3129
self.param_groups = self.optimizer.param_groups
32-
self.dtype = dtype
33-
34-
def state_dict(self):
35-
optimizer_temp = copy.deepcopy(self.optimizer)
36-
weight_params_attr_ = {}
37-
# For bf16 path, the optimizer's params are master weight,
38-
# but self.weight_params_attr's keys are bf16 weight, it hard to
39-
# query the weight's attr, so recreate a dic which using master weight
40-
# as key for easily to query.
41-
if self.dtype == torch.bfloat16 and not self.master_weight_split:
42-
for _, values in self.weight_params_attr.items():
43-
master_weight = values['master_weight']
44-
weight_params_attr_[master_weight] = values
45-
else:
46-
weight_params_attr_ = self.weight_params_attr
47-
48-
for (k1, _), (_, v2) in zip(self.optimizer.state.items(), optimizer_temp.state.items()):
49-
# unpack tensor state using weight's attr.
50-
if k1 in weight_params_attr_:
51-
weight_attr = weight_params_attr_[k1]
52-
for state_key, state_value in v2.items():
53-
if isinstance(state_value, torch.Tensor):
54-
# It covers both conv and linear now. TODO: LSTM or other ops.
55-
if weight_attr['op'] is torch.nn.Conv2d:
56-
if self.master_weight_split and state_value.dtype == torch.bfloat16:
57-
state_value = torch.ops.torch_ipex.cat_bfloat16_float(state_value, weight_attr['trail'])
58-
v2[state_key] = torch.ops.torch_ipex.conv2d_weight_unpack(
59-
state_value,
60-
weight_attr['padding'],
61-
weight_attr['stride'],
62-
weight_attr['dilation'],
63-
weight_attr['kernel_size'],
64-
weight_attr['groups'],
65-
weight_attr['out_channels'],
66-
weight_attr['in_channels'],
67-
weight_attr['weight_channels_last'],
68-
weight_attr['dtype'])
69-
elif weight_attr['op'] is torch.nn.Linear:
70-
if self.master_weight_split and state_value.dtype == torch.bfloat16:
71-
state_value = torch.ops.torch_ipex.cat_bfloat16_float(state_value, weight_attr['trail'])
72-
v2[state_key] = torch.ops.torch_ipex.linear_weight_unpack(
73-
state_value,
74-
weight_attr['out_features'],
75-
weight_attr['in_features'],
76-
weight_attr['weight_transposed'],
77-
weight_attr['dtype'])
78-
return optimizer_temp.state_dict()
30+
self.state = self.optimizer.state
7931

8032
def load_state_dict(self, state_dict):
8133
assert False, "_ipex_optimizer does not suppory load_state_dict"
8234

8335
def zero_grad(self, set_to_none: bool = False):
84-
if self.dtype == torch.bfloat16:
85-
for p in self.weight_params_attr:
36+
if not self.master_weight_split:
37+
for p in self.params_attr:
8638
if p.grad is not None:
8739
if set_to_none:
8840
p.grad = None
@@ -96,14 +48,14 @@ def zero_grad(self, set_to_none: bool = False):
9648
self.optimizer.zero_grad(set_to_none)
9749

9850
def step(self, closure=None):
99-
if self.dtype == torch.bfloat16 and not self.master_weight_split:
51+
if not self.master_weight_split:
10052
# convert bf16 weight'grad to float.
101-
for k, value in self.weight_params_attr.items():
102-
value["master_weight"].grad = k.grad.detach().to(torch.float)
53+
for k, value in self.params_attr.items():
54+
value["master_param"].grad = k.grad.detach().to(torch.float)
10355
loss = self.optimizer.step(closure)
10456
# sync mater weight to model's paramerter
105-
if self.dtype == torch.bfloat16 and not self.master_weight_split:
106-
for k, value in self.weight_params_attr.items():
107-
torch.ops.torch_ipex.sync_master_weight_to_bf16(value["master_weight"], k)
57+
if not self.master_weight_split:
58+
for k, value in self.params_attr.items():
59+
torch.ops.torch_ipex.sync_master_weight_to_bf16(value["master_param"], k)
10860
return loss
10961

intel_pytorch_extension_py/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,10 @@ def optimize(model, dtype=torch.bfloat16, optimizer=None, level='O1', inplace=Fa
8888
pass
8989
elif level == 'O1':
9090
# Do weight prepack, and convert optimizer for training case.
91-
optimized_model, optimized_optimizer, weight_params_attr = _weight_prepack_with_ipex(optimized_model, optimized_optimizer, dtype)
92-
if dtype == torch.bfloat16 and model.training and optimizer is not None:
93-
optimized_model, optimized_optimizer, weight_params_attr = _weight_dtype_convert_with_ipex(optimized_model, optimized_optimizer, weight_params_attr)
94-
if optimizer is not None:
95-
assert model.training, "please call model.train() if you want to convert the optimizer to ipex optimizer."
96-
optimized_optimizer = _ipex_optimizer(optimized_optimizer, weight_params_attr, dtype)
91+
params_attr = {}
92+
if dtype == torch.bfloat16 and model.training:
93+
optimized_model, optimized_optimizer, params_attr = _weight_dtype_convert_with_ipex(optimized_model, optimized_optimizer, params_attr)
94+
optimized_model, optimized_optimizer, params_attr = _weight_prepack_with_ipex(optimized_model, optimized_optimizer, params_attr)
9795
else:
9896
assert False, "Only support level O0 and O1 now for optimize"
9997

0 commit comments

Comments
 (0)