@@ -402,13 +402,21 @@ void DropoutRawKernel(const Context& dev_ctx,
402
402
}
403
403
404
404
if (!is_upscale) {
405
- phi::Scalar revert_scale = static_cast <T>(1.0 - dropout_prob);
406
- EXEC_NPU_CMD (aclnnInplaceMuls, dev_ctx, *out, revert_scale);
405
+ auto revert_scale = static_cast <T>(1.0 - dropout_prob);
406
+ aclDataType acl_data_type = ConvertToNpuDtype (x.dtype ());
407
+ static const auto aclCreateScalar = GET_OP_API_FUNC (aclCreateScalar);
408
+ aclScalar* acl_scalar_revert_scale =
409
+ aclCreateScalar (&revert_scale, acl_data_type);
410
+ EXEC_NPU_CMD (aclnnInplaceMuls, dev_ctx, *out, acl_scalar_revert_scale);
407
411
}
408
412
} else {
409
413
if (!is_upscale) {
410
- phi::Scalar down_scale = static_cast <T>(1.0 - dropout_prob);
411
- EXEC_NPU_CMD (aclnnMuls, dev_ctx, x, down_scale, *out);
414
+ auto down_scale = static_cast <T>(1.0 - dropout_prob);
415
+ aclDataType acl_data_type = ConvertToNpuDtype (x.dtype ());
416
+ static const auto aclCreateScalar = GET_OP_API_FUNC (aclCreateScalar);
417
+ aclScalar* acl_scalar_down_scale =
418
+ aclCreateScalar (&down_scale, acl_data_type);
419
+ EXEC_NPU_CMD (aclnnMuls, dev_ctx, x, acl_scalar_down_scale, *out);
412
420
return ;
413
421
}
414
422
TensorCopy (dev_ctx, x, false , out);
@@ -565,8 +573,12 @@ void DropoutGradRawKernel(const Context& dev_ctx,
565
573
}
566
574
567
575
if (!is_upscale) {
568
- phi::Scalar revert_scale = static_cast <T>(1.0 - dropout_prob);
569
- EXEC_NPU_CMD (aclnnInplaceMuls, dev_ctx, *dx, revert_scale);
576
+ auto revert_scale = static_cast <T>(1.0 - dropout_prob);
577
+ aclDataType acl_data_type = ConvertToNpuDtype (dx->dtype ());
578
+ static const auto aclCreateScalar = GET_OP_API_FUNC (aclCreateScalar);
579
+ aclScalar* acl_scalar_revert_scale =
580
+ aclCreateScalar (&revert_scale, acl_data_type);
581
+ EXEC_NPU_CMD (aclnnInplaceMuls, dev_ctx, *dx, acl_scalar_revert_scale);
570
582
}
571
583
return ;
572
584
}
0 commit comments