Skip to content

Commit ac3c880

Browse files
committed
rm float64 methods
1 parent 8b91c9b commit ac3c880

File tree

1 file changed

+0
-39
lines changed

1 file changed

+0
-39
lines changed

src/nnpack/interface.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,15 @@ function check_support(x, k, pad, stride, dilation = 1)
66
return pad_, stride_, fallback
77
end
88

9-
function softmax!(x::A) where A<:AbstractVecOrMat{Float64}
10-
x = Float32.(x)
11-
softmax!(x)
12-
end
13-
149
softmax!(x::A) where A<:AbstractVecOrMat{Float32} =
1510
nnp_softmax_output(x, x)
1611

17-
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(y), Float32.(x))
18-
1912
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
2013
nnp_softmax_output(x, y)
2114

22-
softmax(x::A) where A<:AbstractVecOrMat{Float64} = softmax(Float32.(x))
23-
2415
softmax(x::A) where A<:AbstractVecOrMat{Float32} =
2516
nnp_softmax_output(x, similar(x))
2617

27-
maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
28-
maxpool(Float32.(x), k, pad = pad, stride = stride)
29-
3018
function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4}
3119
pad_, stride_, fallback = check_support(x, k, pad, stride)
3220
if fallback
@@ -36,15 +24,9 @@ function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{
3624
end
3725
end
3826

39-
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
40-
maxpool!(Float32.(y), Float32.(x), k, pad = pad, stride = stride)
41-
4227
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} =
4328
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride))
4429

45-
conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
46-
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
47-
4830
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
4931
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
5032
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
@@ -55,9 +37,6 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
5537
end
5638
end
5739

58-
conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
59-
conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
60-
6140
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
6241
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
6342
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
@@ -68,9 +47,6 @@ function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UIn
6847
end
6948
end
7049

71-
crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
72-
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
73-
7450
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
7551
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
7652
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
@@ -81,19 +57,13 @@ function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
8157
end
8258
end
8359

84-
conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
85-
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
86-
8760
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
8861
if flipkernel == 0
8962
w = reverse(reverse(w, dims=1), dims=2)
9063
end
9164
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride)
9265
end
9366

94-
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
95-
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
96-
9767
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
9868
conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
9969

@@ -109,17 +79,11 @@ function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo
10979
end
11080
end
11181

112-
∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =
113-
∇conv_data!(Float32.(dx), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
114-
11582
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
11683
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
11784
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
11885
end
11986

120-
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
121-
∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
122-
12387
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
12488
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
12589
if fallback
@@ -129,9 +93,6 @@ function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, al
12993
end
13094
end
13195

132-
∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =
133-
∇conv_filter!(Float32.(dw), Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
134-
13596
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
13697
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
13798
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)

0 commit comments

Comments
 (0)