|
52 | 52 | @test isequal(metrics_from_inputs, metrics_from_table) |
53 | 53 |
|
54 | 54 | # Multiple labelers: round-trip `ObservationRow``... |
55 | | - num_voters = 5 |
56 | | - possible_vote_labels = collect(0:length(classes)) # vote 0 == "no vote" |
57 | | - vote_rng = StableRNG(22) |
58 | | - votes = [rand(vote_rng, possible_vote_labels) |
59 | | - for sample in 1:num_observations, voter in 1:num_voters] |
60 | | - votes[:, 3] .= votes[:, 4] # Voter 4 voted identically to voter 3 (force non-zero agreement) |
61 | | - elected_hard_multilabeller = map(row -> majority(vote_rng, row, 1:length(classes)), |
62 | | - eachrow(votes)) |
63 | | - table = test_roundtrip_observation_table(; predicted_soft_labels, predicted_hard_labels, |
64 | | - elected_hard_labels=elected_hard_multilabeller, |
65 | | - votes) |
66 | | - |
67 | | - # ...is there parity in evaluation_metrics calculations? |
68 | | - metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels, |
69 | | - predicted_soft_labels, |
70 | | - elected_hard_multilabeller, |
71 | | - classes; votes) |
72 | | - metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes) |
73 | | - @test isequal(metrics_from_inputs, metrics_from_table) |
74 | | - |
75 | | - r_table = Lighthouse._inputs_to_observation_table(; predicted_soft_labels, |
76 | | - predicted_hard_labels, |
77 | | - elected_hard_labels=elected_hard_multilabeller, |
78 | | - votes) |
79 | | - @test isnothing(Legolas.validate(r_table, Lighthouse.OBSERVATION_ROW_SCHEMA)) |
80 | | - |
81 | | - # ...can we handle both dataframe input and more generic row iterators? |
82 | | - df_table = DataFrame(r_table) |
83 | | - output_r = Lighthouse._observation_table_to_inputs(r_table) |
84 | | - output_df = Lighthouse._observation_table_to_inputs(df_table) |
85 | | - @test isequal(output_r, output_df) |
86 | | - |
87 | | - # Safety last! |
88 | | - transform!(df_table, :votes => ByRow(v -> isodd(sum(v)) ? missing : v); |
89 | | - renamecols=false) |
90 | | - @test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table) |
91 | | - |
92 | | - transform!(df_table, :votes => ByRow(v -> ismissing(v) ? [1, 2, 3] : v); |
93 | | - renamecols=false) |
94 | | - @test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table) |
95 | | - |
96 | | - @test_throws DimensionMismatch Lighthouse._inputs_to_observation_table(; |
97 | | - predicted_soft_labels, |
98 | | - predicted_hard_labels=predicted_hard_labels[1:4], |
99 | | - elected_hard_labels=elected_hard_multilabeller, |
100 | | - votes) |
| 55 | + for num_voters in (1, 5) |
| 56 | + possible_vote_labels = collect(0:length(classes)) # vote 0 == "no vote" |
| 57 | + vote_rng = StableRNG(22) |
| 58 | + votes = [rand(vote_rng, possible_vote_labels) |
| 59 | + for sample in 1:num_observations, voter in 1:num_voters] |
| 60 | + if num_voters >= 4 |
| 61 | + votes[:, 3] .= votes[:, 4] # Voter 4 voted identically to voter 3 (force non-zero agreement) |
| 62 | + end |
| 63 | + elected_hard_multilabeller = map(row -> majority(vote_rng, row, 1:length(classes)), |
| 64 | + eachrow(votes)) |
| 65 | + table = test_roundtrip_observation_table(; predicted_soft_labels, |
| 66 | + predicted_hard_labels, |
| 67 | + elected_hard_labels=elected_hard_multilabeller, |
| 68 | + votes) |
| 69 | + |
| 70 | + # ...is there parity in evaluation_metrics calculations? |
| 71 | + metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels, |
| 72 | + predicted_soft_labels, |
| 73 | + elected_hard_multilabeller, |
| 74 | + classes; votes) |
| 75 | + metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes) |
| 76 | + @test isequal(metrics_from_inputs, metrics_from_table) |
| 77 | + |
| 78 | + r_table = Lighthouse._inputs_to_observation_table(; predicted_soft_labels, |
| 79 | + predicted_hard_labels, |
| 80 | + elected_hard_labels=elected_hard_multilabeller, |
| 81 | + votes) |
| 82 | + @test isnothing(Legolas.validate(r_table, Lighthouse.OBSERVATION_ROW_SCHEMA)) |
| 83 | + |
| 84 | + # ...can we handle both dataframe input and more generic row iterators? |
| 85 | + df_table = DataFrame(r_table) |
| 86 | + output_r = Lighthouse._observation_table_to_inputs(r_table) |
| 87 | + output_df = Lighthouse._observation_table_to_inputs(df_table) |
| 88 | + @test isequal(output_r, output_df) |
| 89 | + |
| 90 | + # Safety last! |
| 91 | + transform!(df_table, :votes => ByRow(v -> isodd(sum(v)) ? missing : v); |
| 92 | + renamecols=false) |
| 93 | + @test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table) |
| 94 | + |
| 95 | + transform!(df_table, :votes => ByRow(v -> ismissing(v) ? [1, 2, 3] : v); |
| 96 | + renamecols=false) |
| 97 | + @test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table) |
| 98 | + |
| 99 | + @test_throws DimensionMismatch Lighthouse._inputs_to_observation_table(; |
| 100 | + predicted_soft_labels, |
| 101 | + predicted_hard_labels=predicted_hard_labels[1:4], |
| 102 | + elected_hard_labels=elected_hard_multilabeller, |
| 103 | + votes) |
| 104 | + end |
101 | 105 | end |
102 | 106 |
|
103 | 107 | @testset "`ClassRow" begin |
104 | 108 | @test isa(Lighthouse.ClassRow(; class_index=3, class_labels=missing).class_index, Int64) |
105 | | - @test isa(Lighthouse.ClassRow(; class_index=Int8(3), class_labels=missing).class_index, Int64) |
| 109 | + @test isa(Lighthouse.ClassRow(; class_index=Int8(3), class_labels=missing).class_index, |
| 110 | + Int64) |
106 | 111 | @test Lighthouse.ClassRow(; class_index=:multiclass).class_index == :multiclass |
107 | | - @test Lighthouse.ClassRow(; class_index=:multiclass, class_labels=["a", "b"]).class_labels == ["a", "b"] |
| 112 | + @test Lighthouse.ClassRow(; class_index=:multiclass, |
| 113 | + class_labels=["a", "b"]).class_labels == ["a", "b"] |
108 | 114 |
|
109 | | - @test_throws ArgumentError Lighthouse.ClassRow(; class_index=3.0f0, class_labels=missing) |
110 | | - @test_throws ArgumentError Lighthouse.ClassRow(; class_index=:mUlTiClAsS, class_labels=missing) |
| 115 | + @test_throws ArgumentError Lighthouse.ClassRow(; class_index=3.0f0, |
| 116 | + class_labels=missing) |
| 117 | + @test_throws ArgumentError Lighthouse.ClassRow(; class_index=:mUlTiClAsS, |
| 118 | + class_labels=missing) |
111 | 119 | end |
112 | 120 |
|
113 | 121 | @testset "class_labels" begin |
|
117 | 125 | 0.9 0.1 |
118 | 126 | 0.0 1.0] |
119 | 127 | elected_hard_labels = [1, 2, 2, 2, 1] |
120 | | - predicted_hard_labels = [1,2,2,1,2] |
| 128 | + predicted_hard_labels = [1, 2, 2, 1, 2] |
121 | 129 | thresholds = [0.25, 0.5, 0.75] |
122 | 130 | i_class = 2 |
123 | 131 | class_labels = ["a", "b"] |
|
127 | 135 | @test default_metrics.class_labels == class_labels |
128 | 136 | end |
129 | 137 |
|
130 | | - |
131 | 138 | @testset "`Curve`" begin |
132 | 139 | c = ([1, 2, 3], [4, 5, 6]) |
133 | 140 | @test Curve(c...) isa Curve |
|
0 commit comments