Skip to content

Commit 867b105

Browse files
committed
Tidy up and use normal PB OneOf, switch to allowing Real instead of just Float64
1 parent b8fbbb6 commit 867b105

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/hparams.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import .tensorboard_plugin_hparams.hparams: Interval, MetricInfo, MetricName, HParamInfo, Experiment, HParams
22
import .tensorboard_plugin_hparams.hparams: var"#DataType" as HParamDataType, DatasetType as HDatasetType
3-
import .tensorboard_plugin_hparams.google.protobuf: ListValue as HListValue, OneOf as HOneOf, Value as HValue
3+
import .tensorboard_plugin_hparams.google.protobuf: ListValue as HListValue, Value as HValue
44
import .tensorboard_plugin_hparams.hparams as HP
55
import .tensorboard: DataClass
66

@@ -29,24 +29,24 @@ default_domain(::Val{Float64}) = HParamRealDomain(typemin(Float64), typemax(Floa
2929
default_domain(::Val{String}) = nothing
3030

3131
_to_proto_hparam_dtype(::Val{Bool}) = HParamDataType.DATA_TYPE_BOOL
32-
_to_proto_hparam_dtype(::Val{String}) = HParamDataType.DATA_TYPE_STRING
3332
_to_proto_hparam_dtype(::Val{Float64}) = HParamDataType.DATA_TYPE_FLOAT64
33+
_to_proto_hparam_dtype(::Val{String}) = HParamDataType.DATA_TYPE_STRING
3434

35-
function _convert_value(v::T) where {T<:Union{String, Bool, Float64}}
35+
function _convert_value(v::T) where {T<:Union{String, Bool, Real}}
3636
if v isa String
37-
return HValue(HOneOf(:string_value, v))
37+
return HValue(OneOf(:string_value, v))
3838
elseif v isa Bool
39-
return HValue(HOneOf(:bool_value, v))
40-
elseif v isa Float64
41-
return HValue(HOneOf(:number_value, v))
39+
return HValue(OneOf(:bool_value, v))
40+
elseif v isa Real
41+
return HValue(OneOf(:number_value, Float64(v)))
4242
else
4343
error("Unrecognised type!")
4444
end
4545
end
4646

4747
_convert_hparam_domain(::Nothing) = nothing
48-
_convert_hparam_domain(domain::HParamRealDomain) = HOneOf(:domain_interval, Interval(domain.min_value, domain.max_value))
49-
_convert_hparam_domain(domain::HParamSetDomain) = HOneOf(:domain_discrete, HListValue([_convert_value(v) for v in domain.values]))
48+
_convert_hparam_domain(domain::HParamRealDomain) = OneOf(:domain_interval, Interval(domain.min_value, domain.max_value))
49+
_convert_hparam_domain(domain::HParamSetDomain) = OneOf(:domain_discrete, HListValue([_convert_value(v) for v in domain.values]))
5050

5151

5252
function hparam_info(c::HParamConfig)
@@ -62,7 +62,7 @@ function hparam_info(c::HParamConfig)
6262
elseif isa(domain, HParamSetDomain{Bool})
6363
@assert datatype==Bool "Domains with bools require a datatype of Bool"
6464
elseif isa(domain, HParamSetDomain{Float64})
65-
@assert datatype==Float64 "Domains with floats require a datatype of Float64"
65+
@assert datatype<:Real "Domains with floats require a datatype a Real datatype"
6666
end
6767
end
6868

@@ -91,7 +91,7 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
9191

9292
# Check for datatypes
9393
for v in values(hparams)
94-
@assert typeof(v) <: Union{Bool, String, Float64} "Hyperparameters must be of types String, Bool or Float64"
94+
@assert typeof(v) <: Union{Bool, String, Real} "Hyperparameters must be of types String, Bool or Real"
9595
end
9696

9797
hparam_infos = [hparam_info(HParamConfig(; name=k, datatype=typeof(v))) for (k, v) in hparams]
@@ -101,18 +101,18 @@ function write_hparams!(logger::TBLogger, hparams::Dict{String, Any}, metrics::A
101101
hparams_dict = Dict{String, Any}(k=>_convert_value(v) for (k,v) in hparams)
102102

103103
session_start_info = HP.SessionStartInfo(hparams_dict, "", "", "", zero(Float64))
104-
session_start_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, HOneOf(:session_start_info, session_start_info))
104+
session_start_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:session_start_info, session_start_info))
105105
session_start_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(session_start_content)), PLUGIN_NAME, "", DataClass.DATA_CLASS_UNKNOWN)
106106
session_start_summary = Summary([Summary_Value("", SESSION_START_INFO_TAG, session_start_md, nothing)])
107107

108108
experiment = HP.Experiment("", "", "", zero(Float64), hparam_infos, metric_infos)
109-
experiment_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, HOneOf(:experiment, experiment))
110-
experiment_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(experiment_content)), PLUGIN_NAME, "", DataClass.DATA_CLASS_UNKNOWN)
109+
experiment_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:experiment, experiment))
110+
experiment_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(experiment_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)
111111
experiment_summary = Summary([Summary_Value("", EXPERIMENT_TAG, experiment_md, nothing)])
112112

113113
session_end_info = HP.SessionEndInfo(HP.Status.STATUS_SUCCESS, zero(Float64))
114-
session_end_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, HOneOf(:session_end_info, session_end_info))
115-
session_end_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(session_end_content)), PLUGIN_NAME, "", DataClass.DATA_CLASS_UNKNOWN)
114+
session_end_content = HP.HParamsPluginData(PLUGIN_DATA_VERSION, OneOf(:session_end_info, session_end_info))
115+
session_end_md = SummaryMetadata(SummaryMetadata_PluginData(PLUGIN_NAME, encode_bytes(session_end_content)), "", "", DataClass.DATA_CLASS_UNKNOWN)
116116
session_end_summary = Summary([Summary_Value("", SESSION_END_INFO_TAG, session_end_md, nothing)])
117117

118118
for s in (experiment_summary, session_start_summary, session_end_summary)

0 commit comments

Comments
 (0)