Skip to content

Commit eac2c3c

Browse files
authored
Merge pull request #8505 from emailweixu/math_op
Correctly handling variable with batch dimension for math ops.
2 parents 0d878e4 + e9b8ebf commit eac2c3c

File tree

10 files changed

+167
-123
lines changed

10 files changed

+167
-123
lines changed

paddle/fluid/framework/ddim.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
2626
}
2727

2828
template <>
29-
Dim<1> make_dim<1>(const int64_t* d) {
30-
return Dim<1>(*d);
29+
Dim<0> make_dim<0>(const int64_t* d) {
30+
return Dim<0>(*d);
3131
}
3232

3333
void make_ddim(DDim& ddim, const int64_t* dims, int n) {
3434
switch (n) {
35+
case 0:
36+
ddim = make_dim<0>(dims);
37+
break;
3538
case 1:
3639
ddim = make_dim<1>(dims);
3740
break;
@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
190193
this->operator()(t.tail);
191194
}
192195

193-
void operator()(const Dim<1>& t) { vector.push_back(t.head); }
196+
void operator()(const Dim<0>& t) {}
194197
};
195198
/// @endcond
196199

@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
247250
}
248251
}
249252

250-
void operator()(const Dim<1>& dim) {
251-
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound.");
252-
vector.push_back(dim.head);
253+
void operator()(const Dim<0>& dim) {
254+
PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
253255
}
254256
};
255257

paddle/fluid/framework/ddim.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ namespace framework {
3030
* The number of dimensions must be between [1, 9].
3131
*/
3232
struct DDim {
33-
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>,
34-
Dim<8>, Dim<9>>
33+
typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
34+
Dim<7>, Dim<8>, Dim<9>>
3535
DDimVar;
3636
DDimVar var;
3737

paddle/fluid/framework/dim.h

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,38 +72,36 @@ struct Dim {
7272

7373
// Base case specialization
7474
template <>
75-
struct Dim<1> {
76-
static constexpr int dimensions = 1;
75+
struct Dim<0> {
76+
static constexpr int dimensions = 0;
7777

7878
HOSTDEVICE
79-
Dim(int64_t _head) : head(_head) {}
79+
Dim(int64_t _head) {}
8080

8181
HOSTDEVICE
82-
Dim() : head(0) {}
82+
Dim() {}
8383

8484
HOSTDEVICE
85-
Dim(int idx, const Dim<1>& size) : head(idx) {
85+
Dim(int idx, const Dim<0>& size) {
8686
#ifndef __CUDA_ARCH__
87-
if (idx >= size.head) {
87+
if (idx > 0) {
8888
throw std::invalid_argument("Index out of range.");
8989
}
9090
#else
91-
PADDLE_ASSERT(idx < size.head);
91+
PADDLE_ASSERT(idx == 0);
9292
#endif
9393
}
9494

9595
HOSTDEVICE
96-
bool operator==(const Dim<1>& o) const { return (head == o.head); }
96+
bool operator==(const Dim<0>& o) const { return true; }
9797

9898
HOSTDEVICE
99-
bool operator!=(const Dim<1>& o) const { return !(*this == o); }
99+
bool operator!=(const Dim<0>& o) const { return false; }
100100

101101
HOSTDEVICE
102102
int64_t& operator[](int idx);
103103
HOSTDEVICE
104104
int64_t operator[](int idx) const;
105-
106-
int64_t head;
107105
};
108106

109107
namespace {
@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
154152
}
155153

156154
template <>
157-
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
155+
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
158156
#ifndef __CUDA_ARCH__
159-
if (idx != 0) {
160-
throw std::invalid_argument("Invalid index");
161-
}
157+
throw std::invalid_argument("Invalid index");
162158
#else
163-
PADDLE_ASSERT(idx == 0);
159+
PADDLE_ASSERT(false);
164160
#endif
165-
return dim.head;
161+
static int64_t head = 0;
162+
return head;
166163
}
167164

168165
template <int D>
@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
181178
}
182179

183180
template <>
184-
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
181+
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
185182
#ifndef __CUDA_ARCH__
186-
if (idx != 0) {
187-
throw std::invalid_argument("Invalid index");
188-
}
183+
throw std::invalid_argument("Invalid index");
189184
#else
190-
PADDLE_ASSERT(idx == 0);
185+
PADDLE_ASSERT(false);
191186
#endif
192-
return dim.head;
187+
static int64_t head = 0;
188+
return head;
193189
}
194190

195191
} // namespace
@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
218214
}
219215

220216
// Dynamic access to constant Dim
221-
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
217+
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
222218
return indexer(*this, i);
223219
}
224220

225221
// Dynamic access to mutable Dim
226-
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
222+
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
227223
return indexer(*this, i);
228224
}
229225

@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
251247
// Base case dot product of two Dims
252248
// Notice it is inline because it is no longer a template
253249
template <>
254-
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
255-
return a.head * b.head;
250+
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
251+
return 0;
256252
}
257253

258254
// Product of a Dim
@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
264260
// Base case product of a Dim
265261
// Notice it is inline because it is no longer a template
266262
template <>
267-
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
268-
return prod * a.head;
263+
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
264+
return prod;
269265
}
270266

271267
// Is 0 <= idx_i < size_i for all i?
@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
278274
// Base case of is 0 <= idx_i < size_i ?
279275
// Notice it is inline because it is no longer a template
280276
template <>
281-
HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
282-
return ((0 <= idx.head) && (idx.head < size.head));
277+
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
278+
return true;
283279
}
284280

285281
/**
@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
294290
// Base case of ex_prefix_mul
295291
// Notice it is inline because it is no longer a template
296292
template <>
297-
HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
298-
return Dim<1>(mul);
293+
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
294+
return Dim<0>();
299295
}
300296
///\endcond
301297

@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
309305

310306
// Base case
311307
template <>
312-
HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) {
313-
return Dim<1>(a.head + b.head);
308+
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
309+
return Dim<0>();
314310
}
315311

316312
template <int i>
@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
328324

329325
// Base case
330326
template <>
331-
HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) {
332-
return Dim<1>(a.head * b.head);
327+
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
328+
return Dim<0>();
333329
}
334330

335331
template <int i>
@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
356352
///\cond HIDDEN
357353

358354
template <>
359-
HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size,
360-
const Dim<1>& stride) {
361-
int norm_stride = size.head == 1 ? 0 : stride.head;
362-
return Dim<1>(norm_stride);
355+
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
356+
const Dim<0>& stride) {
357+
return Dim<0>();
363358
}
364359

365360
///\endcond
@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
394389
return os;
395390
}
396391

392+
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
393+
return os;
394+
}
395+
397396
template <int i>
398397
HOST std::string Dim<i>::to_string() const {
399398
std::stringstream stream;

paddle/fluid/operators/detail/strided_memcpy.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,29 @@ namespace detail {
2424
template <typename T, int Rank>
2525
struct StridedMemcpyFunctor;
2626

27+
template <typename T>
28+
struct StridedMemcpyFunctor<T, 0> {
29+
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
30+
framework::Dim<0> src_stride, framework::Dim<0> dst_dim,
31+
framework::Dim<0> dst_stride, T* dst) const {
32+
auto place = dev_ctx.GetPlace();
33+
if (platform::is_cpu_place(place)) {
34+
auto& cpu_place = boost::get<platform::CPUPlace>(place);
35+
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T));
36+
} else {
37+
#ifdef PADDLE_WITH_CUDA
38+
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
39+
auto& cuda_ctx =
40+
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
41+
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T),
42+
cuda_ctx.stream());
43+
#else
44+
PADDLE_THROW("Paddle is not compiled with GPU");
45+
#endif
46+
}
47+
}
48+
};
49+
2750
template <typename T>
2851
struct StridedMemcpyFunctor<T, 1> {
2952
void operator()(const platform::DeviceContext& dev_ctx, const T* src,

paddle/fluid/operators/elementwise_op.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,17 @@ smaller than or equal to the dimensions of $X$.
6565
6666
There are two cases for this operator:
6767
1. The shape of $Y$ is same with $X$;
68-
2. The shape of $Y$ is a subset of $X$.
68+
2. The shape of $Y$ is a congiguous subsequencet of $X$. The trailing dimensions
69+
of size 1 for $Y$ will be ignored for the consideration of subsequence.
70+
6971
7072
For case 2:
73+
7174
$Y$ will be broadcasted to match the shape of $X$ and axis should be
7275
set to index of the start dimension to broadcast $Y$ onto $X$.
7376
77+
If axis is -1, it is treated as axis=rank(X)-rank(Y).
78+
7479
For example
7580
.. code-block:: python
7681
@@ -79,6 +84,7 @@ For example
7984
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
8085
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
8186
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
87+
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
8288
8389
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
8490
information. However, the output only shares the LoD information with input $X$.

0 commit comments

Comments
 (0)