Skip to content

Commit ed3442d

Browse files
jianyizhCopilot
andauthored
fix NllLossForwardReduce2DKernelFunctor accuracy (#1868)
follow cuda and make NllLossForwardReduce2DKernelFunctor more accurate using accumulate type --------- Co-authored-by: Copilot <[email protected]>
1 parent db80b86 commit ed3442d

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/ATen/native/xpu/sycl/LossNLLKernel.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/AccumulateType.h>
12
#include <ATen/Functions.h>
23
#include <ATen/TensorUtils.h>
34
#include <ATen/core/Reduction.h>
@@ -126,7 +127,7 @@ struct NllLossForwardReduce1DKernelFunctor {
126127
int64_t reduction;
127128
};
128129

129-
template <typename scalar_t, typename index_t>
130+
template <typename scalar_t, typename index_t, typename accscalar_t>
130131
struct NllLossForwardReduce2DKernelFunctor
131132
: public __SYCL_KER_CONFIG_CONVENTION__ {
132133
void operator()(sycl::nd_item<1> item_id) const {
@@ -136,17 +137,18 @@ struct NllLossForwardReduce2DKernelFunctor
136137
auto total_weight_ptr = total_weight_data;
137138
auto output_ptr = output_data;
138139
int64_t local_id = item_id.get_local_id(0);
139-
local_output_acc[local_id] = 0.0;
140-
local_total_weight_acc[local_id] = 0.0;
140+
local_output_acc[local_id] = accscalar_t(0);
141+
local_total_weight_acc[local_id] = accscalar_t(0);
141142
for (int i = local_id; i < batch_size; i += local_size) {
142143
int cur_target = target_ptr[i];
143144
if (cur_target != ignore_index) {
144145
scalar_t cur_weight =
145146
has_weight ? weight_ptr[cur_target] : static_cast<scalar_t>(1.0f);
146-
local_total_weight_acc[local_id] += cur_weight;
147+
local_total_weight_acc[local_id] +=
148+
static_cast<accscalar_t>(cur_weight);
147149
local_output_acc[local_id] -=
148-
static_cast<scalar_t>(input_ptr[i * n_target + cur_target]) *
149-
static_cast<scalar_t>(cur_weight);
150+
static_cast<accscalar_t>(input_ptr[i * n_target + cur_target]) *
151+
static_cast<accscalar_t>(cur_weight);
150152
}
151153
}
152154

@@ -161,11 +163,13 @@ struct NllLossForwardReduce2DKernelFunctor
161163
}
162164
item_id.barrier(sycl_global_and_local_fence);
163165

164-
output_ptr[0] = local_output_acc[0];
165-
total_weight_ptr[0] = local_total_weight_acc[0];
166166
if (reduction == at::Reduction::Mean) {
167-
output_ptr[0] /= total_weight_ptr[0];
167+
output_ptr[0] = static_cast<scalar_t>(
168+
local_output_acc[0] / local_total_weight_acc[0]);
169+
} else {
170+
output_ptr[0] = static_cast<scalar_t>(local_output_acc[0]);
168171
}
172+
total_weight_ptr[0] = static_cast<scalar_t>(local_total_weight_acc[0]);
169173
}
170174
NllLossForwardReduce2DKernelFunctor(
171175
scalar_t* input_data_,
@@ -192,8 +196,8 @@ struct NllLossForwardReduce2DKernelFunctor
192196
reduction(reduction_) {}
193197

194198
void sycl_ker_config_convention(sycl::handler& cgh) {
195-
local_output_acc = sycl_local_acc_t<scalar_t>(local_size, cgh);
196-
local_total_weight_acc = sycl_local_acc_t<scalar_t>(local_size, cgh);
199+
local_output_acc = sycl_local_acc_t<accscalar_t>(local_size, cgh);
200+
local_total_weight_acc = sycl_local_acc_t<accscalar_t>(local_size, cgh);
197201
}
198202

199203
private:
@@ -207,8 +211,8 @@ struct NllLossForwardReduce2DKernelFunctor
207211
int64_t local_size;
208212
int64_t ignore_index;
209213
int n_target;
210-
sycl_local_acc_t<scalar_t> local_output_acc;
211-
sycl_local_acc_t<scalar_t> local_total_weight_acc;
214+
sycl_local_acc_t<accscalar_t> local_output_acc;
215+
sycl_local_acc_t<accscalar_t> local_total_weight_acc;
212216
int64_t reduction;
213217
};
214218

@@ -309,8 +313,9 @@ void nll_loss_forward_template(
309313

310314
sycl_kernel_submit(sycl::range<1>(local_size), queue, kfn);
311315
} else if (input_cont.dim() == 2) {
316+
using accscalar_t = at::acc_type<scalar_t, true>;
312317
using NllLossForwardReduce2DKernel =
313-
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t>;
318+
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t, accscalar_t>;
314319

315320
int64_t batch_size = input.size(0);
316321
int n_target = input.size(1);
@@ -322,7 +327,7 @@ void nll_loss_forward_template(
322327
auto target_data = _target_data;
323328
auto total_weight_data = _total_weight_data;
324329
auto output_data = _output_data;
325-
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t> kfn(
330+
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t, accscalar_t> kfn(
326331
input_data,
327332
target_data,
328333
weight_data,

0 commit comments

Comments
 (0)