diff --git a/src/ops/math.jl b/src/ops/math.jl index 2e6d1daa..cf7909d9 100644 --- a/src/ops/math.jl +++ b/src/ops/math.jl @@ -176,36 +176,44 @@ end # TODO Clean this up for reduction in [:sum, :prod, :min, :max, :all, :any, :mean] @eval @op function $(Symbol("reduce_", reduction))(n::AbstractTensor; axis=nothing, keep_dims=false, name=nothing) - if name === nothing - name = get_name("reduce") - end - if axis == nothing + local desc + shape = get_shape(n) + nodetype = $(capitalize(reduction)) + + if axis == nothing && shape.rank_unknown n = Tensor(n) # TODO: rewrite this - range_start = constant(Int32(0)) - range_delta = constant(Int32(1)) - desc = NodeDescription("Rank", "$name/rank") - add_input(desc, n) - rank = Tensor(Operation(desc), 1) - desc = NodeDescription("Range", "$name/range") - add_input(desc, range_start) - add_input(desc, rank) - add_input(desc, range_delta) - range = Tensor(Operation(desc), 1) - desc = NodeDescription($(capitalize(reduction)), name) - add_input(desc, n) - add_input(desc, range) - Tensor(Operation(desc), 1) + rank = tf.with_op_name(nothing, "Rank") do + desc_rank = NodeDescription("Rank") + add_input(desc_rank, n) + Tensor(Operation(desc_rank), 1) + end + range = tf.with_op_name(nothing, "range") do + @tf start = constant(Int32(0)) + @tf delta = constant(Int32(1)) + desc_range = NodeDescription("Range") + add_input(desc_range, start) + add_input(desc_range, rank) + add_input(desc_range, delta) + Tensor(Operation(desc_range), 1) + end + tf.with_op_name(name, nodetype) do + desc = NodeDescription(nodetype) + add_input(desc, n) + add_input(desc, range) + end else - if isa(axis, Number) - axis = [axis] + tf.with_op_name(name, nodetype) do + if axis == nothing + axis = 1:length(shape.dims) + end + @tf reduction_indices = constant(Int32.(axis.-1)) + desc = NodeDescription(nodetype) + add_input(desc, Tensor(n)) + add_input(desc, reduction_indices) + desc["keep_dims"] = keep_dims end - axis = [Int32(idx-1) for idx in axis] - desc = NodeDescription($(capitalize(reduction)), name) - add_input(desc, Tensor(n)) - add_input(desc, Tensor(axis)) - desc["keep_dims"] = keep_dims - Tensor(Operation(desc), 1) end + Tensor(Operation(desc), 1) end end diff --git a/test/meta.jl b/test/meta.jl index 952b0b75..a0d4b2aa 100644 --- a/test/meta.jl +++ b/test/meta.jl @@ -18,7 +18,9 @@ end @testset "Naming" begin let g = Graph() - local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y, Ysum1, Ysum2, Ysum3, Ysum4 + local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y, + Ysum1, Ysum2, Ysum3, Ysum4, Ysum5, Ysum6, Ysum7, Ysum8, + p, psum1, psum2, psum3, psum4, psum5 as_default(g) do @tf begin i = constant(1.0) @@ -46,6 +48,29 @@ end Ysum3 = reduce_sum(Y, keep_dims=true) # With a comma (issue #188) Ysum4 = reduce_sum(Y, keep_dims=true, name="namefor_Ysum4") # With a comma (issue #188) + + Ysum5 = reduce_sum(Y, axis=2) + + nn.tf.with_op_name("level1") do + Ysum6 = reduce_sum(Y) + nn.tf.with_op_name("level2") do + Ysum7 = reduce_sum(Y) + Ysum8 = reduce_sum(Y, axis=1) + end + end + + p = placeholder(Float32) + psum1 = reduce_sum(p) + psum2 = reduce_sum(p, axis=1) + + nn.tf.with_op_name("anotherlevel1") do + psum3 = reduce_sum(p) + + nn.tf.with_op_name("level2") do + psum4 = reduce_sum(p) + psum5 = reduce_sum(p, axis=1) + end + end end end @@ -68,8 +93,18 @@ end @test Ysum2 == get_tensor_by_name(g, "Ysum2") @test Ysum3 == get_tensor_by_name(g, "Ysum3") @test Ysum4 == get_tensor_by_name(g, "namefor_Ysum4") - - + @test Ysum5 == get_tensor_by_name(g, "Ysum5") + @test Ysum6 == get_tensor_by_name(g, "level1/Ysum6") + @test Ysum7 == get_tensor_by_name(g, "level1/level2/Ysum7") + @test Ysum8 == get_tensor_by_name(g, "level1/level2/Ysum8") + + @test psum1 == get_tensor_by_name(g, "psum1") + @test psum2 == get_tensor_by_name(g, "psum2") + @test psum3 == get_tensor_by_name(g, "anotherlevel1/psum3") + @test psum4 == get_tensor_by_name(g, "anotherlevel1/level2/psum4") + @test psum5 == get_tensor_by_name(g, "anotherlevel1/level2/psum5") + + @test_throws TensorFlow.TFException reduce_sum(p, name="Ysum1") end end