@@ -72,38 +72,36 @@ struct Dim {
72
72
73
73
// Base case specialization
74
74
template <>
75
- struct Dim <1 > {
76
- static constexpr int dimensions = 1 ;
75
+ struct Dim <0 > {
76
+ static constexpr int dimensions = 0 ;
77
77
78
78
HOSTDEVICE
79
- Dim (int64_t _head) : head(_head) {}
79
+ Dim (int64_t _head) {}
80
80
81
81
HOSTDEVICE
82
- Dim () : head( 0 ) {}
82
+ Dim () {}
83
83
84
84
HOSTDEVICE
85
- Dim (int idx, const Dim<1 >& size) : head(idx ) {
85
+ Dim (int idx, const Dim<0 >& size) {
86
86
#ifndef __CUDA_ARCH__
87
- if (idx >= size. head ) {
87
+ if (idx > 0 ) {
88
88
throw std::invalid_argument (" Index out of range." );
89
89
}
90
90
#else
91
- PADDLE_ASSERT (idx < size. head );
91
+ PADDLE_ASSERT (idx == 0 );
92
92
#endif
93
93
}
94
94
95
95
HOSTDEVICE
96
- bool operator ==(const Dim<1 >& o) const { return (head == o. head ) ; }
96
+ bool operator ==(const Dim<0 >& o) const { return true ; }
97
97
98
98
HOSTDEVICE
99
- bool operator !=(const Dim<1 >& o) const { return !(* this == o) ; }
99
+ bool operator !=(const Dim<0 >& o) const { return false ; }
100
100
101
101
HOSTDEVICE
102
102
int64_t & operator [](int idx);
103
103
HOSTDEVICE
104
104
int64_t operator [](int idx) const ;
105
-
106
- int64_t head;
107
105
};
108
106
109
107
namespace {
@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
154
152
}
155
153
156
154
template <>
157
- HOSTDEVICE int64_t & indexer<1 >(Dim<1 >& dim, int idx) {
155
+ HOSTDEVICE int64_t & indexer<0 >(Dim<0 >& dim, int idx) {
158
156
#ifndef __CUDA_ARCH__
159
- if (idx != 0 ) {
160
- throw std::invalid_argument (" Invalid index" );
161
- }
157
+ throw std::invalid_argument (" Invalid index" );
162
158
#else
163
- PADDLE_ASSERT (idx == 0 );
159
+ PADDLE_ASSERT (false );
164
160
#endif
165
- return dim.head ;
161
+ static int64_t head = 0 ;
162
+ return head;
166
163
}
167
164
168
165
template <int D>
@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
181
178
}
182
179
183
180
template <>
184
- HOSTDEVICE int64_t indexer<1 >(const Dim<1 >& dim, int idx) {
181
+ HOSTDEVICE int64_t indexer<0 >(const Dim<0 >& dim, int idx) {
185
182
#ifndef __CUDA_ARCH__
186
- if (idx != 0 ) {
187
- throw std::invalid_argument (" Invalid index" );
188
- }
183
+ throw std::invalid_argument (" Invalid index" );
189
184
#else
190
- PADDLE_ASSERT (idx == 0 );
185
+ PADDLE_ASSERT (false );
191
186
#endif
192
- return dim.head ;
187
+ static int64_t head = 0 ;
188
+ return head;
193
189
}
194
190
195
191
} // namespace
@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
218
214
}
219
215
220
216
// Dynamic access to constant Dim
221
- inline HOSTDEVICE int64_t Dim<1 >::operator [](int i) const {
217
+ inline HOSTDEVICE int64_t Dim<0 >::operator [](int i) const {
222
218
return indexer (*this , i);
223
219
}
224
220
225
221
// Dynamic access to mutable Dim
226
- inline HOSTDEVICE int64_t & Dim<1 >::operator [](int i) {
222
+ inline HOSTDEVICE int64_t & Dim<0 >::operator [](int i) {
227
223
return indexer (*this , i);
228
224
}
229
225
@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
251
247
// Base case dot product of two Dims
252
248
// Notice it is inline because it is no longer a template
253
249
template <>
254
- HOSTDEVICE inline int64_t linearize (const Dim<1 >& a, const Dim<1 >& b) {
255
- return a. head * b. head ;
250
+ HOSTDEVICE inline int64_t linearize (const Dim<0 >& a, const Dim<0 >& b) {
251
+ return 0 ;
256
252
}
257
253
258
254
// Product of a Dim
@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
264
260
// Base case product of a Dim
265
261
// Notice it is inline because it is no longer a template
266
262
template <>
267
- HOSTDEVICE inline int64_t product (const Dim<1 >& a, int prod) {
268
- return prod * a. head ;
263
+ HOSTDEVICE inline int64_t product (const Dim<0 >& a, int prod) {
264
+ return prod;
269
265
}
270
266
271
267
// Is 0 <= idx_i < size_i for all i?
@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
278
274
// Base case of is 0 <= idx_i < size_i ?
279
275
// Notice it is inline because it is no longer a template
280
276
template <>
281
- HOSTDEVICE inline bool contained (const Dim<1 >& idx, const Dim<1 >& size) {
282
- return (( 0 <= idx. head ) && (idx. head < size. head )) ;
277
+ HOSTDEVICE inline bool contained (const Dim<0 >& idx, const Dim<0 >& size) {
278
+ return true ;
283
279
}
284
280
285
281
/* *
@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
294
290
// Base case of ex_prefix_mul
295
291
// Notice it is inline because it is no longer a template
296
292
template <>
297
- HOSTDEVICE inline Dim<1 > ex_prefix_mul (const Dim<1 >& src, int mul) {
298
- return Dim<1 >(mul );
293
+ HOSTDEVICE inline Dim<0 > ex_prefix_mul (const Dim<0 >& src, int mul) {
294
+ return Dim<0 >( );
299
295
}
300
296
// /\endcond
301
297
@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
309
305
310
306
// Base case
311
307
template <>
312
- HOSTDEVICE inline Dim<1 > dim_plus (const Dim<1 >& a, const Dim<1 >& b) {
313
- return Dim<1 >(a. head + b. head );
308
+ HOSTDEVICE inline Dim<0 > dim_plus (const Dim<0 >& a, const Dim<0 >& b) {
309
+ return Dim<0 >( );
314
310
}
315
311
316
312
template <int i>
@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
328
324
329
325
// Base case
330
326
template <>
331
- HOSTDEVICE inline Dim<1 > dim_mult (const Dim<1 >& a, const Dim<1 >& b) {
332
- return Dim<1 >(a. head * b. head );
327
+ HOSTDEVICE inline Dim<0 > dim_mult (const Dim<0 >& a, const Dim<0 >& b) {
328
+ return Dim<0 >( );
333
329
}
334
330
335
331
template <int i>
@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
356
352
// /\cond HIDDEN
357
353
358
354
template <>
359
- HOSTDEVICE inline Dim<1 > normalize_strides (const Dim<1 >& size,
360
- const Dim<1 >& stride) {
361
- int norm_stride = size.head == 1 ? 0 : stride.head ;
362
- return Dim<1 >(norm_stride);
355
+ HOSTDEVICE inline Dim<0 > normalize_strides (const Dim<0 >& size,
356
+ const Dim<0 >& stride) {
357
+ return Dim<0 >();
363
358
}
364
359
365
360
// /\endcond
@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
394
389
return os;
395
390
}
396
391
392
+ inline std::ostream& operator <<(std::ostream& os, const Dim<0 >& d) {
393
+ return os;
394
+ }
395
+
397
396
template <int i>
398
397
HOST std::string Dim<i>::to_string() const {
399
398
std::stringstream stream;
0 commit comments