@@ -52,12 +52,15 @@ template <typename T, int Dims>
52
52
struct CachedData
53
53
{
54
54
static constexpr bool const sync_after_init = true ;
55
- using pointer_type = T *;
55
+ using Shape = sycl::range<Dims>;
56
+ using value_type = T;
57
+ using pointer_type = value_type *;
58
+ static constexpr auto dims = Dims;
56
59
57
- using ncT = typename std::remove_const<T >::type;
60
+ using ncT = typename std::remove_const<value_type >::type;
58
61
using LocalData = sycl::local_accessor<ncT, Dims>;
59
62
60
- CachedData (T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
63
+ CachedData (T *global_data, Shape shape, sycl::handler &cgh)
61
64
{
62
65
this ->global_data = global_data;
63
66
local_data = LocalData (shape, cgh);
@@ -71,13 +74,13 @@ struct CachedData
71
74
template <int _Dims>
72
75
void init (const sycl::nd_item<_Dims> &item) const
73
76
{
74
- int32_t llid = item.get_local_linear_id ();
77
+ uint32_t llid = item.get_local_linear_id ();
75
78
auto local_ptr = &local_data[0 ];
76
- int32_t size = local_data.size ();
79
+ uint32_t size = local_data.size ();
77
80
auto group = item.get_group ();
78
- int32_t local_size = group.get_local_linear_range ();
81
+ uint32_t local_size = group.get_local_linear_range ();
79
82
80
- for (int32_t i = llid; i < size; i += local_size) {
83
+ for (uint32_t i = llid; i < size; i += local_size) {
81
84
local_ptr[i] = global_data[i];
82
85
}
83
86
}
@@ -87,17 +90,30 @@ struct CachedData
87
90
return local_data.size ();
88
91
}
89
92
93
+ T &operator [](const sycl::id<Dims> &id) const
94
+ {
95
+ return local_data[id];
96
+ }
97
+
98
+ template <typename = std::enable_if_t <Dims == 1 >>
99
+ T &operator [](const size_t id) const
100
+ {
101
+ return local_data[id];
102
+ }
103
+
90
104
private:
91
105
LocalData local_data;
92
- T *global_data = nullptr ;
106
+ value_type *global_data = nullptr ;
93
107
};
94
108
95
109
template <typename T, int Dims>
96
110
struct UncachedData
97
111
{
98
112
static constexpr bool const sync_after_init = false ;
99
113
using Shape = sycl::range<Dims>;
100
- using pointer_type = T *;
114
+ using value_type = T;
115
+ using pointer_type = value_type *;
116
+ static constexpr auto dims = Dims;
101
117
102
118
UncachedData (T *global_data, const Shape &shape, sycl::handler &)
103
119
{
@@ -120,6 +136,17 @@ struct UncachedData
120
136
return _shape.size ();
121
137
}
122
138
139
+ T &operator [](const sycl::id<Dims> &id) const
140
+ {
141
+ return global_data[id];
142
+ }
143
+
144
+ template <typename = std::enable_if_t <Dims == 1 >>
145
+ T &operator [](const size_t id) const
146
+ {
147
+ return global_data[id];
148
+ }
149
+
123
150
private:
124
151
T *global_data = nullptr ;
125
152
Shape _shape;
@@ -191,15 +218,15 @@ struct HistWithLocalCopies
191
218
template <int _Dims>
192
219
void finalize (const sycl::nd_item<_Dims> &item) const
193
220
{
194
- int32_t llid = item.get_local_linear_id ();
195
- int32_t bins_count = local_hist.get_range ().get (1 );
196
- int32_t local_hist_count = local_hist.get_range ().get (0 );
221
+ uint32_t llid = item.get_local_linear_id ();
222
+ uint32_t bins_count = local_hist.get_range ().get (1 );
223
+ uint32_t local_hist_count = local_hist.get_range ().get (0 );
197
224
auto group = item.get_group ();
198
- int32_t local_size = group.get_local_linear_range ();
225
+ uint32_t local_size = group.get_local_linear_range ();
199
226
200
- for (int32_t i = llid; i < bins_count; i += local_size) {
227
+ for (uint32_t i = llid; i < bins_count; i += local_size) {
201
228
auto value = local_hist[0 ][i];
202
- for (int32_t lhc = 1 ; lhc < local_hist_count; ++lhc) {
229
+ for (uint32_t lhc = 1 ; lhc < local_hist_count; ++lhc) {
203
230
value += local_hist[lhc][i];
204
231
}
205
232
if (value != T (0 )) {
@@ -290,9 +317,9 @@ class histogram_kernel;
290
317
291
318
template <typename T, typename HistImpl, typename Edges, typename Weights>
292
319
void submit_histogram (const T *in,
293
- size_t size,
294
- size_t dims,
295
- uint32_t WorkPI,
320
+ const size_t size,
321
+ const size_t dims,
322
+ const uint32_t WorkPI,
296
323
const HistImpl &hist,
297
324
const Edges &edges,
298
325
const Weights &weights,
0 commit comments