@@ -6,27 +6,15 @@ function check_support(x, k, pad, stride, dilation = 1)
6
6
return pad_, stride_, fallback
7
7
end
8
8
9
- function softmax! (x:: A ) where A<: AbstractVecOrMat{Float64}
10
- x = Float32 .(x)
11
- softmax! (x)
12
- end
13
-
14
9
softmax! (x:: A ) where A<: AbstractVecOrMat{Float32} =
15
10
nnp_softmax_output (x, x)
16
11
17
- softmax! (y:: A , x:: A ) where A<: AbstractVecOrMat{Float64} = softmax! (Float32 .(y), Float32 .(x))
18
-
19
12
softmax! (y:: A , x:: A ) where A<: AbstractVecOrMat{Float32} =
20
13
nnp_softmax_output (x, y)
21
14
22
- softmax (x:: A ) where A<: AbstractVecOrMat{Float64} = softmax (Float32 .(x))
23
-
24
15
softmax (x:: A ) where A<: AbstractVecOrMat{Float32} =
25
16
nnp_softmax_output (x, similar (x))
26
17
27
- maxpool (x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float64, 4} =
28
- maxpool (Float32 .(x), k, pad = pad, stride = stride)
29
-
30
18
function maxpool (x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float32, 4}
31
19
pad_, stride_, fallback = check_support (x, k, pad, stride)
32
20
if fallback
@@ -36,15 +24,9 @@ function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{
36
24
end
37
25
end
38
26
39
- maxpool! (y:: A , x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float64, 4} =
40
- maxpool! (Float32 .(y), Float32 .(x), k, pad = pad, stride = stride)
41
-
42
27
maxpool! (y:: A , x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float32, 4} =
43
28
nnp_max_pooling_output (x, y, k, padding = expand (Val{length (k)}, pad), stride = expand (Val{length (k)}, stride))
44
29
45
- conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float64, 4} =
46
- conv (Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
47
-
48
30
function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
49
31
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation)
50
32
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
@@ -55,9 +37,6 @@ function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
55
37
end
56
38
end
57
39
58
- conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
59
- conv (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
60
-
61
40
function conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
62
41
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation)
63
42
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
@@ -68,9 +47,6 @@ function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UIn
68
47
end
69
48
end
70
49
71
- crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
72
- crosscor (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
73
-
74
50
function crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
75
51
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation)
76
52
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
@@ -81,19 +57,13 @@ function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
81
57
end
82
58
end
83
59
84
- conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
85
- conv! (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
86
-
87
60
function conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
88
61
if flipkernel == 0
89
62
w = reverse (reverse (w, dims= 1 ), dims= 2 )
90
63
end
91
64
nnp_convolution_output (y, x, w, b, algo = algo, padding = pad, stride = stride)
92
65
end
93
66
94
- crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
95
- conv! (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
96
-
97
67
crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} } =
98
68
conv! (y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
99
69
@@ -109,17 +79,11 @@ function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo
109
79
end
110
80
end
111
81
112
- ∇conv_data! (dx:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float64, 4} =
113
- ∇conv_data! (Float32 .(dx), Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
114
-
115
82
function ∇conv_data! (dx:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float32, 4}
116
83
flipkernel == 0 && (w = reverse (reverse (w, dims= 1 ), dims= 2 ))
117
84
nnp_convolution_input_gradient (dx, x, dy, w, padding = pad, stride = stride, algo = algo)
118
85
end
119
86
120
- ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float64, 4} =
121
- ∇conv_filter (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
122
-
123
87
function ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
124
88
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation)
125
89
if fallback
@@ -129,9 +93,6 @@ function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, al
129
93
end
130
94
end
131
95
132
- ∇conv_filter! (dw:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float64, 4} =
133
- ∇conv_filter! (Float32 .(dw), Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
134
-
135
96
function ∇conv_filter! (dw:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float32, 4}
136
97
flipkernel == 0 && (w = reverse (reverse (w, dims= 1 ), dims= 2 ))
137
98
dw .= nnp_convolution_kernel_gradient (dw, x, dy, w, padding = pad, stride = stride, algo = algo)
0 commit comments