1
+ #include < ATen/AccumulateType.h>
1
2
#include < ATen/Functions.h>
2
3
#include < ATen/TensorUtils.h>
3
4
#include < ATen/core/Reduction.h>
@@ -126,7 +127,7 @@ struct NllLossForwardReduce1DKernelFunctor {
126
127
int64_t reduction;
127
128
};
128
129
129
- template <typename scalar_t , typename index_t >
130
+ template <typename scalar_t , typename index_t , typename accscalar_t >
130
131
struct NllLossForwardReduce2DKernelFunctor
131
132
: public __SYCL_KER_CONFIG_CONVENTION__ {
132
133
void operator ()(sycl::nd_item<1 > item_id) const {
@@ -136,17 +137,18 @@ struct NllLossForwardReduce2DKernelFunctor
136
137
auto total_weight_ptr = total_weight_data;
137
138
auto output_ptr = output_data;
138
139
int64_t local_id = item_id.get_local_id (0 );
139
- local_output_acc[local_id] = 0.0 ;
140
- local_total_weight_acc[local_id] = 0.0 ;
140
+ local_output_acc[local_id] = accscalar_t ( 0 ) ;
141
+ local_total_weight_acc[local_id] = accscalar_t ( 0 ) ;
141
142
for (int i = local_id; i < batch_size; i += local_size) {
142
143
int cur_target = target_ptr[i];
143
144
if (cur_target != ignore_index) {
144
145
scalar_t cur_weight =
145
146
has_weight ? weight_ptr[cur_target] : static_cast <scalar_t >(1 .0f );
146
- local_total_weight_acc[local_id] += cur_weight;
147
+ local_total_weight_acc[local_id] +=
148
+ static_cast <accscalar_t >(cur_weight);
147
149
local_output_acc[local_id] -=
148
- static_cast <scalar_t >(input_ptr[i * n_target + cur_target]) *
149
- static_cast <scalar_t >(cur_weight);
150
+ static_cast <accscalar_t >(input_ptr[i * n_target + cur_target]) *
151
+ static_cast <accscalar_t >(cur_weight);
150
152
}
151
153
}
152
154
@@ -161,11 +163,13 @@ struct NllLossForwardReduce2DKernelFunctor
161
163
}
162
164
item_id.barrier (sycl_global_and_local_fence);
163
165
164
- output_ptr[0 ] = local_output_acc[0 ];
165
- total_weight_ptr[0 ] = local_total_weight_acc[0 ];
166
166
if (reduction == at::Reduction::Mean) {
167
- output_ptr[0 ] /= total_weight_ptr[0 ];
167
+ output_ptr[0 ] = static_cast <scalar_t >(
168
+ local_output_acc[0 ] / local_total_weight_acc[0 ]);
169
+ } else {
170
+ output_ptr[0 ] = static_cast <scalar_t >(local_output_acc[0 ]);
168
171
}
172
+ total_weight_ptr[0 ] = static_cast <scalar_t >(local_total_weight_acc[0 ]);
169
173
}
170
174
NllLossForwardReduce2DKernelFunctor (
171
175
scalar_t * input_data_,
@@ -192,8 +196,8 @@ struct NllLossForwardReduce2DKernelFunctor
192
196
reduction(reduction_) {}
193
197
194
198
void sycl_ker_config_convention (sycl::handler& cgh) {
195
- local_output_acc = sycl_local_acc_t <scalar_t >(local_size, cgh);
196
- local_total_weight_acc = sycl_local_acc_t <scalar_t >(local_size, cgh);
199
+ local_output_acc = sycl_local_acc_t <accscalar_t >(local_size, cgh);
200
+ local_total_weight_acc = sycl_local_acc_t <accscalar_t >(local_size, cgh);
197
201
}
198
202
199
203
private:
@@ -207,8 +211,8 @@ struct NllLossForwardReduce2DKernelFunctor
207
211
int64_t local_size;
208
212
int64_t ignore_index;
209
213
int n_target;
210
- sycl_local_acc_t <scalar_t > local_output_acc;
211
- sycl_local_acc_t <scalar_t > local_total_weight_acc;
214
+ sycl_local_acc_t <accscalar_t > local_output_acc;
215
+ sycl_local_acc_t <accscalar_t > local_total_weight_acc;
212
216
int64_t reduction;
213
217
};
214
218
@@ -309,8 +313,9 @@ void nll_loss_forward_template(
309
313
310
314
sycl_kernel_submit (sycl::range<1 >(local_size), queue, kfn);
311
315
} else if (input_cont.dim () == 2 ) {
316
+ using accscalar_t = at::acc_type<scalar_t , true >;
312
317
using NllLossForwardReduce2DKernel =
313
- NllLossForwardReduce2DKernelFunctor<scalar_t , index_t >;
318
+ NllLossForwardReduce2DKernelFunctor<scalar_t , index_t , accscalar_t >;
314
319
315
320
int64_t batch_size = input.size (0 );
316
321
int n_target = input.size (1 );
@@ -322,7 +327,7 @@ void nll_loss_forward_template(
322
327
auto target_data = _target_data;
323
328
auto total_weight_data = _total_weight_data;
324
329
auto output_data = _output_data;
325
- NllLossForwardReduce2DKernelFunctor<scalar_t , index_t > kfn (
330
+ NllLossForwardReduce2DKernelFunctor<scalar_t , index_t , accscalar_t > kfn (
326
331
input_data,
327
332
target_data,
328
333
weight_data,
0 commit comments