Skip to content

Commit 47560f1

Browse files
committed
Add some splatting to fix tests
1 parent 5b35148 commit 47560f1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/nnpack/interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ end
2020

2121

2222
function conv_nnpack(x::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
23-
y = similar(x, output_size(cdims), channels_out(cdims), size(x, 4))
23+
y = similar(x, output_size(cdims)..., channels_out(cdims), size(x, 4))
2424
return conv_nnpack!(y, x, w, cdims; kwargs...)
2525
end
2626

2727

2828
function ∇conv_data(dy::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
29-
dx = similar(dy, input_size(cdims), channels_in(cdims), size(dy, 4))
29+
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, 4))
3030
return ∇conv_data!(dx, dy, w, cdims; kwargs...)
3131
end
3232

3333

3434
function ∇conv_filter(x::Array{T1, 4}, dy::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
35-
dw = similar(x, kernel_size(cdims), channels_in(cdims), channels_out(cdims))
35+
dw = similar(x, kernel_size(cdims)..., channels_in(cdims), channels_out(cdims))
3636
return ∇conv_filter!(dw, x, dy, cdims; kwargs...)
3737
end
3838

0 commit comments

Comments
 (0)