@@ -14,75 +14,27 @@ class _ipex_optimizer(torch.optim.Optimizer):
1414
1515 Args:
1616 optimizer: optimized optimizer, contains optimized model's paramerter setting.
17- weight_params_attr: the prepacked parameters' attrs, to do prepack for corresponding
18- momentum_buffer or other state according those attrs.
19- dtype: can be torch.bfloat16 or torch.float32(torch.float), determin doing bfloat16 training
20- or float training.
17+ params_attr: the parameters' attrs, to cat top_half and bottom(trail) half back to fp32
18+
2119 """
2220
23- def __init__ (self , optimizer , weight_params_attr , dtype ):
24- if type (optimizer ) in IPEX_OPTIMIZER_MAPPING and dtype == torch . bfloat16 :
25- self .optimizer = IPEX_OPTIMIZER_MAPPING [type (optimizer )] (optimizer , weight_params_attr )
21+ def __init__ (self , optimizer , params_attr ):
22+ if type (optimizer ) in IPEX_OPTIMIZER_MAPPING :
23+ self .optimizer = IPEX_OPTIMIZER_MAPPING [type (optimizer )] (optimizer , params_attr )
2624 self .master_weight_split = True
2725 else :
2826 self .optimizer = optimizer
2927 self .master_weight_split = False
30- self .weight_params_attr = weight_params_attr
28+ self .params_attr = params_attr
3129 self .param_groups = self .optimizer .param_groups
32- self .dtype = dtype
33-
34- def state_dict (self ):
35- optimizer_temp = copy .deepcopy (self .optimizer )
36- weight_params_attr_ = {}
37- # For bf16 path, the optimizer's params are master weight,
38- # but self.weight_params_attr's keys are bf16 weight, it hard to
39- # query the weight's attr, so recreate a dic which using master weight
40- # as key for easily to query.
41- if self .dtype == torch .bfloat16 and not self .master_weight_split :
42- for _ , values in self .weight_params_attr .items ():
43- master_weight = values ['master_weight' ]
44- weight_params_attr_ [master_weight ] = values
45- else :
46- weight_params_attr_ = self .weight_params_attr
47-
48- for (k1 , _ ), (_ , v2 ) in zip (self .optimizer .state .items (), optimizer_temp .state .items ()):
49- # unpack tensor state using weight's attr.
50- if k1 in weight_params_attr_ :
51- weight_attr = weight_params_attr_ [k1 ]
52- for state_key , state_value in v2 .items ():
53- if isinstance (state_value , torch .Tensor ):
54- # It covers both conv and linear now. TODO: LSTM or other ops.
55- if weight_attr ['op' ] is torch .nn .Conv2d :
56- if self .master_weight_split and state_value .dtype == torch .bfloat16 :
57- state_value = torch .ops .torch_ipex .cat_bfloat16_float (state_value , weight_attr ['trail' ])
58- v2 [state_key ] = torch .ops .torch_ipex .conv2d_weight_unpack (
59- state_value ,
60- weight_attr ['padding' ],
61- weight_attr ['stride' ],
62- weight_attr ['dilation' ],
63- weight_attr ['kernel_size' ],
64- weight_attr ['groups' ],
65- weight_attr ['out_channels' ],
66- weight_attr ['in_channels' ],
67- weight_attr ['weight_channels_last' ],
68- weight_attr ['dtype' ])
69- elif weight_attr ['op' ] is torch .nn .Linear :
70- if self .master_weight_split and state_value .dtype == torch .bfloat16 :
71- state_value = torch .ops .torch_ipex .cat_bfloat16_float (state_value , weight_attr ['trail' ])
72- v2 [state_key ] = torch .ops .torch_ipex .linear_weight_unpack (
73- state_value ,
74- weight_attr ['out_features' ],
75- weight_attr ['in_features' ],
76- weight_attr ['weight_transposed' ],
77- weight_attr ['dtype' ])
78- return optimizer_temp .state_dict ()
30+ self .state = self .optimizer .state
7931
8032 def load_state_dict (self , state_dict ):
8133 assert False , "_ipex_optimizer does not suppory load_state_dict"
8234
8335 def zero_grad (self , set_to_none : bool = False ):
84- if self .dtype == torch . bfloat16 :
85- for p in self .weight_params_attr :
36+ if not self .master_weight_split :
37+ for p in self .params_attr :
8638 if p .grad is not None :
8739 if set_to_none :
8840 p .grad = None
@@ -96,14 +48,14 @@ def zero_grad(self, set_to_none: bool = False):
9648 self .optimizer .zero_grad (set_to_none )
9749
9850 def step (self , closure = None ):
99- if self . dtype == torch . bfloat16 and not self .master_weight_split :
51+ if not self .master_weight_split :
10052 # convert bf16 weight'grad to float.
101- for k , value in self .weight_params_attr .items ():
102- value ["master_weight " ].grad = k .grad .detach ().to (torch .float )
53+ for k , value in self .params_attr .items ():
54+ value ["master_param " ].grad = k .grad .detach ().to (torch .float )
10355 loss = self .optimizer .step (closure )
10456 # sync mater weight to model's paramerter
105- if self . dtype == torch . bfloat16 and not self .master_weight_split :
106- for k , value in self .weight_params_attr .items ():
107- torch .ops .torch_ipex .sync_master_weight_to_bf16 (value ["master_weight " ], k )
57+ if not self .master_weight_split :
58+ for k , value in self .params_attr .items ():
59+ torch .ops .torch_ipex .sync_master_weight_to_bf16 (value ["master_param " ], k )
10860 return loss
10961
0 commit comments