Skip to content

Commit 655c498

Browse files
authored
[Cherry-Pick]Fix test_cudnn_norm_conv and test_cudnn_bn_add_relu in CUDA11.2 (#42406)
* Fix test_cudnn_norm_conv and test_cudnn_bn_add_relu in CUDA11.2 * no throw in V100 for some cases
1 parent 778ec77 commit 655c498

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
2424
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
2525
#include "paddle/fluid/platform/float16.h"
26+
#include "paddle/phi/core/kernel_registry.h"
2627
#include "paddle/phi/kernels/funcs/math_function.h"
2728

2829
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
@@ -33,6 +34,7 @@ namespace op = paddle::operators;
3334
using Tensor = paddle::framework::Tensor;
3435

3536
USE_OP_ITSELF(batch_norm);
37+
PD_DECLARE_KERNEL(batch_norm, GPU, ALL_LAYOUT);
3638
USE_CUDA_ONLY_OP(fused_bn_add_activation);
3739
USE_CUDA_ONLY_OP(fused_bn_add_activation_grad);
3840

paddle/fluid/operators/fused/cudnn_norm_conv_test.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx,
164164
attrs.insert({"groups", groups});
165165
attrs.insert({"exhaustive_search", exhaustive_search});
166166
attrs.insert({"use_addto", use_addto});
167+
attrs.insert({"workspace_size_MB", 512});
167168

168169
auto op = framework::OpRegistry::CreateOp(
169170
"conv2d_grad", {{"Input", {"Input"}},
@@ -408,7 +409,7 @@ TEST(CudnnNormConvFp16, K1S1) {
408409
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
409410
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
410411

411-
if (ctx->GetComputeCapability() <= 70) {
412+
if (ctx->GetComputeCapability() < 70) {
412413
ASSERT_THROW(test.CheckForward(1e-3, true),
413414
paddle::platform::EnforceNotMet);
414415
ASSERT_THROW(test.CheckBackward(1e-3, true),
@@ -434,7 +435,7 @@ TEST(CudnnNormConvFp16, K3S1) {
434435
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
435436
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
436437

437-
if (ctx->GetComputeCapability() <= 70) {
438+
if (ctx->GetComputeCapability() < 70) {
438439
ASSERT_THROW(test.CheckForward(1e-3, true),
439440
paddle::platform::EnforceNotMet);
440441
ASSERT_THROW(test.CheckBackward(1e-3, true),
@@ -460,7 +461,7 @@ TEST(CudnnNormConvFp16, K1S1O4) {
460461
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
461462
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
462463

463-
if (ctx->GetComputeCapability() <= 70) {
464+
if (ctx->GetComputeCapability() < 70) {
464465
ASSERT_THROW(test.CheckForward(1e-3, true),
465466
paddle::platform::EnforceNotMet);
466467
ASSERT_THROW(test.CheckBackward(1e-3, true),

0 commit comments

Comments
 (0)