@@ -2,6 +2,7 @@ using NNlib, Test
2
2
using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,
3
3
stride, padding, dilation, flipkernel, output_size,
4
4
groupcount
5
+ using Random: AbstractRNG, SamplerType
5
6
6
7
@testset " ConvDims" begin
7
8
for T in (DenseConvDims, DepthwiseConvDims)
865
866
@test size (NNlib.∇conv_filter_direct! (w, x, y, cdims)) == w_size
866
867
end
867
868
869
+ # https://github.com/FluxML/NNlib.jl/issues/490
870
+ # https://github.com/FluxML/NNlib.jl/issues/405
871
+ @testset " conv_direct! - Unusual input types" begin
872
+ # Create test type that can't be indexed when undefined.
873
+ # This simulates the worst-case scenario for custom types.
874
+ struct MyFloat <: Real
875
+ set:: Set{Float32}
876
+ end
877
+
878
+ # Test that direct indexing fails when undefined.
879
+ v = Array {MyFloat} (undef, 3 )
880
+ @test_throws UndefRefError v[1 ]
881
+
882
+ # Define minimal set of functions required for conv_direct!
883
+ MyFloat (x:: MyFloat ) = x
884
+ MyFloat (x:: Real ) = MyFloat (Set (Float32 (x)))
885
+
886
+ Base.:+ (x:: MyFloat , y:: MyFloat ) = MyFloat (only (x. set) + only (y. set))
887
+ Base.:* (x:: MyFloat , y:: MyFloat ) = MyFloat (only (x. set) * only (y. set))
888
+ Base. promote_rule (:: Type{MyFloat} , :: Type{Float32} ) = MyFloat
889
+ Base. rand (:: AbstractRNG , :: SamplerType{MyFloat} ) = MyFloat (rand (Float32))
890
+ Base. zero (:: MyFloat ) = MyFloat (zero (Float32))
891
+ Base. zero (:: Type{MyFloat} ) = MyFloat (zero (Float32))
892
+
893
+ # Test conv_direct!
894
+ x_size = (6 , 7 , 8 , 5 , 3 )
895
+ y_size = (5 , 6 , 7 , 4 , 3 )
896
+ w_size = (2 , 2 , 2 , 5 , 4 )
897
+ x = rand (MyFloat, x_size);
898
+ w = randn (Float32, w_size);
899
+ y = Array {MyFloat} (undef, y_size... );
900
+ cdims = DenseConvDims (x_size, w_size)
901
+ y_out = NNlib. conv_direct! (y, x, w, cdims)
902
+
903
+ @test eltype (y_out) == MyFloat
904
+ @test size (y_out) == y_size
905
+ end
906
+
868
907
@testset " AutoDiff: spatial_rank=$spatial_rank " for spatial_rank in (1 , 2 , 3 )
869
908
x = rand (rng, repeat ([5 ], spatial_rank)... , 3 , 2 )
870
909
w = rand (rng, repeat ([3 ], spatial_rank)... , 3 , 3 )
0 commit comments