Skip to content

Commit d423576

Browse files
authored
Support conv_direct! on custom datatypes (#592)
* Add test for unusual input datatypes * Add fix to `conv_direct!` * Only set y to zero if beta is false or zero * Test output eltype
1 parent 85b17cf commit d423576

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/impl/conv_direct.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ function conv_direct!(
8181
# Use `calc_padding_regions` to determine where we do or don't need to worry about padding
8282
padded_regions, central_region = calc_padding_regions(cdims)
8383

84+
# Set outputs to zero to support custom datatypes (https://github.com/FluxML/NNlib.jl/issues/490)
85+
if iszero(beta)
86+
y = fill!(y, zero(yT))
87+
end
88+
8489
# Start with the central region
8590
w_region, h_region, d_region = central_region
8691
@inbounds for batch in 1:size(x, 5),

test/conv.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using NNlib, Test
22
using NNlib: input_size, kernel_size, channels_in, channels_out, channel_multiplier,
33
stride, padding, dilation, flipkernel, output_size,
44
groupcount
5+
using Random: AbstractRNG, SamplerType
56

67
@testset "ConvDims" begin
78
for T in (DenseConvDims, DepthwiseConvDims)
@@ -865,6 +866,44 @@ end
865866
@test size(NNlib.∇conv_filter_direct!(w, x, y, cdims)) == w_size
866867
end
867868

869+
# https://github.com/FluxML/NNlib.jl/issues/490
870+
# https://github.com/FluxML/NNlib.jl/issues/405
871+
@testset "conv_direct! - Unusual input types" begin
872+
# Create test type that can't be indexed when undefined.
873+
# This simulates the worst-case scenario for custom types.
874+
struct MyFloat <: Real
875+
set::Set{Float32}
876+
end
877+
878+
# Test that direct indexing fails when undefined.
879+
v = Array{MyFloat}(undef, 3)
880+
@test_throws UndefRefError v[1]
881+
882+
# Define minimal set of functions required for conv_direct!
883+
MyFloat(x::MyFloat) = x
884+
MyFloat(x::Real) = MyFloat(Set(Float32(x)))
885+
886+
Base.:+(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) + only(y.set))
887+
Base.:*(x::MyFloat, y::MyFloat) = MyFloat(only(x.set) * only(y.set))
888+
Base.promote_rule(::Type{MyFloat}, ::Type{Float32}) = MyFloat
889+
Base.rand(::AbstractRNG, ::SamplerType{MyFloat}) = MyFloat(rand(Float32))
890+
Base.zero(::MyFloat) = MyFloat(zero(Float32))
891+
Base.zero(::Type{MyFloat}) = MyFloat(zero(Float32))
892+
893+
# Test conv_direct!
894+
x_size = (6, 7, 8, 5, 3)
895+
y_size = (5, 6, 7, 4, 3)
896+
w_size = (2, 2, 2, 5, 4)
897+
x = rand(MyFloat, x_size);
898+
w = randn(Float32, w_size);
899+
y = Array{MyFloat}(undef, y_size...);
900+
cdims = DenseConvDims(x_size, w_size)
901+
y_out = NNlib.conv_direct!(y, x, w, cdims)
902+
903+
@test eltype(y_out) == MyFloat
904+
@test size(y_out) == y_size
905+
end
906+
868907
@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
869908
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
870909
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)

0 commit comments

Comments
 (0)