@@ -88,57 +88,6 @@ struct MergeAdd {
88
88
framework::SelectedRows* output);
89
89
};
90
90
91
- template <typename DeviceContext, typename T>
92
- struct Add {
93
- framework::SelectedRows operator ()(const DeviceContext& context,
94
- const framework::SelectedRows& input1,
95
- const framework::SelectedRows& input2) {
96
- framework::SelectedRows out;
97
- out.set_rows (input1.rows ());
98
- out.set_height (input1.height ());
99
- out.mutable_value ()->mutable_data <T>(input1.value ().dims (),
100
- context.GetPlace ());
101
- auto e_out = framework::EigenVector<T>::Flatten (*(out.mutable_value ()));
102
- auto e_in1 = framework::EigenVector<T>::Flatten (input1.value ());
103
- auto e_in2 = framework::EigenVector<T>::Flatten (input2.value ());
104
- e_out.device (*context.eigen_device ()) = e_in1 + e_in2;
105
- return out;
106
- }
107
- };
108
-
109
- template <typename DeviceContext, typename T>
110
- struct Mul {
111
- // multiply two SelectedRows
112
- framework::SelectedRows operator ()(const DeviceContext& context,
113
- const framework::SelectedRows& input1,
114
- const framework::SelectedRows& input2) {
115
- framework::SelectedRows out;
116
- out.set_rows (input1.rows ());
117
- out.set_height (input1.height ());
118
- out.mutable_value ()->mutable_data <T>(input1.value ().dims (),
119
- context.GetPlace ());
120
- auto e_out = framework::EigenVector<T>::Flatten (*(out.mutable_value ()));
121
- auto e_in1 = framework::EigenVector<T>::Flatten (input1.value ());
122
- auto e_in2 = framework::EigenVector<T>::Flatten (input2.value ());
123
- e_out.device (*context.eigen_device ()) = e_in1 * e_in2;
124
- return out;
125
- }
126
- // multiply scalar to SelectedRows
127
- framework::SelectedRows operator ()(const DeviceContext& context,
128
- const framework::SelectedRows& input1,
129
- const T input2) {
130
- framework::SelectedRows out;
131
- out.set_rows (input1.rows ());
132
- out.set_height (input1.height ());
133
- out.mutable_value ()->mutable_data <T>(input1.value ().dims (),
134
- context.GetPlace ());
135
- auto e_out = framework::EigenVector<T>::Flatten (*(out.mutable_value ()));
136
- auto e_in1 = framework::EigenVector<T>::Flatten (input1.value ());
137
- e_out.device (*context.eigen_device ()) = input2 * e_in1;
138
- return out;
139
- }
140
- };
141
-
142
91
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
143
92
144
93
// out = seleted_rows_in / tensor
0 commit comments