@@ -35,77 +35,77 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
35
35
36
36
struct SumFunctor {
37
37
template <typename DeviceContext, typename X, typename Y, typename Dim>
38
- void operator ()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
39
- y. device (place) = x. sum (dim);
38
+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
39
+ y-> device (place) = x-> sum (dim);
40
40
}
41
41
};
42
42
43
43
struct SumGradFunctor {
44
44
template <typename DeviceContext, typename X, typename Y, typename DX,
45
45
typename DY, typename Dim>
46
- void operator ()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
46
+ void operator ()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
47
47
const Dim& dim, int size) {
48
- dx. device (place) = dy. broadcast (dim);
48
+ dx-> device (place) = dy-> broadcast (dim);
49
49
}
50
50
};
51
51
52
52
struct MeanFunctor {
53
53
template <typename DeviceContext, typename X, typename Y, typename Dim>
54
- void operator ()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
55
- y. device (place) = x. mean (dim);
54
+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
55
+ y-> device (place) = x-> mean (dim);
56
56
}
57
57
};
58
58
59
59
struct MeanGradFunctor {
60
60
template <typename DeviceContext, typename X, typename Y, typename DX,
61
61
typename DY, typename Dim>
62
- void operator ()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
62
+ void operator ()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
63
63
const Dim& dim, int size) {
64
- dx. device (place) = dy. broadcast (dim) / dx. constant (size);
64
+ dx-> device (place) = dy-> broadcast (dim) / dx-> constant (size);
65
65
}
66
66
};
67
67
68
68
struct MaxFunctor {
69
69
template <typename DeviceContext, typename X, typename Y, typename Dim>
70
- void operator ()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
71
- y. device (place) = x. maximum (dim);
70
+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
71
+ y-> device (place) = x-> maximum (dim);
72
72
}
73
73
};
74
74
75
75
struct MinFunctor {
76
76
template <typename DeviceContext, typename X, typename Y, typename Dim>
77
- void operator ()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
78
- y. device (place) = x. minimum (dim);
77
+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
78
+ y-> device (place) = x-> minimum (dim);
79
79
}
80
80
};
81
81
82
82
struct MaxOrMinGradFunctor {
83
83
template <typename DeviceContext, typename X, typename Y, typename DX,
84
84
typename DY, typename Dim>
85
- void operator ()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
85
+ void operator ()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
86
86
const Dim& dim, int size) {
87
- auto equals = x == y. broadcast (dim);
88
- auto ones = dx. constant (1 );
89
- auto zeros = dx. constant (0 );
87
+ auto equals = (*x) == y-> broadcast (dim);
88
+ auto ones = dx-> constant (1 );
89
+ auto zeros = dx-> constant (0 );
90
90
// If there are multiple minimum or maximum elements, the subgradient of
91
91
// each is the set [0, 1], and we pass gradient to all of them here.
92
- dx. device (place) = dy. broadcast (dim) * equals.select (ones, zeros);
92
+ dx-> device (place) = dy-> broadcast (dim) * equals.select (ones, zeros);
93
93
}
94
94
};
95
95
96
96
struct ProdFunctor {
97
97
template <typename DeviceContext, typename X, typename Y, typename Dim>
98
- void operator ()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
99
- y. device (place) = x. prod (dim);
98
+ void operator ()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
99
+ y-> device (place) = x-> prod (dim);
100
100
}
101
101
};
102
102
103
103
struct ProdGradFunctor {
104
104
template <typename DeviceContext, typename X, typename Y, typename DX,
105
105
typename DY, typename Dim>
106
- void operator ()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
106
+ void operator ()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
107
107
const Dim& dim, int size) {
108
- dx. device (place) = dy. broadcast (dim) * y. broadcast (dim) * x. inverse ();
108
+ dx-> device (place) = dy-> broadcast (dim) * y-> broadcast (dim) * x-> inverse ();
109
109
}
110
110
};
111
111
@@ -125,7 +125,7 @@ class ReduceKernel : public framework::OpKernel<T> {
125
125
*context.template device_context <DeviceContext>().eigen_device ();
126
126
auto reduce_dim = Eigen::array<int , 1 >({{0 }});
127
127
Functor functor;
128
- functor (place, x, out, reduce_dim);
128
+ functor (place, & x, & out, reduce_dim);
129
129
} else {
130
130
int rank = context.Input <Tensor>(" X" )->dims ().size ();
131
131
switch (rank) {
@@ -178,10 +178,10 @@ class ReduceKernel : public framework::OpKernel<T> {
178
178
179
179
if (D == 1 ) {
180
180
auto out = EigenScalar<T>::From (*output);
181
- functor (place, x, out, reduce_dim);
181
+ functor (place, & x, & out, reduce_dim);
182
182
} else {
183
183
auto out = EigenTensor<T, (D - 1 )>::From (*output, dims);
184
- functor (place, x, out, reduce_dim);
184
+ functor (place, & x, & out, reduce_dim);
185
185
}
186
186
}
187
187
};
@@ -206,7 +206,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
206
206
auto broadcast_dim =
207
207
Eigen::array<int , 1 >({{static_cast <int >(input0->numel ())}});
208
208
Functor functor;
209
- functor (place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
209
+ functor (place, & x, & x_reduce, & x_grad, & x_reduce_grad, broadcast_dim,
210
210
broadcast_dim[0 ]);
211
211
} else {
212
212
int rank = context.Input <Tensor>(" X" )->dims ().size ();
@@ -258,7 +258,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
258
258
auto & place =
259
259
*context.template device_context <DeviceContext>().eigen_device ();
260
260
Functor functor;
261
- functor (place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
261
+ functor (place, & x, & x_reduce, & x_grad, & x_reduce_grad, broadcast_dim,
262
262
broadcast_dim[dim]);
263
263
}
264
264
};
0 commit comments