11import . tensorboard_plugin_hparams. hparams: Interval, MetricInfo, MetricName, HParamInfo, Experiment, HParams
22import . tensorboard_plugin_hparams. hparams: var"#DataType" as HParamDataType, DatasetType as HDatasetType
3- import . tensorboard_plugin_hparams. google. protobuf: ListValue as HListValue, OneOf as HOneOf, Value as HValue
3+ import . tensorboard_plugin_hparams. google. protobuf: ListValue as HListValue, Value as HValue
44import . tensorboard_plugin_hparams. hparams as HP
55import . tensorboard: DataClass
66
@@ -29,24 +29,24 @@ default_domain(::Val{Float64}) = HParamRealDomain(typemin(Float64), typemax(Floa
2929default_domain (:: Val{String} ) = nothing
3030
3131_to_proto_hparam_dtype (:: Val{Bool} ) = HParamDataType. DATA_TYPE_BOOL
32- _to_proto_hparam_dtype (:: Val{String} ) = HParamDataType. DATA_TYPE_STRING
3332_to_proto_hparam_dtype (:: Val{Float64} ) = HParamDataType. DATA_TYPE_FLOAT64
33+ _to_proto_hparam_dtype (:: Val{String} ) = HParamDataType. DATA_TYPE_STRING
3434
35- function _convert_value (v:: T ) where {T<: Union{String, Bool, Float64 } }
35+ function _convert_value (v:: T ) where {T<: Union{String, Bool, Real } }
3636 if v isa String
37- return HValue (HOneOf (:string_value , v))
37+ return HValue (OneOf (:string_value , v))
3838 elseif v isa Bool
39- return HValue (HOneOf (:bool_value , v))
40- elseif v isa Float64
41- return HValue (HOneOf (:number_value , v ))
39+ return HValue (OneOf (:bool_value , v))
40+ elseif v isa Real
41+ return HValue (OneOf (:number_value , Float64 (v) ))
4242 else
4343 error (" Unrecognised type!" )
4444 end
4545end
4646
4747_convert_hparam_domain (:: Nothing ) = nothing
48- _convert_hparam_domain (domain:: HParamRealDomain ) = HOneOf (:domain_interval , Interval (domain. min_value, domain. max_value))
49- _convert_hparam_domain (domain:: HParamSetDomain ) = HOneOf (:domain_discrete , HListValue ([_convert_value (v) for v in domain. values]))
48+ _convert_hparam_domain (domain:: HParamRealDomain ) = OneOf (:domain_interval , Interval (domain. min_value, domain. max_value))
49+ _convert_hparam_domain (domain:: HParamSetDomain ) = OneOf (:domain_discrete , HListValue ([_convert_value (v) for v in domain. values]))
5050
5151
5252function hparam_info (c:: HParamConfig )
@@ -62,7 +62,7 @@ function hparam_info(c::HParamConfig)
6262 elseif isa (domain, HParamSetDomain{Bool})
6363 @assert datatype== Bool " Domains with bools require a datatype of Bool"
6464 elseif isa (domain, HParamSetDomain{Float64})
65- @assert datatype== Float64 " Domains with floats require a datatype of Float64 "
65+ @assert datatype<: Real " Domains with floats require a datatype a Real datatype "
6666 end
6767 end
6868
@@ -91,7 +91,7 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
9191
9292 # Check for datatypes
9393 for v in values (hparams)
94- @assert typeof (v) <: Union{Bool, String, Float64 } " Hyperparameters must be of types String, Bool or Float64 "
94+ @assert typeof (v) <: Union{Bool, String, Real } " Hyperparameters must be of types String, Bool or Real "
9595 end
9696
9797 hparam_infos = [hparam_info (HParamConfig (; name= k, datatype= typeof (v))) for (k, v) in hparams]
@@ -101,18 +101,18 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
101101 hparams_dict = Dict {String, Any} (k=> _convert_value (v) for (k,v) in hparams)
102102
103103 session_start_info = HP. SessionStartInfo (hparams_dict, " " , " " , " " , zero (Float64))
104- session_start_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, HOneOf (:session_start_info , session_start_info))
104+ session_start_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, OneOf (:session_start_info , session_start_info))
105105 session_start_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (session_start_content)), PLUGIN_NAME, " " , DataClass. DATA_CLASS_UNKNOWN)
106106 session_start_summary = Summary ([Summary_Value (" " , SESSION_START_INFO_TAG, session_start_md, nothing )])
107107
108108 experiment = HP. Experiment (" " , " " , " " , zero (Float64), hparam_infos, metric_infos)
109- experiment_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, HOneOf (:experiment , experiment))
110- experiment_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (experiment_content)), PLUGIN_NAME , " " , DataClass. DATA_CLASS_UNKNOWN)
109+ experiment_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, OneOf (:experiment , experiment))
110+ experiment_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (experiment_content)), " " , " " , DataClass. DATA_CLASS_UNKNOWN)
111111 experiment_summary = Summary ([Summary_Value (" " , EXPERIMENT_TAG, experiment_md, nothing )])
112112
113113 session_end_info = HP. SessionEndInfo (HP. Status. STATUS_SUCCESS, zero (Float64))
114- session_end_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, HOneOf (:session_end_info , session_end_info))
115- session_end_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (session_end_content)), PLUGIN_NAME , " " , DataClass. DATA_CLASS_UNKNOWN)
114+ session_end_content = HP. HParamsPluginData (PLUGIN_DATA_VERSION, OneOf (:session_end_info , session_end_info))
115+ session_end_md = SummaryMetadata (SummaryMetadata_PluginData (PLUGIN_NAME, encode_bytes (session_end_content)), " " , " " , DataClass. DATA_CLASS_UNKNOWN)
116116 session_end_summary = Summary ([Summary_Value (" " , SESSION_END_INFO_TAG, session_end_md, nothing )])
117117
118118 for s in (experiment_summary, session_start_summary, session_end_summary)
0 commit comments