Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions src/ops/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,36 +176,44 @@ end
# TODO Clean this up
for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
@eval @op function $(Symbol("reduce_", reduction))(n::AbstractTensor; axis=nothing, keep_dims=false, name=nothing)
if name === nothing
name = get_name("reduce")
end
if axis == nothing
local desc
shape = get_shape(n)
nodetype = $(capitalize(reduction))

if axis == nothing && shape.rank_unknown
n = Tensor(n) # TODO: rewrite this
range_start = constant(Int32(0))
range_delta = constant(Int32(1))
desc = NodeDescription("Rank", "$name/rank")
add_input(desc, n)
rank = Tensor(Operation(desc), 1)
desc = NodeDescription("Range", "$name/range")
add_input(desc, range_start)
add_input(desc, rank)
add_input(desc, range_delta)
range = Tensor(Operation(desc), 1)
desc = NodeDescription($(capitalize(reduction)), name)
add_input(desc, n)
add_input(desc, range)
Tensor(Operation(desc), 1)
rank = tf.with_op_name(nothing, "Rank") do
desc_rank = NodeDescription("Rank")
add_input(desc_rank, n)
Tensor(Operation(desc_rank), 1)
end
range = tf.with_op_name(nothing, "range") do
@tf start = constant(Int32(0))
@tf delta = constant(Int32(1))
desc_range = NodeDescription("Range")
add_input(desc_range, start)
add_input(desc_range, rank)
add_input(desc_range, delta)
Tensor(Operation(desc_range), 1)
end
tf.with_op_name(name, nodetype) do
desc = NodeDescription(nodetype)
add_input(desc, n)
add_input(desc, range)
end
else
if isa(axis, Number)
axis = [axis]
tf.with_op_name(name, nodetype) do
if axis == nothing
axis = 1:length(shape.dims)
end
@tf reduction_indices = constant(Int32.(axis.-1))
desc = NodeDescription(nodetype)
add_input(desc, Tensor(n))
add_input(desc, reduction_indices)
desc["keep_dims"] = keep_dims
end
axis = [Int32(idx-1) for idx in axis]
desc = NodeDescription($(capitalize(reduction)), name)
add_input(desc, Tensor(n))
add_input(desc, Tensor(axis))
desc["keep_dims"] = keep_dims
Tensor(Operation(desc), 1)
end
Tensor(Operation(desc), 1)
end
end

Expand Down
41 changes: 38 additions & 3 deletions test/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ end
@testset "Naming" begin
let
g = Graph()
local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y, Ysum1, Ysum2, Ysum3, Ysum4
local i, j_jl, j, k, ijk, ij, ij2, fq, m, W, Y,
Ysum1, Ysum2, Ysum3, Ysum4, Ysum5, Ysum6, Ysum7, Ysum8,
p, psum1, psum2, psum3, psum4, psum5
as_default(g) do
@tf begin
i = constant(1.0)
Expand Down Expand Up @@ -46,6 +48,29 @@ end
Ysum3 = reduce_sum(Y, keep_dims=true) # With a comma (issue #188)

Ysum4 = reduce_sum(Y, keep_dims=true, name="namefor_Ysum4") # With a comma (issue #188)

Ysum5 = reduce_sum(Y, axis=2)

nn.tf.with_op_name("level1") do
Ysum6 = reduce_sum(Y)
nn.tf.with_op_name("level2") do
Ysum7 = reduce_sum(Y)
Ysum8 = reduce_sum(Y, axis=1)
end
end

p = placeholder(Float32)
psum1 = reduce_sum(p)
psum2 = reduce_sum(p, axis=1)

nn.tf.with_op_name("anotherlevel1") do
psum3 = reduce_sum(p)

nn.tf.with_op_name("level2") do
psum4 = reduce_sum(p)
psum5 = reduce_sum(p, axis=1)
end
end
end
end

Expand All @@ -68,8 +93,18 @@ end
@test Ysum2 == get_tensor_by_name(g, "Ysum2")
@test Ysum3 == get_tensor_by_name(g, "Ysum3")
@test Ysum4 == get_tensor_by_name(g, "namefor_Ysum4")


@test Ysum5 == get_tensor_by_name(g, "Ysum5")
@test Ysum6 == get_tensor_by_name(g, "level1/Ysum6")
@test Ysum7 == get_tensor_by_name(g, "level1/level2/Ysum7")
@test Ysum8 == get_tensor_by_name(g, "level1/level2/Ysum8")

@test psum1 == get_tensor_by_name(g, "psum1")
@test psum2 == get_tensor_by_name(g, "psum2")
@test psum3 == get_tensor_by_name(g, "anotherlevel1/psum3")
@test psum4 == get_tensor_by_name(g, "anotherlevel1/level2/psum4")
@test psum5 == get_tensor_by_name(g, "anotherlevel1/level2/psum5")

@test_throws TensorFlow.TFException reduce_sum(p, name="Ysum1")
end
end

Expand Down