Skip to content

Commit 8a560ba

Browse files
authored
[NPU] Fix aclnnInplaceMuls parameter bug in dropout kernel (#1413)
1 parent 4f4eddb commit 8a560ba

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

backends/npu/kernels/dropout_kernel.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,21 @@ void DropoutRawKernel(const Context& dev_ctx,
402402
}
403403

404404
if (!is_upscale) {
405-
phi::Scalar revert_scale = static_cast<T>(1.0 - dropout_prob);
406-
EXEC_NPU_CMD(aclnnInplaceMuls, dev_ctx, *out, revert_scale);
405+
auto revert_scale = static_cast<T>(1.0 - dropout_prob);
406+
aclDataType acl_data_type = ConvertToNpuDtype(x.dtype());
407+
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
408+
aclScalar* acl_scalar_revert_scale =
409+
aclCreateScalar(&revert_scale, acl_data_type);
410+
EXEC_NPU_CMD(aclnnInplaceMuls, dev_ctx, *out, acl_scalar_revert_scale);
407411
}
408412
} else {
409413
if (!is_upscale) {
410-
phi::Scalar down_scale = static_cast<T>(1.0 - dropout_prob);
411-
EXEC_NPU_CMD(aclnnMuls, dev_ctx, x, down_scale, *out);
414+
auto down_scale = static_cast<T>(1.0 - dropout_prob);
415+
aclDataType acl_data_type = ConvertToNpuDtype(x.dtype());
416+
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
417+
aclScalar* acl_scalar_down_scale =
418+
aclCreateScalar(&down_scale, acl_data_type);
419+
EXEC_NPU_CMD(aclnnMuls, dev_ctx, x, acl_scalar_down_scale, *out);
412420
return;
413421
}
414422
TensorCopy(dev_ctx, x, false, out);
@@ -565,8 +573,12 @@ void DropoutGradRawKernel(const Context& dev_ctx,
565573
}
566574

567575
if (!is_upscale) {
568-
phi::Scalar revert_scale = static_cast<T>(1.0 - dropout_prob);
569-
EXEC_NPU_CMD(aclnnInplaceMuls, dev_ctx, *dx, revert_scale);
576+
auto revert_scale = static_cast<T>(1.0 - dropout_prob);
577+
aclDataType acl_data_type = ConvertToNpuDtype(dx->dtype());
578+
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
579+
aclScalar* acl_scalar_revert_scale =
580+
aclCreateScalar(&revert_scale, acl_data_type);
581+
EXEC_NPU_CMD(aclnnInplaceMuls, dev_ctx, *dx, acl_scalar_revert_scale);
570582
}
571583
return;
572584
}

backends/npu/tests/unittests/test_dropout_op_npu.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,28 @@ def setUp(self):
148148
}
149149

150150

151+
class TestDropoutModeDownOp1Fp16(TestDropoutOp):
152+
# the dropout_prob is 0.2
153+
def init_dtype(self):
154+
self.dtype = np.float16
155+
156+
def setUp(self):
157+
self.op_type = "dropout"
158+
self.set_npu()
159+
self.init_dtype()
160+
self.inputs = {"X": np.random.random((32, 64)).astype(self.dtype)}
161+
self.attrs = {
162+
"dropout_prob": 0.2,
163+
"fix_seed": True,
164+
"is_test": False,
165+
"dropout_implementation": "downgrade_in_infer",
166+
}
167+
self.outputs = {
168+
"Out": np.zeros((32, 64)).astype("float16"),
169+
"Mask": convert_to_npu_mask(np.zeros((32, 64)).astype("uint8")),
170+
}
171+
172+
151173
class TestDropoutOp2(TestDropoutOp):
152174
# the dropout_prob is 1.0
153175
def setUp(self):

0 commit comments

Comments
 (0)