@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
71
71
{" adam" ,
72
72
{" Param" , " Grad" , " LearningRate" , " Moment1" , " Moment2" , " Beta1Pow" ,
73
73
" Beta2Pow" , " MasterParam" }},
74
+ {" adamw" ,
75
+ {" Param" , " Grad" , " LearningRate" , " Moment1" , " Moment2" , " Beta1Pow" ,
76
+ " Beta2Pow" , " MasterParam" }},
74
77
};
75
78
76
79
// NOTE(zhiqiu): Like op_ins_map.
@@ -110,6 +113,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
110
113
{" adam" ,
111
114
{" ParamOut" , " Moment1Out" , " Moment2Out" , " Beta1PowOut" , " Beta2PowOut" ,
112
115
" MasterParamOut" }},
116
+ {" adamw" ,
117
+ {" ParamOut" , " Moment1Out" , " Moment2Out" , " Beta1PowOut" , " Beta2PowOut" ,
118
+ " MasterParamOut" }},
113
119
};
114
120
115
121
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
@@ -129,7 +135,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
129
135
{" ParamOut" , " Moment1Out" , " Moment2Out" , " Beta1PowOut" , " Beta2PowOut" ,
130
136
" MasterParamOut" }},
131
137
{" adamw" ,
132
- {" ParamOut" , " Moment1Out" , " Moment2Out" , " Beta1PowOut" , " Beta2PowOut" }},
138
+ {" ParamOut" , " Moment1Out" , " Moment2Out" , " Beta1PowOut" , " Beta2PowOut" ,
139
+ " MasterParamOut" }},
133
140
{" average_accumulates" ,
134
141
{" out_sum_1" , " out_sum_2" , " out_sum_3" , " out_num_accumulates" ,
135
142
" out_old_num_accumulates" , " out_num_updates" }},
0 commit comments