Skip to content

Commit 0a6f14c

Browse files
committed
Added a unit test to ensure the hparams content is being serialised corrected
1 parent c5e2a3f commit 0a6f14c

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

test/test_hparams.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using TensorBoardLogger, Logging
22
using Test
3+
import ProtoBuf as PB
34

45
test_hparams_log_dir = "test_hparams_logs/"
56

@@ -34,4 +35,46 @@ test_hparams_log_dir = "test_hparams_logs/"
3435
# # Check that a new event file has been created
3536
# @test length(logger.all_files) == 2
3637
end
38+
end
39+
40+
@testset "Serialise hparams" begin
41+
hparams_config = Dict{String, Any}(
42+
"id"=>Float64(1),
43+
"alpha"=>0.5,
44+
"p1"=>0.1,
45+
"optimisations"=>false,
46+
"method"=>"MC"
47+
)
48+
metrics = ["scalar/loss"]
49+
50+
PLUGIN_NAME = "hparams"
51+
PLUGIN_DATA_VERSION = 0
52+
53+
EXPERIMENT_TAG = "_hparams_/experiment"
54+
SESSION_START_INFO_TAG = "_hparams_/session_start_info"
55+
SESSION_END_INFO_TAG = "_hparams_/session_end_info"
56+
57+
hparam_infos = [TensorBoardLogger.hparam_info(TensorBoardLogger.HParamConfig(; name=k, datatype=typeof(v))) for (k, v) in hparams_config]
58+
metric_infos = [TensorBoardLogger.metric_info(TensorBoardLogger.MetricConfig(; name=metric)) for metric in metrics]
59+
60+
61+
hparams_dict = Dict(k => TensorBoardLogger._convert_value(v) for (k, v) in hparams_config)
62+
63+
session_start_info = TensorBoardLogger.HP.SessionStartInfo(hparams_dict, "", "", "", zero(Float64))
64+
session_start_content = TensorBoardLogger.HP.HParamsPluginData(PLUGIN_DATA_VERSION, TensorBoardLogger.OneOf(:session_start_info, session_start_info))
65+
66+
expected_bytes = UInt8[0x1a, 0x5b, 0x0a, 0x0e, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x04, 0x1a, 0x02, 0x4d, 0x43, 0x0a, 0x12, 0x0a, 0x05, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x12, 0x09, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe0, 0x3f, 0x0a, 0x0f, 0x0a, 0x02, 0x69, 0x64, 0x12, 0x09, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, 0x0a, 0x13, 0x0a, 0x0d, 0x6f, 0x70, 0x74, 0x69, 0x6d, 0x69, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x02, 0x20, 0x00, 0x0a, 0x0f, 0x0a, 0x02, 0x70, 0x31, 0x12, 0x09, 0x11, 0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xb9, 0x3f]
67+
68+
@test TensorBoardLogger.encode_bytes(session_start_content) == expected_bytes
69+
70+
d = TensorBoardLogger.ProtoDecoder(IOBuffer(expected_bytes))
71+
decoded_content = PB.Codecs.decode(d, TensorBoardLogger.HP.HParamsPluginData)
72+
decoded_session_info = decoded_content.data.value
73+
74+
@test all(haskey(decoded_session_info.hparams, k) for k in keys(hparams_config))
75+
76+
for (k, hv) in decoded_session_info.hparams
77+
decoded_v = hv.kind.value
78+
@test hparams_config[k] == decoded_v
79+
end
3780
end

0 commit comments

Comments
 (0)