diff --git a/src/ProtoStruct.jl b/src/ProtoStruct.jl index 1ff6462..f288eb6 100644 --- a/src/ProtoStruct.jl +++ b/src/ProtoStruct.jl @@ -1,7 +1,17 @@ const revise_uuid = Base.UUID("295af30f-e4ad-537b-8983-00126c2a3abe") +using Base: nothing_sentinel const revise_pkgid = Base.PkgId(revise_uuid, "Revise") +Base.@kwdef struct ProtoInfo + world::UInt64 + paramcount::Int + fields::NamedTuple +end + +"Definitions of proto structs so we can upgrade instances on demand" +const DEFS = Dict{Tuple{Module,Symbol}, ProtoInfo}() + function checkrev() revise_pkgid in keys(Base.loaded_modules) || return false d = @__DIR__ @@ -12,6 +22,48 @@ function checkrev() return f in keys(trackedfiles) end +# Base._kwdef! is not available in 1.11 +function local_kwdef!(expr, params, calls) + for (i, ex) in enumerate(expr.args) + ex isa LineNumberNode && continue + isconst = ex isa Expr && ex.head == :const + if isconst + ex = ex.args[1] + end + name = ex + if name isa Expr && name.head == :(=) + name = name.args[1] + end + if name isa Expr && name.head == :(::) + name = name.args[1] + end + if !(name isa Symbol) + error("ProtoStructs cannot handle struct field $ex") + end + if ex isa Expr && ex.head == :(=) + # remove assignment from expr + newexpr = ex.args[1] + push!(params, Expr(:kw, name, ex.args[2])) + else + newexpr = ex + push!(params, name) + end + if isconst + newexpr = Expr(:const, newexpr) + end + expr.args[i] = newexpr + push!(calls, name) + end +end + +function defsfor(t::Type) + get(DEFS, (parentmodule(t), nameof(t)), nothing) +end + +function setdefs!(t::Type, info::ProtoInfo) + DEFS[(parentmodule(t), nameof(t))] = info +end + macro proto(expr) if checkrev() @eval __module__ $(_proto(expr)) @@ -81,177 +133,342 @@ function _proto(expr) end end i = 0 - field_info = map(fields) do field + field_info = (; (map(fields) do field i += 1 if field isa Symbol - return (field, Any, const_fields[i]) + return field => (; name=field, type=Any, isconst=!ismutable || const_fields[i]) else - return (field.args[1], field.args[2], const_fields[i]) + return field.args[1] => (; name=field.args[1], type=field.args[2], isconst=!ismutable || const_fields[i]) end - end + end)..., + ) - field_names = Tuple(getindex.(field_info, 1)) - const_field_names = [f for (f, fi) in zip(field_names, field_info) if fi[3] == true] + field_names = keys(field_info) + const_field_names = [info.name for info in field_info if info.isconst] if ismutable - field_types = :(Tuple{$((x in const_field_names ? :($x where {$x}) : (:(Base.RefValue{$x} where {$x})) - for x in getindex.(field_info, 2))...)}) + base_field_types = :(Tuple{$(getindex.(values(field_info), :type)...)}) + field_types = :(Tuple{$((info.isconst ? :($(info.type)) : + :(Base.RefValue{$(info.type)}) + for info in field_info)...)}) fields_with_ref = (x in const_field_names ? :($x=$x) : (:($x=Ref($x))) for x in field_names) else - field_types = :(Tuple{$(getindex.(field_info, 2)...)}) + field_types = :(Tuple{$(getindex.(values(field_info), :type)...)}) end - field_subtype_info = map(getindex.(field_info, 2)) do ft - if ft in type_parameter_names - return type_parameter_types[ft] - else - return ft - end - end + # UNUSED + #field_subtype_info = map(getindex.(field_info, 2)) do ft + # if ft in type_parameter_names + # return type_parameter_types[ft] + # else + # return ft + # end + #end params_ex = Expr(:parameters) call_args = Any[] - - Base._kwdef!(expr.args[3], params_ex.args, call_args) - - # remove escapes - params_ex.args = map(params_ex.args) do ex - if ex isa Symbol return ex end - ex.args[2] = ex.args[2].args[1] - ex - end - + local_kwdef!(expr.args[3], params_ex.args, call_args) default_params = [Symbol("P", i) for i in 1:15] N_any_params = length(default_params) - length(type_parameter_names) N_any_params <= 0 && error("The number of parameters of the proto struct is too high") any_params = [:(Any) for _ in 1:N_any_params] + # merge default values into field_info + field_info = + (; + (name => (; name, type, isconst, hasdefault, default) + for ((name, type, isconst), (hasdefault, default)) in + zip(field_info, getdefault.(params_ex.args)))..., + ) + world = UInt64(Base.get_world_counter()) + UNIQ = gensym() ex = if ismutable quote if !@isdefined $name - Base.@__doc__ struct $name{$(default_params...), NT<:NamedTuple} <: $abstract_type - properties::NT + Base.@__doc__ struct $name{$(default_params...)} <: $abstract_type + info::Ref{$ProtoStructs.ProtoInfo} + properties::Ref{NamedTuple} end else if ($abstract_type != Any) && ($abstract_type != Base.supertype($name)) error("The supertype of a proto struct is not redefinable. Please restart your julia session.") end - the_methods = collect(methods($name)) - if length(the_methods) >= 1 - Base.delete_method(the_methods[1]) - end - if length(the_methods) >= 2 - Base.delete_method(the_methods[2]) - end + $ProtoStructs.cleanup_struct($name, $(Symbol("new_$name"))) end - function $name($(fields...)) where {$(type_parameters...)} + function $(Symbol("new_$name"))($(fields...)) where {$(type_parameters...)} v = NamedTuple{$field_names}(($(fields_with_ref...),)) - return $name{$(type_parameter_names...), $(any_params...), typeof(v)}(v) + return $name{$(type_parameter_names...), $(any_params...)}(Ref($ProtoStructs.defsfor($name)), v) end - function $name{$(type_parameter_names...)}($(fields...)) where {$(type_parameters...)} - v = NamedTuple{$field_names}(($(fields_with_ref...),)) - return $name{$(type_parameter_names...), $(any_params...), typeof(v)}(v) + function $name($(field_names...)) + local prop_types = $ProtoStructs.property_types($name, $((:Any for _ in 1:length(type_parameters))...)).parameters[2].parameters + # call new_NAME so it can infer the type parameters + return $(Symbol("new_$name"))( + $((:($ProtoStructs.convert_field($name, $arg, prop_types[$i])) + for (i, arg) in enumerate(field_names))...)) end - function $name($params_ex) - return $name($(call_args...)) + function $name{$(type_parameter_names...)}($(field_names...)) where {$(type_parameters...)} + v = NamedTuple{$field_names, $field_types}(($( + (info.isconst ? :($ProtoStructs.convert_field($name, $(info.name), $type)) : + :(Ref{$type}($ProtoStructs.convert_field($name, $(info.name), $type))) + for (type, info) in zip(base_field_types.args[2:end], field_info))...),)) + return $name{$(type_parameter_names...), $(any_params...)}(Ref($ProtoStructs.defsfor($name)), v) + end + + function $name($params_ex) where {$(type_parameters...)} + $name($(call_args...)) end function $name{$(type_parameter_names...)}($params_ex) where {$(type_parameters...)} $name{$(type_parameter_names...)}($(call_args...)) end + function $ProtoStructs.default_for(::$name{$(type_parameter_names...)}, field::Symbol) where {$(type_parameter_names...)} + return (; + $((Expr(:kw, p.args[1], p.args[2]) for p in params_ex.args if !(p isa Symbol))...) + )[field] + end + + function $ProtoStructs.property_types(::Type{$UNIQ}, $(type_parameter_names...)) where {$UNIQ <: $name} + return NamedTuple{$field_names, $base_field_types} + end + function Base.getproperty(o::$name, s::Symbol) - p = getproperty(getfield(o, :properties), s) - if p isa Base.RefValue - p[] - else - p - end + $ProtoStructs.updateproto(o) + p = getproperty(getfield(o, :properties)[], s) + s ∈ $const_field_names ? p : p[] end function Base.setproperty!(o::$name, s::Symbol, v) - p = getproperty(getfield(o, :properties), s) - if p isa Base.RefValue - p[] = v - else + $ProtoStructs.updateproto(o) + p = getproperty(getfield(o, :properties)[], s) + if s ∈ $const_field_names error("const field $s of type ", $name, " cannot be changed") + else + p[] = v end end function Base.propertynames(o::$name) - return propertynames(getfield(o, :properties)) + $ProtoStructs.updateproto(o) + return propertynames(getfield(o, :properties)[]) end function Base.show(io::IO, o::$name) - vals = join([x isa Base.RefValue ? (x[] isa String ? "\"$(x[])\"" : x[]) : x for x in getfield(o, :properties)], ", ") - params = typeof(o).parameters[1:end-$N_any_params-1] + $ProtoStructs.updateproto(o) + vals = join([x isa Base.RefValue ? (x[] isa String ? "\"$(x[])\"" : x[]) : x for x in getfield(o, :properties)[]], ", ") + params = typeof(o).parameters[1:end-$N_any_params] if isempty(params) print(io, string($name), "($vals)") else print(io, string($name, "{", join(params, ", "), "}"), "($vals)") end end + $ProtoStructs.setdefs!($name, $ProtoStructs.ProtoInfo( + world=$world, + paramcount=$(length(type_parameters)), + fields=$(runtime_field_info(field_info)), + )) end else quote if !@isdefined $name - Base.@__doc__ struct $name{$(default_params...), NT<:NamedTuple} <: $abstract_type - properties::NT + Base.@__doc__ struct $name{$(default_params...)} <: $abstract_type + info::Ref{$ProtoStructs.ProtoInfo} + properties::Ref{NamedTuple} end else if ($abstract_type != Any) && ($abstract_type != Base.supertype($name)) error("The supertype of a proto struct is not redefinable. Please restart your julia session.") end - the_methods = collect(methods($name)) - if length(the_methods) >= 1 - Base.delete_method(the_methods[1]) - end - if length(the_methods) >= 2 - Base.delete_method(the_methods[2]) - end + $ProtoStructs.cleanup_struct($name, $(Symbol("new_$name"))) end - function $name($(fields...)) where {$(type_parameters...)} + function $(Symbol("new_$name"))($(fields...)) where {$(type_parameters...)} v = NamedTuple{$field_names, $field_types}(($(field_names...),)) - return $name{$(type_parameter_names...), $(any_params...), typeof(v)}(v) + return $name{$(type_parameter_names...), $(any_params...)}(Ref($ProtoStructs.defsfor($name)), v) end - function $name{$(type_parameter_names...)}($(fields...)) where {$(type_parameters...)} - v = NamedTuple{$field_names, $field_types}(($(field_names...),)) - return $name{$(type_parameter_names...), $(any_params...), typeof(v)}(v) + function $name($(field_names...)) + local prop_types = $ProtoStructs.property_types($name, $((:Any for _ in 1:length(type_parameters))...)).parameters[2].parameters + # call new_NAME so it can infer the type parameters + return $(Symbol("new_$name"))( + $((:($ProtoStructs.convert_field($name, $arg, prop_types[$i])) + for (i, arg) in enumerate(field_names))...)) end - function $name($params_ex) - return $name($(call_args...)) + function $name{$(type_parameter_names...)}($(field_names...)) where {$(type_parameters...)} + v = NamedTuple{$field_names, $field_types}(($( + (:($ProtoStructs.convert_field($name, $arg, $type)) + for (type, arg) in zip(field_types.args[2:end], field_names))...),)) + return $name{$(type_parameter_names...), $(any_params...)}(Ref($ProtoStructs.defsfor($name)), v) + end + + function $name($params_ex) where {$(type_parameters...)} + $name($(call_args...)) end function $name{$(type_parameter_names...)}($params_ex) where {$(type_parameters...)} $name{$(type_parameter_names...)}($(call_args...)) end + function $ProtoStructs.default_for(::$name{$(type_parameter_names...)}, field::Symbol) where {$(type_parameter_names...)} + return (; + $((Expr(:kw, p.args[1], p.args[2]) for p in params_ex.args if !(p isa Symbol))...) + )[field] + end + + $((:(function $ProtoStructs.convert_field(::Type{$UNIQ}, value::$(info.type), target::Type) where {$UNIQ <: $name, $(type_parameters...)} + value + end) + for info in field_info)...) + + function $ProtoStructs.property_types(::Type{$UNIQ}, $(type_parameter_names...)) where {$UNIQ <: $name} + return NamedTuple{$field_names, $field_types} + end + function Base.getproperty(o::$name, s::Symbol) - return getproperty(getfield(o, :properties), s) + $ProtoStructs.updateproto(o) + return getproperty(getfield(o, :properties)[], s) end function Base.propertynames(o::$name) - return propertynames(getfield(o, :properties)) + $ProtoStructs.updateproto(o) + return propertynames(getfield(o, :properties)[]) end function Base.show(io::IO, o::$name) - vals = join([x isa String ? "\"$x\"" : x for x in getfield(o, :properties)], ", ") - params = typeof(o).parameters[1:end-$N_any_params-1] + $ProtoStructs.updateproto(o) + vals = join([x isa String ? "\"$x\"" : x for x in getfield(o, :properties)[]], ", ") + params = typeof(o).parameters[1:end-$N_any_params] if isempty(params) print(io, string($name), "($vals)") else print(io, string($name, "{", join(params, ", "), "}"), "($vals)") end end + $ProtoStructs.setdefs!($name, $ProtoStructs.ProtoInfo( + world=$world, + paramcount=$(length(type_parameters)), + fields=$(runtime_field_info(field_info)), + )) end end return ex end +function runtime_field_info(info) + return :((; + $((:($(i.name) = (; name = $(QuoteNode(i.name)), + isconst = $(i.isconst), + hasdefault = $(i.hasdefault), + )) + for i in info)...), + )) +end + +function getdefault(ex) + ex isa Expr && ex.head == :kw && + return true, ex.args[2] + ex isa Symbol && + return false, nothing + @warn "Cannot compute default for field $ex" + return false, nothing +end + +function property_types end + +function updateproto(o::T) where {T} + local defs = defsfor(T) + local olddefs = getfield(o, :info)[] + olddefs.world == defs.world && + return false + getfield(o, :info)[] = defs + # get type for old struct's new properties + local raw_prop_type = property_types(T, T.parameters[1:defs.paramcount]...) + local prop_types = (; zip(raw_prop_type.parameters[1], raw_prop_type.parameters[2].parameters)...) + local oldfields = getfield(o, :properties)[] + + getfield(o, :properties)[] = (; (name => updatefield(o, type, defs.fields[name], oldfields) + for (name, type) in pairs(prop_types))...) + @info "Warning, T has changed" + return true +end + +function default_for end + +""" +newtype contains the fully instantiated type for this struct +""" +function updatefield(o, newtype, info, oldfields) + local err = "" + local orig_type = newtype + local default = default_for(o, info.name) + + if !info.isconst + newtype = Ref{newtype} + end + if info.name ∈ keys(oldfields) + try + return updatevalue(newtype, oldfields[info.name]) + catch + err = "Could not convert old value $(oldfields[info.name]) to type $newtype" + end + end + if !info.hasdefault + try + default = typemin(orig_type) + if !isempty(err) + err = "$err and no default value for field $(info.name), choosing typemin" + else + err = "No default value for field $(info.name), choosing typemin" + end + catch + if !isempty(err) + error("$err and no default value or typemin for field $(info.name)") + else + error("No default value or typemin for field $(info.name)") + end + end + elseif !isempty(err) + err = "$err, using default instead" + end + !isempty(err) && + @warn err + return default +end + +updatevalue(newtype, value) = convert(newtype, value) +updatevalue(::Type{Ref{T}}, value::Ref{U}) where {T, U} = Ref{T}(updatevalue(T, value[])) +updatevalue(::Type{T}, value::Ref{U}) where {T, U} = updatevalue(T, value[]) +updatevalue(::Type{Ref{T}}, value::U) where {T, U} = Ref{T}(updatevalue(T, value)) + +function convert_field(_, value, target::Type) + convert(target, value) +end + +function firstparam(method) + local sig = method.sig + sig isa UnionAll ? Nothing : sig.parameters[1] +end + +function cleanup_struct(t::Type, constructor) + local info = defsfor(t) + local the_methods = collect(methods(t)) + + if length(the_methods) >= 1 + Base.delete_method(the_methods[1]) + end + if length(the_methods) >= 2 + Base.delete_method(the_methods[2]) + end + local anys = (Any for _ in 1:info.paramcount) + Base.delete_method.([ + methods(property_types, (Type{t}, anys...))..., + [m for m in methods(convert_field, (Type{t}, Any, Any)) if firstparam(m) <: t]..., + methods(constructor)..., + ]) +end diff --git a/test/test_ProtoStruct.jl b/test/test_ProtoStruct.jl index 7f918c2..c6474da 100644 --- a/test/test_ProtoStruct.jl +++ b/test/test_ProtoStruct.jl @@ -87,6 +87,9 @@ end @test tm.F == 8 && tm.G == 2.0 @test_throws MethodError tm.F = "2" @test propertynames(tm) == (:F, :G) + # value conversion + tm.G = 8 + @test tm.G == 8.0 end abstract type AbstractMutation end @@ -168,6 +171,8 @@ end end end +@static if VERSION <= v"1.10" + @proto struct TestMethods end @testset "Constuctor updating I" begin @@ -183,6 +188,10 @@ end @test length(collect(methods(TestMethods))) == 2 end +end + +@static if VERSION == v"1.10" + """ This is a docstring. """ @@ -193,3 +202,34 @@ end @testset "Docstring" begin @test string(@doc DocTestMe) == "This is a docstring.\n" end + +end + +@proto @kwdef struct A{T} + x::Array{T} = T[] +end + +a = A{Int}() + +@test a.x == [] + +a = A(; x = Int64[]) + +@test typeof(a).parameters[1] == Int64 + +@proto @kwdef struct A{T} + x::Array{T} = T[] + y::Int = 3 +end + +@test a.y == 3 + +@test A(; x = ["hello"]).x == ["hello"] + +@proto @kwdef struct A{T} + y::Float64 = 3 +end + +@test_throws ErrorException a.x +@test a.y == 3.0 +@test A{:florp}(;y = Int(1)).y === 1.0