@@ -399,6 +399,16 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
399
399
end
400
400
end
401
401
402
+ # Test im2col
403
+
404
+ for beta in (- 2.0 , - 1.0 , 0.0 , 0.5 , 1.0 , 2.0 )
405
+ cache_dx, cache_dy, cache_w = ([0.17 ;;; 0.19 ;;; 0.23 ], [0.11 ;;; 0.13 ;;; 0.15 ], [1.0 ;;;])
406
+ dx_old = copy (cache_dx)
407
+ cdims = DenseConvDims (cache_dx, cache_w)
408
+ NNlib.∇conv_data_im2col! (cache_dx, cache_dy, cache_w, cdims; alpha= 1.0 , beta)
409
+ @test isapprox (cache_dx, dx_old * beta + cache_dy, rtol = 1.0e-7 )
410
+ end
411
+
402
412
# Test all in-place implementations/interfaces
403
413
for (∇conv_filter!, ∇conv_data!) in (
404
414
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
@@ -407,47 +417,46 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
407
417
)
408
418
# α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
409
419
α, β = 2e0 , - 1e0
410
- flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)
411
420
412
421
@testset " $(∇conv_filter!) /$(∇conv_data!) " begin
413
422
# First, your basic convolution with no parameters
414
423
cdims = DenseConvDims (x, w)
415
424
dy = NNlib. conv (x, w, cdims)
416
425
@test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw + β* w, rtol = 1.0e-7 )
417
- @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 ) broken = flag
426
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 )
418
427
419
428
# Next, test convolution on views and alternate datatypes:
420
429
@test isapprox (ddims (∇conv_filter! (copy (w), x, view (dy, repeat ([:], ndims (dy))... ), cdims; alpha= α, beta= β)), α* dw + β* w, rtol = 1.0e-7 )
421
- @test isapprox (ddims (∇conv_data! (copy (x), view (dy, repeat ([:], ndims (dy))... ), w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 ) broken = flag
430
+ @test isapprox (ddims (∇conv_data! (copy (x), view (dy, repeat ([:], ndims (dy))... ), w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 )
422
431
423
432
@test isapprox (ddims (∇conv_filter! (Float32 .(copy (w)), Float32 .(x), Float32 .(dy), cdims; alpha= Float32 (α), beta= Float32 (β))), α* dw + β* w, rtol = 1.0e-7 )
424
- @test isapprox (ddims (∇conv_data! (Float32 .(copy (x)), Float32 .(dy), Float32 .(w), cdims; alpha= Float32 (α), beta= Float32 (β))), α* dx + β* x, rtol = 1.0e-7 ) broken = flag
433
+ @test isapprox (ddims (∇conv_data! (Float32 .(copy (x)), Float32 .(dy), Float32 .(w), cdims; alpha= Float32 (α), beta= Float32 (β))), α* dx + β* x, rtol = 1.0e-7 )
425
434
426
435
# Next, introduce stride:
427
436
cdims = DenseConvDims (x, w; stride= 2 )
428
437
dy = NNlib. conv (x, w, cdims)
429
438
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1 ,3 )
430
439
@test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_stride + β* w, rtol = 1.0e-7 ) broken= flag_
431
- @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_stride + β* x, rtol = 1.0e-7 ) broken = flag
440
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_stride + β* x, rtol = 1.0e-7 )
432
441
433
442
# Next, introduce dilation:
434
443
cdims = DenseConvDims (x, w; dilation= 2 )
435
444
dy = NNlib. conv (x, w, cdims)
436
445
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
437
446
@test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_dil + β* w, rtol = 1.0e-7 )
438
- @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_dil + β* x, rtol = 1.0e-7 ) broken= flag || flag_
447
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_dil + β* x, rtol = 1.0e-7 ) broken= flag_
439
448
440
449
# Next, introduce padding:
441
450
cdims = DenseConvDims (x, w; padding= 1 )
442
451
dy = NNlib. conv (x, w, cdims)
443
452
@test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_pad + β* w, rtol = 1.0e-7 )
444
- @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_pad + β* x, rtol = 1.0e-7 ) broken = flag
453
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_pad + β* x, rtol = 1.0e-7 )
445
454
446
455
# Next, test crosscor/conv with a flipped kernel
447
456
cdims = DenseConvDims (x, w; flipkernel= true )
448
457
dy = NNlib. conv (x, w, cdims)
449
458
@test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_flip + β* w, rtol = 1.0e-7 )
450
- @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_flip + β* x, rtol = 1.0e-7 ) broken = flag
459
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_flip + β* x, rtol = 1.0e-7 )
451
460
end
452
461
end
453
462
end
0 commit comments