@@ -7,15 +7,15 @@ struct HParamRealDomain
77 min_value:: Float64
88 max_value:: Float64
99end
10- struct HParamSetDomain{T<: Union{String, Bool, Float64} }
10+ struct HParamSetDomain{T<: Union{String,Bool,Float64} }
1111 values:: Vector{T}
1212end
1313Base. @kwdef struct HParamConfig
1414 name:: String
1515 datatype:: DataType
1616 displayname:: String = " "
1717 description:: String = " "
18- domain:: Union{Nothing, HParamRealDomain, HParamSetDomain} = nothing
18+ domain:: Union{Nothing,HParamRealDomain,HParamSetDomain} = nothing
1919end
2020Base. @kwdef struct MetricConfig
2121 name:: String
@@ -31,7 +31,7 @@ _to_proto_hparam_dtype(::Val{Bool}) = HParamDataType.DATA_TYPE_BOOL
3131_to_proto_hparam_dtype (:: Val{Float64} ) = HParamDataType. DATA_TYPE_FLOAT64
3232_to_proto_hparam_dtype (:: Val{String} ) = HParamDataType. DATA_TYPE_STRING
3333
34- function _convert_value (v:: T ) where {T<: Union{String, Bool, Real} }
34+ function _convert_value (v:: T ) where {T<: Union{String,Bool,Real} }
3535 if v isa String
3636 return HValue (OneOf (:string_value , v))
3737 elseif v isa Bool
@@ -47,24 +47,23 @@ _convert_hparam_domain(::Nothing) = nothing
4747_convert_hparam_domain (domain:: HParamRealDomain ) = OneOf (:domain_interval , HP. Interval (domain. min_value, domain. max_value))
4848_convert_hparam_domain (domain:: HParamSetDomain ) = OneOf (:domain_discrete , HListValue ([_convert_value (v) for v in domain. values]))
4949
50-
5150function hparam_info (c:: HParamConfig )
5251 datatype = c. datatype
5352 domain = c. domain
5453 if isnothing (c. domain)
5554 domain = default_domain (Val (datatype))
5655 else
5756 if isa (domain, HParamRealDomain)
58- @assert datatype== Float64 " Real domains require Float64"
57+ @assert datatype == Float64 " Real domains require Float64"
5958 elseif isa (domain, HParamSetDomain{String})
60- @assert datatype== String " Domains with strings require a datatype of String"
59+ @assert datatype == String " Domains with strings require a datatype of String"
6160 elseif isa (domain, HParamSetDomain{Bool})
62- @assert datatype== Bool " Domains with bools require a datatype of Bool"
61+ @assert datatype == Bool " Domains with bools require a datatype of Bool"
6362 elseif isa (domain, HParamSetDomain{Float64})
64- @assert datatype<: Real " Domains with floats require a datatype a Real datatype"
63+ @assert datatype <: Real " Domains with floats require a datatype a Real datatype"
6564 end
6665 end
67-
66+
6867 dtype = _to_proto_hparam_dtype (Val (datatype))
6968 converted_domain = _convert_hparam_domain (domain)
7069 return HP. HParamInfo (c. name, c. displayname, c. description, dtype, converted_domain)
@@ -75,17 +74,15 @@ function metric_info(c::MetricConfig)
7574end
7675
7776function encode_bytes (content:: HP.HParamsPluginData )
78- data = PipeBuffer ();
77+ data = PipeBuffer ()
7978 encode (ProtoEncoder (data), content)
8079 return take! (data)
8180end
8281
83- # Overload the dictionary encoder
82+ # Dictionary serialisation in ProtoBuf does not work for this specific map type
83+ # and must be overloaded so that it can be parsed. The format was derived by
84+ # looking at the binary output of a log file created by tensorboardX.
8485function PB. encode (e:: ProtoEncoder , i:: Int , x:: Dict{String,HValue} )
85- # PB.Codecs.encode_tag(e, i, PB.Codecs.LENGTH_DELIMITED)
86- # remaining_size = PB.Codecs._encoded_size(x, i) - 2 # remove two for the field name and length
87- # PB.Codecs.vbyte_encode(e.io, UInt32(remaining_size))
88-
8986 for (k, v) in x
9087 PB. Codecs. encode_tag (e, 1 , PB. Codecs. LENGTH_DELIMITED)
9188 total_size = PB. Codecs. _encoded_size (k, 1 ) + PB. Codecs. _encoded_size (v, 2 )
@@ -95,6 +92,8 @@ function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
9592 end
9693 return nothing
9794end
95+ # Similarly, we must overload the size calculation to take into account the new
96+ # format.
9897function PB. Codecs. _encoded_size (x:: Dict{String,HValue} , i:: Int )
9998 # Field number and length is another 2 bytes
10099 # There are two bytes for each key value pair extra
@@ -113,7 +112,7 @@ to `Float64` when writing the logs.
113112`metrics` should be a list of tags, which correspond to scalars that have been logged. Tensorboard will
114113automatically extract the latest metric logged to use for this value.
115114"""
116- function write_hparams! (logger:: TBLogger , hparams:: Dict{String, Any} , metrics:: AbstractArray{String} )
115+ function write_hparams! (logger:: TBLogger , hparams:: Dict{String,Any} , metrics:: AbstractArray{String} )
117116 PLUGIN_NAME = " hparams"
118117 PLUGIN_DATA_VERSION = 0
119118
@@ -122,9 +121,9 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
122121 SESSION_END_INFO_TAG = " _hparams_/session_end_info"
123122
124123 # Check for datatypes
125- for (k,v) in hparams
126- @assert typeof (v) <: Union{Bool, String, Real} " Hyperparameters must be of types String, Bool or Real"
127- # Cast to other values
124+ for (k, v) in hparams
125+ @assert typeof (v) <: Union{Bool,String,Real} " Hyperparameters must be of types String, Bool or Real"
126+ # Cast non-supported numerical values to Float64
128127 if ! (typeof (v) <: Bool ) && typeof (v) <: Real
129128 hparams[k] = Float64 (v)
130129 end
@@ -133,12 +132,8 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
133132 hparam_infos = [hparam_info (HParamConfig (; name= k, datatype= typeof (v))) for (k, v) in hparams]
134133 metric_infos = [metric_info (MetricConfig (; name= metric)) for metric in metrics]
135134
136-
137- hparams_dict = Dict (k=> _convert_value (v) for (k,v) in hparams)
138- # NOTE: THE ABOVE DICTIONARY IS NOT BEING SERIALISED TO THE FILE PROPERLY,
139- # WE MAY NEED TO EXPLICITLY WRITE AN ENCODER/DECODER FOR THIS TYPE.
135+ hparams_dict = Dict (k => _convert_value (v) for (k, v) in hparams)
140136
141-
142137 experiment = HP. Experiment (" " , " " , " " , time (), hparam_infos, metric_infos)
143138 experiment_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, OneOf (:experiment , experiment))
144139 experiment_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (experiment_content)), " " , " " , DataClass. DATA_CLASS_UNKNOWN)
0 commit comments