Skip to content

Commit b0b6ca0

Browse files
committed
Remove SumNode/ProdNode, use Node2 chains via reduce(+/*) instead
exa_sum now folds children with reduce(+, ...) producing a tree of Node2{typeof(+)}, and exa_prod uses reduce(*, ...) for Node2{typeof(*)}. This reuses the existing registered +/* dispatch for all three passes (primal, adjoint, second-adjoint) without needing dedicated node types. Empty iterators return Null(0) / Null(1) respectively. https://claude.ai/code/session_01QsVaXnG1Cw7LdtCzvVYRoo
1 parent 08d8b4e commit b0b6ca0

File tree

3 files changed

+16
-86
lines changed

3 files changed

+16
-86
lines changed

src/graph.jl

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -318,67 +318,3 @@ end
318318
@inline (v::Null{Nothing})(i, x::SecondAdjointNodeSource{T}, θ) where {T} = SecondAdjointNull(zero(eltype(T)))
319319
@inline (v::Null{N})(i, x::SecondAdjointNodeSource{T}, θ) where {N, T} = SecondAdjointNull(eltype(T)(v.value))
320320

321-
# ── SumNode / ProdNode ────────────────────────────────────────────────────────
322-
323-
"""
324-
SumNode{I} <: AbstractNode
325-
326-
A node representing the sum of a tuple of child nodes.
327-
328-
Constructed by [`exa_sum`](@ref). Within `@obj`, `@con`, and `@expr` macros,
329-
`sum(body for k in range)` is automatically rewritten to
330-
`exa_sum(k -> body, Val(range))` with the `Val` hoisted outside the generator
331-
closure for type stability under `juliac --trim=safe`.
332-
333-
In adjoint / second-adjoint mode the children are evaluated and folded via
334-
`reduce(+, …)` (or `reduce(*, …)` for [`ProdNode`](@ref)), reusing the existing
335-
registered `+` / `*` dispatch. No dedicated adjoint node types are needed.
336-
"""
337-
struct SumNode{I} <: AbstractNode
338-
inners::I
339-
end
340-
341-
"""
342-
ProdNode{I} <: AbstractNode
343-
344-
A node representing the product of a tuple of child nodes.
345-
346-
Constructed by [`exa_prod`](@ref). See [`SumNode`](@ref) for design notes.
347-
"""
348-
struct ProdNode{I} <: AbstractNode
349-
inners::I
350-
end
351-
352-
# ── Primal evaluation (x::AbstractVector → scalar) ───────────────────────────
353-
354-
@inline (n::SumNode{Tuple{}})(i, x::V, θ) where {T, V<:AbstractVector{T}} = zero(T)
355-
@inline (n::SumNode)(i, x::V, θ) where {T, V<:AbstractVector{T}} =
356-
mapreduce(inner -> inner(i, x, θ), +, n.inners)
357-
358-
@inline (n::ProdNode{Tuple{}})(i, x::V, θ) where {T, V<:AbstractVector{T}} = one(T)
359-
@inline (n::ProdNode)(i, x::V, θ) where {T, V<:AbstractVector{T}} =
360-
mapreduce(inner -> inner(i, x, θ), *, n.inners)
361-
362-
# ── Adjoint tree (gradient) ───────────────────────────────────────────────────
363-
364-
@inline (n::SumNode{Tuple{}})(i, x::AdjointNodeSource{VT}, θ) where {VT} =
365-
AdjointNull(zero(eltype(VT)))
366-
@inline (n::SumNode)(i, x::AdjointNodeSource, θ) =
367-
reduce(+, map(inner -> inner(i, x, θ), n.inners))
368-
369-
@inline (n::ProdNode{Tuple{}})(i, x::AdjointNodeSource{VT}, θ) where {VT} =
370-
AdjointNull(one(eltype(VT)))
371-
@inline (n::ProdNode)(i, x::AdjointNodeSource, θ) =
372-
reduce(*, map(inner -> inner(i, x, θ), n.inners))
373-
374-
# ── Second-adjoint tree (Hessian) ─────────────────────────────────────────────
375-
376-
@inline (n::SumNode{Tuple{}})(i, x::SecondAdjointNodeSource{VT}, θ) where {VT} =
377-
SecondAdjointNull(zero(eltype(VT)))
378-
@inline (n::SumNode)(i, x::SecondAdjointNodeSource, θ) =
379-
reduce(+, map(inner -> inner(i, x, θ), n.inners))
380-
381-
@inline (n::ProdNode{Tuple{}})(i, x::SecondAdjointNodeSource{VT}, θ) where {VT} =
382-
SecondAdjointNull(one(eltype(VT)))
383-
@inline (n::ProdNode)(i, x::SecondAdjointNodeSource, θ) =
384-
reduce(*, map(inner -> inner(i, x, θ), n.inners))

src/simdfunction.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,3 @@ end
178178
@inline replace_T(t, n::Null{T}) where T <: Real = Null{t}(t(n.value))
179179
@inline replace_T(::Type{T1}, n::T2) where {T1, T2 <: Real} = T1(n)
180180
@inline replace_T(::Type{T1}, ::Val{V}) where {T1, V} = Val(T1(V))
181-
@inline function replace_T(t, n::SumNode{I}) where {I}
182-
inners = map(x -> replace_T(t, x), n.inners)
183-
SumNode(inners)
184-
end
185-
@inline function replace_T(t, n::ProdNode{I}) where {I}
186-
inners = map(x -> replace_T(t, x), n.inners)
187-
ProdNode(inners)
188-
end

src/specialization.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,11 @@ end
247247
exa_sum(f, ::Val{range})
248248
exa_sum(gen::Base.Generator)
249249
250-
Build a [`SumNode`](@ref) representing `∑ f(k)` for `k ∈ itr`. Inside `@obj`,
251-
`@con`, and `@expr` macros, `sum(body for k in range)` is automatically
252-
rewritten to `exa_sum(k -> body, Val(range))` with the `Val` hoisted outside
253-
the generator closure.
250+
Build a node tree representing `∑ f(k)` for `k ∈ itr`, folded via `+` into
251+
a chain of `Node2{typeof(+)}`. Inside `@obj`, `@con`, and `@expr` macros,
252+
`sum(body for k in range)` is automatically rewritten to
253+
`exa_sum(k -> body, Val(range))` with the `Val` hoisted outside the generator
254+
closure.
254255
255256
# Supported iterators
256257
- `Tuple`: type-stable via tail recursion.
@@ -268,23 +269,24 @@ v = Val(1:nc) # outside generator
268269
c, con = add_con(c, (exa_sum(j -> x[j], v) for i in 1:nh)) # v captured
269270
```
270271
"""
271-
@inline exa_sum(f, ::Tuple{}) = SumNode(())
272-
@inline exa_sum(f, t::Tuple) = SumNode(_exa_map(f, t))
272+
@inline exa_sum(f, ::Tuple{}) = Null(0)
273+
@inline exa_sum(f, t::Tuple) = reduce(+, _exa_map(f, t))
273274

274275
"""
275276
exa_prod(f, itr)
276277
exa_prod(f, ::Val{range})
277278
exa_prod(gen::Base.Generator)
278279
279-
Build a [`ProdNode`](@ref) representing `∏ f(k)` for `k ∈ itr`. Inside `@obj`,
280-
`@con`, and `@expr` macros, `prod(body for k in range)` is automatically
281-
rewritten to `exa_prod(k -> body, Val(range))` with the `Val` hoisted outside
282-
the generator closure.
280+
Build a node tree representing `∏ f(k)` for `k ∈ itr`, folded via `*` into
281+
a chain of `Node2{typeof(*)}`. Inside `@obj`, `@con`, and `@expr` macros,
282+
`prod(body for k in range)` is automatically rewritten to
283+
`exa_prod(k -> body, Val(range))` with the `Val` hoisted outside the
284+
generator closure.
283285
284286
See [`exa_sum`](@ref) for supported iterators and juliac usage notes.
285287
"""
286-
@inline exa_prod(f, ::Tuple{}) = ProdNode(())
287-
@inline exa_prod(f, t::Tuple) = ProdNode(_exa_map(f, t))
288+
@inline exa_prod(f, ::Tuple{}) = Null(1)
289+
@inline exa_prod(f, t::Tuple) = reduce(*, _exa_map(f, t))
288290

289291
# ── UnitRange{Int}: Val{N}-based recursion ────────────────────────────────────
290292

@@ -305,10 +307,10 @@ end
305307
end
306308

307309
@inline _exa_sum_range(f, lo::Int, ::Val{N}) where {N} =
308-
SumNode(ntuple(i -> f(lo + i - 1), Val{N}()))
310+
reduce(+, ntuple(i -> f(lo + i - 1), Val{N}()))
309311

310312
@inline _exa_prod_range(f, lo::Int, ::Val{N}) where {N} =
311-
ProdNode(ntuple(i -> f(lo + i - 1), Val{N}()))
313+
reduce(*, ntuple(i -> f(lo + i - 1), Val{N}()))
312314

313315
# ── Generator form (direct use outside macro) ─────────────────────────────────
314316

0 commit comments

Comments
 (0)