@@ -55,7 +55,49 @@ function conv_baseline!(out, A, kern)
55
55
out
56
56
end
57
57
58
+
59
+ struct DenseConvDims{N,K,C_in,C_out} end
60
+
61
+ function kernaxes (:: DenseConvDims{2,K,C_in, C_out} ) where {K,C_in, C_out}
62
+ K₁ = LoopVectorization. StaticInt (1 ): LoopVectorization. StaticInt (K[1 ])
63
+ K₂ = LoopVectorization. StaticInt (1 ): LoopVectorization. StaticInt (K[2 ])
64
+ Cᵢₙ = LoopVectorization. StaticInt (1 ): LoopVectorization. StaticInt (C_in)
65
+ Cₒᵤₜ = LoopVectorization. StaticInt (1 ): LoopVectorization. StaticInt (C_out)
66
+ (K₁, K₂, Cᵢₙ, Cₒᵤₜ)
67
+ end
68
+
69
+ function convlayer! (
70
+ out:: AbstractArray{<:Any,4} , img, kern,
71
+ dcd:: DenseConvDims{2, <:Any, <:Any, <:Any}
72
+ )
73
+ (K₁, K₂, Cᵢₙ, Cₒᵤₜ) = kernaxes (dcd)
74
+ @avxt for j₁ ∈ axes (out,1 ), j₂ ∈ axes (out,2 ), d ∈ axes (out,4 ), o ∈ Cₒᵤₜ
75
+ s = zero (eltype (out))
76
+ for k₁ ∈ K₁, k₂ ∈ K₂, i ∈ Cᵢₙ
77
+ s += img[j₁ + k₁ - 1 , j₂ + k₂ - 1 , i, d] * kern[k₁, k₂, i, o]
78
+ end
79
+ out[j₁, j₂, o, d] = s
80
+ end
81
+ out
82
+ end
83
+ function convlayer_direct! (
84
+ out:: AbstractArray{<:Any,4} , img, kern,
85
+ dcd:: DenseConvDims{2, <:Any, <:Any, <:Any}
86
+ )
87
+ (K₁, K₂, Cᵢₙ, Cₒᵤₜ) = kernaxes (dcd)
88
+ @inbounds @fastmath for j₁ ∈ axes (out,1 ), j₂ ∈ axes (out,2 ), d ∈ axes (out,4 ), o ∈ Cₒᵤₜ
89
+ s = zero (eltype (out))
90
+ for k₁ ∈ K₁, k₂ ∈ K₂, i ∈ Cᵢₙ
91
+ s += img[j₁ + k₁ - 1 , j₂ + k₂ - 1 , i, d] * kern[k₁, k₂, i, o]
92
+ end
93
+ out[j₁, j₂, o, d] = s
94
+ end
95
+ out
96
+ end
97
+
58
98
@testset " Threading" begin
99
+ dcd = DenseConvDims {2,(5,5),3,6} ()
100
+ kern4 = rand (Float32, 5 , 5 , 3 , 6 );
59
101
for M ∈ 17 : 399
60
102
# @show M
61
103
K = M; N = M;
74
116
out1 = OffsetArray (randn (size (A) .- 2 ), 1 , 1 )
75
117
out2 = similar (out1);
76
118
@test conv! (out1, A, kern) ≈ conv_baseline! (out2, A, kern)
119
+
120
+
121
+ img = rand (Float32, M, M, 3 , 100 );
122
+ out1 = Array {Float32} (undef, size (img,1 )+ 1 - size (kern4,1 ), size (img,2 )+ 1 - size (kern4,2 ), size (kern4,4 ), size (img,4 ));
123
+ out2 = similar (out1);
124
+
125
+ convlayer! (out1, img, kern4, dcd);
126
+ convlayer_direct! (out2, img, kern4, dcd);
127
+ @test out1 ≈ out2
128
+
77
129
end
78
130
end
79
131
0 commit comments