@@ -8,29 +8,72 @@ const data_dir = joinpath(@__DIR__, "data")
8
8
const x1_bal = NPZ. npzread (joinpath (data_dir, " x1_bal.npy" ))
9
9
const x1_dns = NPZ. npzread (joinpath (data_dir, " x1_dns.npy" ))
10
10
const x1_onl = NPZ. npzread (joinpath (data_dir, " x1_onl.npy" ))
11
- const w1_dns_bal = 0.03755967829782972
12
- const w1_dns_onl = 0.004489688974663949
13
- const w1_bal_onl = 0.037079734072606625
14
- const w1_dns_bal_unnorm = 0.8190688772401341
11
+ const x2_bal = NPZ. npzread (joinpath (data_dir, " x2_bal.npy" ))
12
+ const x2_onl = NPZ. npzread (joinpath (data_dir, " x2_onl.npy" ))
13
+ # x2_dns is not needed
14
+
15
+ const w1_dns_bal = 0.8190688772401341
16
+ const w1_dns_onl = 0.10026156568190529
17
+ const w1_bal_onl = 0.8280467119588161
18
+ const w1_dns_bal_normed = 0.03755967829782972
19
+ const w1_dns_bal_comb = 0.818516185308166
20
+ const w1_bal_onl_x2 = 0.8781673689588835
21
+ const w1_bal_onl_comb = 0.8529846357976019
15
22
16
23
# ###############################################################################
17
24
# unit testing #################################################################
18
25
# ###############################################################################
19
- @testset " unit testing" begin
26
+ @testset " W1 testing" begin
20
27
arr1 = [1 , 1 , 1 , 2 , 3 , 4 , 4 , 4 ]
21
28
arr2 = [1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 4 ]
22
- @test Hgm. W1 (arr1, arr2, normalize = false ) == 0.25
23
- @test Hgm. W1 (arr2, arr1, normalize = false ) == 0.25
24
- @test Hgm. W1 (arr1, arr2) == Hgm. W1 (arr2, arr1)
25
-
29
+ @test Hgm. W1 (arr1, arr2; normalize = false ) == 0.25
30
+ @test Hgm. W1 (arr2, arr1; normalize = false ) == 0.25
31
+ @test Hgm. W1 (arr1, arr2; normalize = true ) ==
32
+ Hgm. W1 (arr2, arr1; normalize = true )
33
+
26
34
@test isapprox (Hgm. W1 (x1_dns, x1_bal), w1_dns_bal)
27
35
@test isapprox (Hgm. W1 (x1_dns, x1_onl), w1_dns_onl)
28
36
@test isapprox (Hgm. W1 (x1_bal, x1_onl), w1_bal_onl)
29
- @test isapprox (Hgm. W1 (x1_dns, x1_bal, normalize = false ), w1_dns_bal_unnorm )
37
+ @test isapprox (Hgm. W1 (x1_dns, x1_bal, normalize = true ), w1_dns_bal_normed )
30
38
31
39
@test size (Hgm. W1 (rand (3 ,100 ), rand (3 ,100 ))) == (3 ,)
32
40
@test size (Hgm. W1 (rand (9 ,100 ), rand (3 ,100 ))) == (3 ,)
33
41
end
42
+
43
+ @testset " HistData testing" begin
44
+ hd = Hgm. HistData ()
45
+ Hgm. load! (hd, :unif , rand (3 ,100000 ))
46
+ Hgm. load! (hd, :norm , randn (3 ,100000 ))
47
+ w1_rand = Hgm. W1 (hd, :unif , :norm )
48
+ @test size (w1_rand) == (3 ,)
49
+ @test all (isapprox .(w1_rand, 0.7 ; atol = 3e-2 ))
50
+
51
+ Hgm. load! (hd, :unif , rand (3 ,10 ))
52
+ Hgm. load! (hd, :norm , randn (1000 ))
53
+ @test size (hd. samples[:unif ], 1 ) == 3
54
+ @test ndims (hd. samples[:norm ]) == 1
55
+
56
+ delete! (hd. samples, :unif )
57
+ delete! (hd. samples, :norm )
58
+
59
+ Hgm. load! (hd, :dns , joinpath (data_dir, " x1_dns.npy" ))
60
+ Hgm. load! (hd, :bal , [x1_bal' ; x2_bal' ])
61
+ Hgm. load! (hd, :onl , [x1_onl' ; x2_onl' ])
62
+ @test haskey (hd. samples, :dns )
63
+ @test haskey (hd. samples, :onl )
64
+ @test ndims (hd. samples[:dns ]) == 1
65
+ @test size (hd. samples[:onl ], 1 ) == 2
66
+
67
+ k2a_vectorized = Hgm. W1 (hd, :bal )
68
+ @test isa (k2a_vectorized[:dns ], Real)
69
+ @test size (k2a_vectorized[:onl ]) == (2 ,)
70
+ @test isapprox (k2a_vectorized[:dns ], w1_dns_bal_comb)
71
+ @test isapprox (k2a_vectorized[:onl ], [w1_bal_onl, w1_bal_onl_x2])
72
+
73
+ @test isapprox (Hgm. W1 (hd, :bal , :onl , 1 : 2 ), w1_bal_onl_comb)
74
+ @test isapprox (Hgm. W1 (hd, :dns , :bal , 1 : 2 ), w1_dns_bal_comb)
75
+ end
76
+
34
77
println (" " )
35
78
36
79
0 commit comments