|
20 | 20 |
|
21 | 21 |
|
22 | 22 | function conv_nnpack(x::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
|
23 |
| - y = similar(x, output_size(cdims), channels_out(cdims), size(x, 4)) |
| 23 | + y = similar(x, output_size(cdims)..., channels_out(cdims), size(x, 4)) |
24 | 24 | return conv_nnpack!(y, x, w, cdims; kwargs...)
|
25 | 25 | end
|
26 | 26 |
|
27 | 27 |
|
28 | 28 | function ∇conv_data(dy::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
|
29 |
| - dx = similar(dy, input_size(cdims), channels_in(cdims), size(dy, 4)) |
| 29 | + dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, 4)) |
30 | 30 | return ∇conv_data!(dx, dy, w, cdims; kwargs...)
|
31 | 31 | end
|
32 | 32 |
|
33 | 33 |
|
34 | 34 | function ∇conv_filter(x::Array{T1, 4}, dy::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
|
35 |
| - dw = similar(x, kernel_size(cdims), channels_in(cdims), channels_out(cdims)) |
| 35 | + dw = similar(x, kernel_size(cdims)..., channels_in(cdims), channels_out(cdims)) |
36 | 36 | return ∇conv_filter!(dw, x, dy, cdims; kwargs...)
|
37 | 37 | end
|
38 | 38 |
|
|
0 commit comments