Skip to content

Commit 87102d8

Browse files
committed
feat: create n-arity operator enum
1 parent b6d187b commit 87102d8

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

src/Evaluate.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,14 @@ end
251251

252252
# These are marked unstable due to issues discussed on
253253
# https://github.com/JuliaLang/julia/issues/55147
254-
@unstable get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
255-
@unstable get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)
254+
@unstable function get_nuna(::Type{<:OperatorEnum{OPS}}) where {OPS}
255+
ts = OPS.types
256+
return isempty(ts) ? 0 : counttuple(ts[1])
257+
end
258+
@unstable function get_nbin(::Type{<:OperatorEnum{OPS}}) where {OPS}
259+
ts = OPS.types
260+
return length(ts) == 1 ? 0 : counttuple(ts[2])
261+
end
256262

257263
function _eval_tree_array(
258264
tree::AbstractExpressionNode{T},

src/OperatorEnum.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,49 @@ abstract type AbstractOperatorEnum end
66
OperatorEnum
77
88
Defines an enum over operators, along with their derivatives.
9+
910
# Fields
10-
- `binops`: A tuple of binary operators. Scalar input type.
11-
- `unaops`: A tuple of unary operators. Scalar input type.
11+
- `ops`: A tuple of operators, with index `i` corresponding to the operator tuple for a node of degree `i`.
1212
"""
13-
struct OperatorEnum{B,U} <: AbstractOperatorEnum
14-
binops::B
15-
unaops::U
13+
struct OperatorEnum{OPS<:Tuple{Vararg{Tuple}}} <: AbstractOperatorEnum
14+
ops::OPS
15+
end
16+
17+
function OperatorEnum(binary_operators::Tuple, unary_operators::Tuple)
18+
return OperatorEnum((unary_operators, binary_operators))
1619
end
1720

1821
"""
1922
GenericOperatorEnum
2023
2124
Defines an enum over operators, along with their derivatives.
25+
2226
# Fields
23-
- `binops`: A tuple of binary operators.
24-
- `unaops`: A tuple of unary operators.
27+
- `ops`: A tuple of operators, with index `i` corresponding to the operator tuple for a node of degree `i`.
2528
"""
26-
struct GenericOperatorEnum{B,U} <: AbstractOperatorEnum
27-
binops::B
28-
unaops::U
29+
struct GenericOperatorEnum{OPS<:Tuple{Vararg{Tuple}}} <: AbstractOperatorEnum
30+
ops::OPS
31+
end
32+
33+
function GenericOperatorEnum(binops::Tuple, unaops::Tuple)
34+
return GenericOperatorEnum((unaops, binops))
2935
end
3036

3137
Base.copy(op::AbstractOperatorEnum) = op
3238
# TODO: Is this safe? What if a vector is passed here?
3339

40+
@inline function Base.getindex(op::AbstractOperatorEnum, i::Int)
41+
return getfield(op, :ops)[i]
42+
end
43+
@inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol)
44+
if k == :unaops
45+
return getfield(op, :ops)[1]
46+
elseif k == :binops
47+
ops = getfield(op, :ops)
48+
return length(ops) > 1 ? ops[2] : ()
49+
else
50+
return getfield(op, k)
51+
end
52+
end
53+
3454
end

0 commit comments

Comments
 (0)