Skip to content

Commit 13daad7

Browse files
committed
Tidied up the functional code and added explanations
1 parent 08c89cd commit 13daad7

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

src/hparams.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ struct HParamRealDomain
77
min_value::Float64
88
max_value::Float64
99
end
10-
struct HParamSetDomain{T<:Union{String, Bool, Float64}}
10+
struct HParamSetDomain{T<:Union{String,Bool,Float64}}
1111
values::Vector{T}
1212
end
1313
Base.@kwdef struct HParamConfig
1414
name::String
1515
datatype::DataType
1616
displayname::String = ""
1717
description::String = ""
18-
domain::Union{Nothing, HParamRealDomain, HParamSetDomain} = nothing
18+
domain::Union{Nothing,HParamRealDomain,HParamSetDomain} = nothing
1919
end
2020
Base.@kwdef struct MetricConfig
2121
name::String
@@ -31,7 +31,7 @@ _to_proto_hparam_dtype(::Val{Bool}) = HParamDataType.DATA_TYPE_BOOL
3131
_to_proto_hparam_dtype(::Val{Float64}) = HParamDataType.DATA_TYPE_FLOAT64
3232
_to_proto_hparam_dtype(::Val{String}) = HParamDataType.DATA_TYPE_STRING
3333

34-
function _convert_value(v::T) where {T<:Union{String, Bool, Real}}
34+
function _convert_value(v::T) where {T<:Union{String,Bool,Real}}
3535
if v isa String
3636
return HValue(OneOf(:string_value, v))
3737
elseif v isa Bool
@@ -47,24 +47,23 @@ _convert_hparam_domain(::Nothing) = nothing
4747
_convert_hparam_domain(domain::HParamRealDomain) = OneOf(:domain_interval, HP.Interval(domain.min_value, domain.max_value))
4848
_convert_hparam_domain(domain::HParamSetDomain) = OneOf(:domain_discrete, HListValue([_convert_value(v) for v in domain.values]))
4949

50-
5150
function hparam_info(c::HParamConfig)
5251
datatype = c.datatype
5352
domain = c.domain
5453
if isnothing(c.domain)
5554
domain = default_domain(Val(datatype))
5655
else
5756
if isa(domain, HParamRealDomain)
58-
@assert datatype==Float64 "Real domains require Float64"
57+
@assert datatype == Float64 "Real domains require Float64"
5958
elseif isa(domain, HParamSetDomain{String})
60-
@assert datatype==String "Domains with strings require a datatype of String"
59+
@assert datatype == String "Domains with strings require a datatype of String"
6160
elseif isa(domain, HParamSetDomain{Bool})
62-
@assert datatype==Bool "Domains with bools require a datatype of Bool"
61+
@assert datatype == Bool "Domains with bools require a datatype of Bool"
6362
elseif isa(domain, HParamSetDomain{Float64})
64-
@assert datatype<:Real "Domains with floats require a datatype a Real datatype"
63+
@assert datatype <: Real "Domains with floats require a datatype a Real datatype"
6564
end
6665
end
67-
66+
6867
dtype = _to_proto_hparam_dtype(Val(datatype))
6968
converted_domain = _convert_hparam_domain(domain)
7069
return HP.HParamInfo(c.name, c.displayname, c.description, dtype, converted_domain)
@@ -75,17 +74,15 @@ function metric_info(c::MetricConfig)
7574
end
7675

7776
function encode_bytes(content::HP.HParamsPluginData)
78-
data = PipeBuffer();
77+
data = PipeBuffer()
7978
encode(ProtoEncoder(data), content)
8079
return take!(data)
8180
end
8281

83-
# Overload the dictionary encoder
82+
# Dictionary serialisation in ProtoBuf does not work for this specific map type
83+
# and must be overloaded so that it can be parsed. The format was derived by
84+
# looking at the binary output of a log file created by tensorboardX.
8485
function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
85-
# PB.Codecs.encode_tag(e, i, PB.Codecs.LENGTH_DELIMITED)
86-
# remaining_size = PB.Codecs._encoded_size(x, i) - 2 # remove two for the field name and length
87-
# PB.Codecs.vbyte_encode(e.io, UInt32(remaining_size))
88-
8986
for (k, v) in x
9087
PB.Codecs.encode_tag(e, 1, PB.Codecs.LENGTH_DELIMITED)
9188
total_size = PB.Codecs._encoded_size(k, 1) + PB.Codecs._encoded_size(v, 2)
@@ -95,6 +92,8 @@ function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
9592
end
9693
return nothing
9794
end
95+
# Similarly, we must overload the size calculation to take into account the new
96+
# format.
9897
function PB.Codecs._encoded_size(x::Dict{String,HValue}, i::Int)
9998
# Field number and length is another 2 bytes
10099
# There are two bytes for each key value pair extra
@@ -113,7 +112,7 @@ to `Float64` when writing the logs.
113112
`metrics` should be a list of tags, which correspond to scalars that have been logged. Tensorboard will
114113
automatically extract the latest metric logged to use for this value.
115114
"""
116-
function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::AbstractArray{String})
115+
function write_hparams!(logger::TBLogger, hparams::Dict{String,Any}, metrics::AbstractArray{String})
117116
PLUGIN_NAME = "hparams"
118117
PLUGIN_DATA_VERSION = 0
119118

@@ -122,9 +121,9 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
122121
SESSION_END_INFO_TAG = "_hparams_/session_end_info"
123122

124123
# Check for datatypes
125-
for (k,v) in hparams
126-
@assert typeof(v) <: Union{Bool, String, Real} "Hyperparameters must be of types String, Bool or Real"
127-
# Cast to other values
124+
for (k, v) in hparams
125+
@assert typeof(v) <: Union{Bool,String,Real} "Hyperparameters must be of types String, Bool or Real"
126+
# Cast non-supported numerical values to Float64
128127
if !(typeof(v) <: Bool) && typeof(v) <: Real
129128
hparams[k] = Float64(v)
130129
end
@@ -133,12 +132,8 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
133132
hparam_infos = [hparam_info(HParamConfig(; name=k, datatype=typeof(v))) for (k, v) in hparams]
134133
metric_infos = [metric_info(MetricConfig(; name=metric)) for metric in metrics]
135134

136-
137-
hparams_dict = Dict(k=>_convert_value(v) for (k,v) in hparams)
138-
# NOTE: THE ABOVE DICTIONARY IS NOT BEING SERIALISED TO THE FILE PROPERLY,
139-
# WE MAY NEED TO EXPLICITLY WRITE AN ENCODER/DECODER FOR THIS TYPE.
135+
hparams_dict = Dict(k => _convert_value(v) for (k, v) in hparams)
140136

141-
142137
experiment = HP.Experiment("", "", "", time(), hparam_infos, metric_infos)
143138
experiment_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:experiment, experiment))
144139
experiment_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(experiment_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)

0 commit comments

Comments
 (0)