@@ -278,27 +278,27 @@ conv_answer_dict = Dict(
278
278
@testset " $(conv) " begin
279
279
# First, your basic convolution with no parameters
280
280
cdims = DenseConvDims (x, w)
281
- @test ddims (conv (x, w, cdims)) == y_plain
281
+ @test isapprox ( ddims (conv (x, w, cdims)), y_plain, rtol = 1.0e-7 )
282
282
283
283
# Next, test convolution on views and alternate datatypes:
284
- @test ddims (conv (view (x, repeat ([:], ndims (x))... ), w, cdims)) == y_plain
285
- @test ddims (conv (Float32 .(x), Float32 .(w), cdims)) == Float32 .(y_plain)
284
+ @test isapprox ( ddims (conv (view (x, repeat ([:], ndims (x))... ), w, cdims)), y_plain, rtol = 1.0e-7 )
285
+ @test isapprox ( ddims (conv (Float32 .(x), Float32 .(w), cdims)), Float32 .(y_plain), rtol = 1.0e-7 )
286
286
287
287
# Next, introduce stride:
288
288
cdims = DenseConvDims (x, w; stride= 2 )
289
- @test ddims (conv (x, w, cdims)) == y_stride
289
+ @test isapprox ( ddims (conv (x, w, cdims)), y_stride, rtol = 1.0e-7 )
290
290
291
291
# Next, introduce dilation:
292
292
cdims = DenseConvDims (x, w; dilation= 2 )
293
- @test ddims (conv (x, w, cdims)) == y_dil
293
+ @test isapprox ( ddims (conv (x, w, cdims)), y_dil, rtol = 1.0e-7 )
294
294
295
295
# Next, introduce padding:
296
296
cdims = DenseConvDims (x, w; padding= 1 )
297
- @test ddims (conv (x, w, cdims)) == y_pad
297
+ @test isapprox ( ddims (conv (x, w, cdims)), y_pad, rtol = 1.0e-7 )
298
298
299
299
# Next, test crosscor/conv with a flipped kernel
300
300
cdims = DenseConvDims (x, w; flipkernel= true )
301
- @test ddims (conv (x, w, cdims)) == y_flip
301
+ @test isapprox ( ddims (conv (x, w, cdims)), y_flip, rtol = 1.0e-7 )
302
302
end
303
303
end
304
304
@@ -312,39 +312,39 @@ conv_answer_dict = Dict(
312
312
# First, your basic convolution with no parameters
313
313
cdims = DenseConvDims (x, w)
314
314
dy = NNlib. conv (x, w, cdims)
315
- @test ddims (∇conv_filter (x, dy, cdims)) == dw
316
- @test ddims (∇conv_data (dy, w, cdims)) == dx
315
+ @test isapprox ( ddims (∇conv_filter (x, dy, cdims)), dw, rtol = 1.0e-7 )
316
+ @test isapprox ( ddims (∇conv_data (dy, w, cdims)), dx, rtol = 1.0e-7 )
317
317
318
318
# Next, test convolution on views and alternate datatypes:
319
- @test ddims (∇conv_filter (x, view (dy, repeat ([:], ndims (dy))... ), cdims)) == dw
320
- @test ddims (∇conv_data (view (dy, repeat ([:], ndims (dy))... ), w, cdims)) == dx
319
+ @test isapprox ( ddims (∇conv_filter (x, view (dy, repeat ([:], ndims (dy))... ), cdims)), dw, rtol = 1.0e-7 )
320
+ @test isapprox ( ddims (∇conv_data (view (dy, repeat ([:], ndims (dy))... ), w, cdims)), dx, rtol = 1.0e-7 )
321
321
322
- @test ddims (∇conv_filter (Float32 .(x), Float32 .(dy), cdims)) == dw
323
- @test ddims (∇conv_data (Float32 .(dy), Float32 .(w), cdims)) == dx
322
+ @test isapprox ( ddims (∇conv_filter (Float32 .(x), Float32 .(dy), cdims)), dw, rtol = 1.0e-7 )
323
+ @test isapprox ( ddims (∇conv_data (Float32 .(dy), Float32 .(w), cdims)), dx, rtol = 1.0e-7 )
324
324
325
325
# Next, introduce stride:
326
326
cdims = DenseConvDims (x, w; stride= 2 )
327
327
dy = NNlib. conv (x, w, cdims)
328
- @test ddims (∇conv_filter (x, dy, cdims)) == dw_stride
329
- @test ddims (∇conv_data (dy, w, cdims)) == dx_stride
328
+ @test isapprox ( ddims (∇conv_filter (x, dy, cdims)), dw_stride, rtol = 1.0e-7 )
329
+ @test isapprox ( ddims (∇conv_data (dy, w, cdims)), dx_stride, rtol = 1.0e-7 )
330
330
331
331
# Next, introduce dilation:
332
332
cdims = DenseConvDims (x, w; dilation= 2 )
333
333
dy = NNlib. conv (x, w, cdims)
334
- @test ddims (∇conv_filter (x, dy, cdims)) == dw_dil
335
- @test ddims (∇conv_data (dy, w, cdims)) == dx_dil
334
+ @test isapprox ( ddims (∇conv_filter (x, dy, cdims)), dw_dil, rtol = 1.0e-7 )
335
+ @test isapprox ( ddims (∇conv_data (dy, w, cdims)), dx_dil, rtol = 1.0e-7 )
336
336
337
337
# Next, introduce padding:
338
338
cdims = DenseConvDims (x, w; padding= 1 )
339
339
dy = NNlib. conv (x, w, cdims)
340
- @test ddims (∇conv_filter (x, dy, cdims)) == dw_pad
341
- @test ddims (∇conv_data (dy, w, cdims)) == dx_pad
340
+ @test isapprox ( ddims (∇conv_filter (x, dy, cdims)), dw_pad, rtol = 1.0e-7 )
341
+ @test isapprox ( ddims (∇conv_data (dy, w, cdims)), dx_pad, rtol = 1.0e-7 )
342
342
343
343
# Next, test crosscor/conv with a flipped kernel
344
344
cdims = DenseConvDims (x, w; flipkernel= true )
345
345
dy = NNlib. conv (x, w, cdims)
346
- @test ddims (∇conv_filter (x, dy, cdims)) == dw_flip
347
- @test ddims (∇conv_data (dy, w, cdims)) == dx_flip
346
+ @test isapprox ( ddims (∇conv_filter (x, dy, cdims)), dw_flip, rtol = 1.0e-7 )
347
+ @test isapprox ( ddims (∇conv_data (dy, w, cdims)), dx_flip, rtol = 1.0e-7 )
348
348
end
349
349
end
350
350
end
@@ -481,24 +481,24 @@ end
481
481
@test ddims (conv (x, w, cdims)) == y_plain
482
482
483
483
# Next, test convolution on views and alternate datatypes:
484
- @test ddims (conv (view (x, repeat ([:], ndims (x))... ), w, cdims)) == y_plain
485
- @test ddims (conv (Float32 .(x), Float32 .(w), cdims)) == Float32 .(y_plain)
484
+ @test isapprox ( ddims (conv (view (x, repeat ([:], ndims (x))... ), w, cdims)), y_plain, rtol = 1.0e-7 )
485
+ @test isapprox ( ddims (conv (Float32 .(x), Float32 .(w), cdims)), Float32 .(y_plain), rtol = 1.0e-7 )
486
486
487
487
# Next, introduce stride:
488
488
cdims = DepthwiseConvDims (x, w; stride= 2 )
489
- @test ddims (conv (x, w, cdims)) == y_stride
489
+ @test isapprox ( ddims (conv (x, w, cdims)), y_stride, rtol = 1.0e-7 )
490
490
491
491
# Next, introduce dilation:
492
492
cdims = DepthwiseConvDims (x, w; dilation= 2 )
493
- @test ddims (conv (x, w, cdims)) == y_dil
493
+ @test isapprox ( ddims (conv (x, w, cdims)), y_dil, rtol = 1.0e-7 )
494
494
495
495
# Next, introduce padding:
496
496
cdims = DepthwiseConvDims (x, w; padding= 1 )
497
- @test ddims (conv (x, w, cdims)) == y_pad
497
+ @test isapprox ( ddims (conv (x, w, cdims)), y_pad, rtol = 1.0e-7 )
498
498
499
499
# Next, test crosscor/conv with a flipped kernel
500
500
cdims = DepthwiseConvDims (x, w; flipkernel= true )
501
- @test ddims (conv (x, w, cdims)) == y_flip
501
+ @test isapprox ( ddims (conv (x, w, cdims)), y_flip, rtol = 1.0e-7 )
502
502
end
503
503
end
504
504
0 commit comments