Skip to content

Commit 5d39efc

Browse files
committed
refactor: remove ReadOnlyNode and simplify interface
1 parent 471b1fa commit 5d39efc

10 files changed

+37
-188
lines changed

src/DynamicExpressions.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ using DispatchDoctor: @stable, @unstable
2222
include("Random.jl")
2323
include("Parse.jl")
2424
include("ParametricExpression.jl")
25-
include("ReadOnlyNode.jl")
2625
include("StructuredExpression.jl")
2726
end
2827

@@ -103,7 +102,6 @@ import .ExpressionAlgebraModule: declare_operator_alias
103102
@reexport import .ParseModule: @parse_expression, parse_expression
104103
import .ParseModule: parse_leaf
105104
@reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode
106-
import .ReadOnlyNodeModule: ReadOnlyNode
107105
@reexport import .StructuredExpressionModule: StructuredExpression
108106
import .StructuredExpressionModule: AbstractStructuredExpression
109107

src/Interfaces.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ using ..ExpressionModule:
5757
with_metadata,
5858
default_node_type
5959
using ..ParametricExpressionModule: ParametricExpression, ParametricNode
60-
using ..ReadOnlyNodeModule: AbstractReadOnlyNode
6160
using ..StructuredExpressionModule: StructuredExpression
6261

6362
###############################################################################
@@ -76,7 +75,7 @@ end
7675
function _check_get_tree(
7776
ex::AbstractExpression{T,N}
7877
) where {T,D,N<:AbstractExpressionNode{T,D}}
79-
return get_tree(ex) isa N || get_tree(ex) isa AbstractReadOnlyNode{T,D,N}
78+
return get_tree(ex) isa N
8079
end
8180
function _check_get_operators(ex::AbstractExpression)
8281
return get_operators(ex) isa AbstractOperatorEnum
@@ -151,8 +150,7 @@ end
151150
function _check_tree_mapreduce(
152151
ex::AbstractExpression{T,N}
153152
) where {T,D,N<:AbstractExpressionNode{T,D}}
154-
return tree_mapreduce(node -> [node], vcat, ex) isa
155-
(Vector{N2} where {N2<:Union{N,AbstractReadOnlyNode{T,D,N}}})
153+
return tree_mapreduce(node -> [node], vcat, ex) isa Vector{<:N}
156154
end
157155

158156
#! format: off

src/Node.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -298,25 +298,30 @@ Base.eltype(::AbstractExpressionNode{T}) where {T} = T
298298

299299
has_max_degree(::Type{<:AbstractNode}) = false
300300
has_max_degree(::Type{<:AbstractNode{D}}) where {D} = true
301+
has_eltype(::Type{<:AbstractExpressionNode}) = false
302+
has_eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = true
301303
# COV_EXCL_STOP
302304
#! format: on
303305

304-
@unstable function constructorof(::Type{N}) where {N<:Node}
305-
return Node{T,max_degree(N)} where {T}
306+
@unstable function node_wrapper(::Type{N}) where {N<:AbstractExpressionNode}
307+
return Base.typename(N).wrapper
306308
end
307-
@unstable function constructorof(::Type{N}) where {N<:GraphNode}
308-
return GraphNode{T,max_degree(N)} where {T}
309+
@unstable function constructorof(::Type{N}) where {N<:AbstractExpressionNode}
310+
return node_wrapper(N){T,max_degree(N)} where {T}
309311
end
310-
311-
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T}
312-
return Node{T,max_degree(N)}
312+
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
313+
return node_wrapper(N){T,max_degree(N)}
313314
end
314-
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T}
315-
return GraphNode{T,max_degree(N)}
315+
@unstable function with_max_degree(::Type{N}, ::Val{D}) where {N<:AbstractExpressionNode,D}
316+
if has_eltype(N)
317+
return node_wrapper(N){eltype(N),D}
318+
else
319+
return node_wrapper(N){T,D} where {T}
320+
end
321+
end
322+
@unstable function with_default_max_degree(::Type{N}) where {N<:AbstractNode}
323+
return with_max_degree(N, Val(max_degree(N)))
316324
end
317-
318-
with_max_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D}
319-
with_max_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D}
320325

321326
function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
322327
return with_type_parameters(N, T)()
@@ -341,17 +346,26 @@ include("base.jl")
341346
else
342347
children
343348
end
344-
validate_not_all_defaults(N, val, feature, op, _children)
349+
if all_defaults(N, val, feature, op, _children)
350+
return make_default(N, T1)
351+
end
345352
return node_factory(N, T1, val, feature, op, _children, allocator)
346353
end
347-
function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode}
348-
if all(isnothing, (val, feature, op, children))
354+
function make_default(::Type{N}, ::Type{T1}) where {T1,N<:AbstractExpressionNode}
355+
if has_max_degree(N)
349356
error(
350357
"Encountered the call for $N() inside the generic constructor. "
351358
* "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?"
352359
)
353360
end
354-
return nothing
361+
if T1 === Undefined
362+
return with_default_max_degree(N)()
363+
else
364+
return with_type_parameters(with_default_max_degree(N), T1)()
365+
end
366+
end
367+
function all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode}
368+
return all(isnothing, (val, feature, op, children))
355369
end
356370
"""Create a constant leaf."""
357371
@inline function node_factory(

src/ParametricExpression.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import ..NodeModule:
1313
constructorof,
1414
with_type_parameters,
1515
with_max_degree,
16+
with_default_max_degree,
1617
max_degree,
1718
preserve_sharing,
1819
get_children,
@@ -118,16 +119,10 @@ end
118119
###############################################################################
119120
# Abstract expression node interface ##########################################
120121
###############################################################################
121-
@unstable constructorof(::Type{N}) where {N<:ParametricNode} =
122-
ParametricNode{T,max_degree(N)} where {T}
123122
@unstable constructorof(::Type{<:ParametricExpression}) = ParametricExpression
124-
function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T}
125-
return ParametricNode{T,max_degree(N)}
126-
end
127-
function with_max_degree(::Type{N}, ::Val{D}) where {T,N<:ParametricNode{T},D}
128-
return ParametricNode{T,D}
123+
@unstable function default_node_type(::Type{<:ParametricExpression})
124+
return with_default_max_degree(ParametricNode)
129125
end
130-
@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T}
131126
function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}}
132127
return ParametricNode{T,max_degree(N)}
133128
end

src/ReadOnlyNode.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/StructuredExpression.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import ..ExpressionModule:
2121
node_type,
2222
get_scalar_constants,
2323
set_scalar_constants!
24-
import ..ReadOnlyNodeModule: ReadOnlyNode
2524

2625
abstract type AbstractStructuredExpression{
2726
T,F<:Function,N<:AbstractExpressionNode{T},E<:AbstractExpression{T,N},D<:NamedTuple
@@ -132,7 +131,7 @@ function get_metadata(e::AbstractStructuredExpression)
132131
return e.metadata
133132
end
134133
function get_tree(e::AbstractStructuredExpression)
135-
return ReadOnlyNode(get_tree(get_metadata(e).structure(get_contents(e))))
134+
return get_tree(get_metadata(e).structure(get_contents(e)))
136135
end
137136
function get_operators(
138137
e::AbstractStructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing

test/test_n_arity_nodes.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -448,44 +448,6 @@ end
448448
@test node_from_pex.children[3].feature == 3
449449
end
450450

451-
@testitem "ReadOnlyNode with N-ary Node" tags = [:narity] begin
452-
using DynamicExpressions
453-
using Test
454-
455-
my_ro_unary_op(x) = x
456-
my_ro_binary_op(x, y) = x
457-
my_ro_ternary_op(x, y, z) = x
458-
459-
operators_ro = OperatorEnum(
460-
1 => (my_ro_unary_op,), 2 => (my_ro_binary_op,), 3 => (my_ro_ternary_op,)
461-
)
462-
DynamicExpressions.@extend_operators operators_ro
463-
464-
x1_ro = Node{Float64,3}(; feature=1)
465-
x2_ro = Node{Float64,3}(; feature=2)
466-
x3_ro = Node{Float64,3}(; feature=3)
467-
tree_ro_ter = Node{Float64,3}(; op=1, children=(x1_ro, x2_ro, x3_ro))
468-
469-
expr_ro = Expression(tree_ro_ter; operators=operators_ro)
470-
readonly_tree = DynamicExpressions.ReadOnlyNode(DynamicExpressions.get_tree(expr_ro))
471-
472-
@test readonly_tree isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode
473-
inner_node_ro = DynamicExpressions.ReadOnlyNodeModule.inner(readonly_tree)
474-
@test DynamicExpressions.NodeModule.max_degree(inner_node_ro) == 3
475-
@test readonly_tree.degree == 3
476-
@test readonly_tree.op == 1
477-
478-
ro_children = DynamicExpressions.NodeModule.get_children(readonly_tree, Val(3))
479-
@test length(ro_children) == 3
480-
@test ro_children[1] isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode
481-
@test ro_children[1].feature == 1
482-
@test ro_children[2].feature == 2
483-
@test ro_children[3].feature == 3
484-
485-
@test readonly_tree.l.feature == 1
486-
@test readonly_tree.r.feature == 2
487-
end
488-
489451
@testitem "NodeUtils.jl NodeIndex for N-ary" tags = [:narity] begin
490452
using DynamicExpressions
491453
using Test

test/test_operator_construction_edgecases.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ end
1515

1616
@testitem "OperatorEnumConstruction helper internals & edge cases" begin
1717
using DynamicExpressions.OperatorEnumConstructionModule:
18-
_unpack_broadcast_function,
19-
OperatorEnum,
20-
empty_all_globals!
18+
_unpack_broadcast_function, OperatorEnum, empty_all_globals!
2119
using Base.Broadcast: BroadcastFunction
2220
using Test
2321

test/test_readonlynode.jl

Lines changed: 0 additions & 69 deletions
This file was deleted.

test/unittest.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ include("test_operator_construction_edgecases.jl")
131131
include("test_node_interface.jl")
132132
include("test_expression_math.jl")
133133
include("test_structured_expression.jl")
134-
include("test_readonlynode.jl")
135134
include("test_zygote_gradient_wrapper.jl")
136135
include("test_supposition_consistency.jl")
137136
include("test_n_arity_nodes.jl")

0 commit comments

Comments
 (0)