@@ -164,6 +164,7 @@ void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx,
164
164
attrs.insert ({" groups" , groups});
165
165
attrs.insert ({" exhaustive_search" , exhaustive_search});
166
166
attrs.insert ({" use_addto" , use_addto});
167
+ attrs.insert ({" workspace_size_MB" , 512 });
167
168
168
169
auto op = framework::OpRegistry::CreateOp (
169
170
" conv2d_grad" , {{" Input" , {" Input" }},
@@ -408,7 +409,7 @@ TEST(CudnnNormConvFp16, K1S1) {
408
409
platform::CUDADeviceContext *ctx = static_cast <platform::CUDADeviceContext *>(
409
410
platform::DeviceContextPool::Instance ().Get (platform::CUDAPlace (0 )));
410
411
411
- if (ctx->GetComputeCapability () <= 70 ) {
412
+ if (ctx->GetComputeCapability () < 70 ) {
412
413
ASSERT_THROW (test.CheckForward (1e-3 , true ),
413
414
paddle::platform::EnforceNotMet);
414
415
ASSERT_THROW (test.CheckBackward (1e-3 , true ),
@@ -434,7 +435,7 @@ TEST(CudnnNormConvFp16, K3S1) {
434
435
platform::CUDADeviceContext *ctx = static_cast <platform::CUDADeviceContext *>(
435
436
platform::DeviceContextPool::Instance ().Get (platform::CUDAPlace (0 )));
436
437
437
- if (ctx->GetComputeCapability () <= 70 ) {
438
+ if (ctx->GetComputeCapability () < 70 ) {
438
439
ASSERT_THROW (test.CheckForward (1e-3 , true ),
439
440
paddle::platform::EnforceNotMet);
440
441
ASSERT_THROW (test.CheckBackward (1e-3 , true ),
@@ -460,7 +461,7 @@ TEST(CudnnNormConvFp16, K1S1O4) {
460
461
platform::CUDADeviceContext *ctx = static_cast <platform::CUDADeviceContext *>(
461
462
platform::DeviceContextPool::Instance ().Get (platform::CUDAPlace (0 )));
462
463
463
- if (ctx->GetComputeCapability () <= 70 ) {
464
+ if (ctx->GetComputeCapability () < 70 ) {
464
465
ASSERT_THROW (test.CheckForward (1e-3 , true ),
465
466
paddle::platform::EnforceNotMet);
466
467
ASSERT_THROW (test.CheckBackward (1e-3 , true ),
0 commit comments