diff --git a/REQUIRE b/REQUIRE index 080e1af8..79a4e614 100644 --- a/REQUIRE +++ b/REQUIRE @@ -13,3 +13,4 @@ MacroTools 0.3.6 AutoHashEquals 0.1.0 MLDatasets 0.3.0 SpecialFunctions 0.7.0 +Optim 0.17.0 diff --git a/deps/build.jl b/deps/build.jl index 1a104aed..caabe43d 100644 --- a/deps/build.jl +++ b/deps/build.jl @@ -1,8 +1,8 @@ using PyCall using Conda -const cur_version = "1.10.0" -const cur_py_version = "1.10.0" +const cur_version = "1.12.0" +const cur_py_version = "1.12.0" ############################ diff --git a/deps/default_imports.txt b/deps/default_imports.txt index 9368e056..839cf659 100644 --- a/deps/default_imports.txt +++ b/deps/default_imports.txt @@ -166,3 +166,6 @@ Rank Conv2DBackpropInput Svd Cross +FFT +ComplexAbs +MatrixSolve diff --git a/src/TensorFlow.jl b/src/TensorFlow.jl index 69b76c45..228c93ac 100644 --- a/src/TensorFlow.jl +++ b/src/TensorFlow.jl @@ -128,6 +128,7 @@ tf_versioninfo using Distributed +using Optim const pyproc = Ref(0) diff --git a/src/ops/nn.jl b/src/ops/nn.jl index dcd47a3c..cb808a50 100644 --- a/src/ops/nn.jl +++ b/src/ops/nn.jl @@ -36,6 +36,25 @@ import .rnn_cell: zero_state, output_size, state_size conv2d(input, filter; padding=padding, strides=strides, kwargs...) end +@tf.op function conv1d(input, filter_, strides_::Int64, padding::String; data_format="NHWC", kwargs...) + spatial_start_dim = 0 + if data_format=="NHWC" + strides_ = [1,1,strides_,1] + spatial_start_dim = 2 + elseif data_format == "NCHW" || data_format == "NCW" + data_format = "NCHW" + spatial_start_dim = 3 + strides_ = [1,1,1,strides_] + else + @error "data_format must be NHWC or NCHW or NCW" + end + input = Ops.expand_dims(input, spatial_start_dim) + filter_ = Ops.expand_dims(filter_, 1) + result = Ops.conv2d(input, filter_; strides = strides_, padding = padding, data_format=data_format, kwargs...) + result = Ops.squeeze(result, squeeze_dims=[spatial_start_dim-1]) + return result +end + # Same for max pool @tf.op function max_pool(input, ksize, strides, padding; kwargs...) max_pool(input; ksize=ksize, strides=strides, padding=padding, kwargs...) diff --git a/test/nn.jl b/test/nn.jl index 6ef29e52..2bf1e64c 100644 --- a/test/nn.jl +++ b/test/nn.jl @@ -4,6 +4,26 @@ using StatsFuns using Random import LinearAlgebra +@testset "conv1d" begin + let + sess = Session(Graph()) + F = zeros(Float32, 2, 3, 4) # batch_size = 2, dimension = 3, channle = 4 + for i = 1:2 + for j = 1:3 + for k = 1:4 + F[i,j,k] = Float32(i+j+k-3) + end + end + end + input = constant(F) + filter_ = constant(ones(Float32, 3, 4, 1)) # width = 3, input channel = 4 output channel = 1 + output = nn.conv1d(input, filter_, 2, "VALID") + output_val = run(sess, output) + ref_val = reshape(Float32[30.0;42.0], 2, 1, 1) + @test ref_val ≈ output_val + end +end + @testset "conv2d_transpose" begin let sess = Session(Graph())