Skip to content

Commit e5e3964

Browse files
author
Avik Pal
committed
Fixes for proper type
1 parent fd2c602 commit e5e3964

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

src/nnpack/nnlib.jl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,40 @@ const AA1 = Union{AA{2}, AA{3}, AA{4}, AA{5}}
88
# leakyrelu(x::AA1, a = oftype(x/1, 0.01)) =
99
# nnp_relu_output(x, inplace ? x : similar(x), negative_slope = a, threadpool = shared_threadpool)
1010

11-
softmax!(x::AbstractVecOrMat{Float64}) = Float64.(softmax!(Float32.(x)))
11+
softmax!(x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(x))
1212

13-
softmax!(x::AbstractVecOrMat{Float32}) =
13+
softmax!(x::A) where A<:AbstractVecOrMat{Float32} =
1414
nnp_softmax_output(x, x, threadpool = shared_threadpool)
1515

16-
softmax!(y::AbstractVecOrMat{Float64}, x::AbstractVecOrMat{Float64}) = Float64.(softmax!(Float32.(y), Float32.(x)))
16+
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(y), Float32.(x))
1717

18-
softmax!(y::AbstractVecOrMat{Float32}, x::AbstractVecOrMat{Float32}) =
18+
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
1919
nnp_softmax_output(x, y, threadpool = shared_threadpool)
2020

21-
softmax(x::AbstractVecOrMat{Float64}) = Float64.(softmax(Float32.(x)))
21+
softmax(x::A) where A<:AbstractVecOrMat{Float64} = softmax(Float32.(x))
2222

23-
softmax(x::AbstractVecOrMat{Float32}) =
23+
softmax(x::A) where A<:AbstractVecOrMat{Float32} =
2424
nnp_softmax_output(x, similar(x), threadpool = shared_threadpool)
2525

26-
maxpool(x::AbstractArray{Float64, 4}, k; pad = map(_->0,k), stride = k) =
27-
Float64.(maxpool(Float32.(x), k, pad = pad, stride = stride))
26+
maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
27+
maxpool(Float32.(x), k, pad = pad, stride = stride)
2828

29-
function maxpool(x::AA{4}, k; pad = map(_->0,k), stride = k)
29+
function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AA{4}
3030
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
3131
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
3232
maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
3333
end
3434

35-
maxpool!(y::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, k; pad = map(_->0,k), stride = k) =
36-
Float64.(maxpool!(Float32.(y), Float32.(x), k, pad = pad, stride = stride))
35+
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
36+
maxpool!(Float32.(y), Float32.(x), k, pad = pad, stride = stride)
3737

38-
maxpool!(y::AA{4}, x::AA{4}, k; pad = map(_->0,k), stride = k) =
38+
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AA{4} =
3939
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride), threadpool = shared_threadpool)
4040

41-
conv(x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
42-
Float64.(conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo))
41+
conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
42+
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
4343

44-
function conv(x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0))
44+
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AA{4}
4545
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
4646
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
4747
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
@@ -50,72 +50,72 @@ function conv(x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt
5050
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
5151
end
5252

53-
conv(x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}, b::AbstractArray{Float64, 1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
54-
Float64.(conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo))
53+
conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
54+
conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
5555

56-
function conv(x::AA{4}, w::AA{4}, b::AA{1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0))
56+
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AA{4}, A2<:AA{1}}
5757
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
5858
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
5959
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
6060
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
6161
end
6262

63-
crosscor(x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}, b::AbstractArray{Float64, 1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
64-
Float64.(crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo))
63+
crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
64+
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
6565

66-
function crosscor(x::AA{4}, w::AA{4}, b::AA{1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0))
66+
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AA{4}, A2<:AA{1}}
6767
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
6868
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
6969
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
7070
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
7171
end
7272

73-
conv!(y::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}, b::AbstractArray{Float64, 1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) =
74-
Float64.(conv(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel))
73+
conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
74+
conv(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
7575

76-
function conv!(y::AA{4}, x::AA{4}, w::AA{4}, b::AA{1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0)
76+
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AA{4}, A2<:AA{1}}
7777
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
7878
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool)
7979
end
8080

81-
crosscor!(y::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}, b::AbstractArray{Float64, 1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
82-
Float64.(conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1))
81+
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
82+
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
8383

84-
crosscor!(y::AA{4}, x::AA{4}, w::AA{4}, b::AA{1}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
84+
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AA{4}, A2<:AA{1}} =
8585
conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
8686

87-
∇conv_data(dy::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
88-
Float64.(∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo))
87+
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
88+
∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
8989

90-
function ∇conv_data(dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0))
90+
function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AA{4}
9191
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
9292
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
9393
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
9494
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
9595
end
9696

97-
∇conv_data!(dx::AbstractArray{Float64, 4}, dy::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) =
98-
Float64.(∇conv_data!(Float32.(dx), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel))
97+
∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =
98+
∇conv_data!(Float32.(dx), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
9999

100-
function ∇conv_data!(dx::AA{4}, dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0)
100+
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AA{4}
101101
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
102102
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool)
103103
end
104104

105-
∇conv_filter(dy::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) =
106-
Float64.(∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo))
105+
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
106+
∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
107107

108-
function ∇conv_filter(dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0))
108+
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AA{4}
109109
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
110110
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
111111
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
112112
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
113113
end
114114

115-
∇conv_filter!(dw::AbstractArray{Float64, 4}, dy::AbstractArray{Float64, 4}, x::AbstractArray{Float64, 4}, w::AbstractArray{Float64, 4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) =
116-
Float64.(∇conv_filter!(Float32.(dw), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel))
115+
∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =
116+
∇conv_filter!(Float32.(dw), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
117117

118-
function ∇conv_filter!(dw::AA{4}, dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0)
118+
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AA{4}
119119
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
120120
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool)
121121
flipkernel == 0 ? reverse(reverse(dw, dims=1), dims=2) : dw

0 commit comments

Comments
 (0)