Skip to content

Commit 24eccc9

Browse files
committed
local desc
1 parent c68f188 commit 24eccc9

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/ops/math.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,18 +176,20 @@ end
176176
# TODO Clean this up
177177
for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
178178
@eval @op function $(Symbol("reduce_", reduction))(n::AbstractTensor; axis=nothing, keep_dims=false, name=nothing)
179+
local desc
180+
shape = get_shape(n)
181+
179182
if name == nothing
180183
name = $(capitalize(reduction))
181184
end
182185

183-
shape = get_shape(n)
184186
if axis == nothing && shape.rank_unknown
185187
n = Tensor(n) # TODO: rewrite this
186-
desc_rank = tf.with_op_name(nothing, "Rank") do
187-
NodeDescription("Rank")
188+
rank = tf.with_op_name(nothing, "Rank") do
189+
desc_rank = NodeDescription("Rank")
190+
add_input(desc_rank, n)
191+
Tensor(Operation(desc_rank), 1)
188192
end
189-
add_input(desc_rank, n)
190-
rank = Tensor(Operation(desc_rank), 1)
191193
range = tf.with_op_name(nothing, "range") do
192194
@tf start = constant(Int32(0))
193195
@tf delta = constant(Int32(1))
@@ -197,14 +199,13 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
197199
add_input(desc_range, delta)
198200
Tensor(Operation(desc_range), 1)
199201
end
200-
desc = tf.with_op_name(nothing, name) do
202+
tf.with_op_name(nothing, name) do
201203
desc = NodeDescription($(capitalize(reduction)))
202204
add_input(desc, n)
203205
add_input(desc, range)
204-
desc
205206
end
206207
else
207-
desc = tf.with_op_name(nothing, name) do
208+
tf.with_op_name(nothing, name) do
208209
if axis == nothing
209210
axis = 1:length(shape.dims)
210211
end
@@ -213,7 +214,6 @@ for reduction in [:sum, :prod, :min, :max, :all, :any, :mean]
213214
add_input(desc, Tensor(n))
214215
add_input(desc, reduction_indices)
215216
desc["keep_dims"] = keep_dims
216-
desc
217217
end
218218
end
219219
Tensor(Operation(desc), 1)

0 commit comments

Comments
 (0)