Skip to content

Commit 32c6336

Browse files
authored
don't error if only one voter (#94)
1 parent 0ca288a commit 32c6336

File tree

3 files changed

+61
-54
lines changed

3 files changed

+61
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lighthouse"
22
uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59"
33
authors = ["Beacon Biosignals, Inc."]
4-
version = "0.14.14"
4+
version = "0.14.15"
55

66
[deps]
77
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"

src/metrics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
483483
# Step 4: Calculate all metrics derived directly from labels (does not depend on
484484
# predictions)
485485
labels_metrics_table = LabelMetricsRow[]
486-
if has_value(votes)
486+
if has_value(votes) && size(votes, 2) > 1
487487
labels_metrics_table = map(c -> get_label_metrics_multirater(votes, c),
488488
class_indices)
489489
labels_metrics_table = vcat(labels_metrics_table,

test/row.jl

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -52,62 +52,70 @@ end
5252
@test isequal(metrics_from_inputs, metrics_from_table)
5353

5454
# Multiple labelers: round-trip `ObservationRow``...
55-
num_voters = 5
56-
possible_vote_labels = collect(0:length(classes)) # vote 0 == "no vote"
57-
vote_rng = StableRNG(22)
58-
votes = [rand(vote_rng, possible_vote_labels)
59-
for sample in 1:num_observations, voter in 1:num_voters]
60-
votes[:, 3] .= votes[:, 4] # Voter 4 voted identically to voter 3 (force non-zero agreement)
61-
elected_hard_multilabeller = map(row -> majority(vote_rng, row, 1:length(classes)),
62-
eachrow(votes))
63-
table = test_roundtrip_observation_table(; predicted_soft_labels, predicted_hard_labels,
64-
elected_hard_labels=elected_hard_multilabeller,
65-
votes)
66-
67-
# ...is there parity in evaluation_metrics calculations?
68-
metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels,
69-
predicted_soft_labels,
70-
elected_hard_multilabeller,
71-
classes; votes)
72-
metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes)
73-
@test isequal(metrics_from_inputs, metrics_from_table)
74-
75-
r_table = Lighthouse._inputs_to_observation_table(; predicted_soft_labels,
76-
predicted_hard_labels,
77-
elected_hard_labels=elected_hard_multilabeller,
78-
votes)
79-
@test isnothing(Legolas.validate(r_table, Lighthouse.OBSERVATION_ROW_SCHEMA))
80-
81-
# ...can we handle both dataframe input and more generic row iterators?
82-
df_table = DataFrame(r_table)
83-
output_r = Lighthouse._observation_table_to_inputs(r_table)
84-
output_df = Lighthouse._observation_table_to_inputs(df_table)
85-
@test isequal(output_r, output_df)
86-
87-
# Safety last!
88-
transform!(df_table, :votes => ByRow(v -> isodd(sum(v)) ? missing : v);
89-
renamecols=false)
90-
@test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table)
91-
92-
transform!(df_table, :votes => ByRow(v -> ismissing(v) ? [1, 2, 3] : v);
93-
renamecols=false)
94-
@test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table)
95-
96-
@test_throws DimensionMismatch Lighthouse._inputs_to_observation_table(;
97-
predicted_soft_labels,
98-
predicted_hard_labels=predicted_hard_labels[1:4],
99-
elected_hard_labels=elected_hard_multilabeller,
100-
votes)
55+
for num_voters in (1, 5)
56+
possible_vote_labels = collect(0:length(classes)) # vote 0 == "no vote"
57+
vote_rng = StableRNG(22)
58+
votes = [rand(vote_rng, possible_vote_labels)
59+
for sample in 1:num_observations, voter in 1:num_voters]
60+
if num_voters >= 4
61+
votes[:, 3] .= votes[:, 4] # Voter 4 voted identically to voter 3 (force non-zero agreement)
62+
end
63+
elected_hard_multilabeller = map(row -> majority(vote_rng, row, 1:length(classes)),
64+
eachrow(votes))
65+
table = test_roundtrip_observation_table(; predicted_soft_labels,
66+
predicted_hard_labels,
67+
elected_hard_labels=elected_hard_multilabeller,
68+
votes)
69+
70+
# ...is there parity in evaluation_metrics calculations?
71+
metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels,
72+
predicted_soft_labels,
73+
elected_hard_multilabeller,
74+
classes; votes)
75+
metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes)
76+
@test isequal(metrics_from_inputs, metrics_from_table)
77+
78+
r_table = Lighthouse._inputs_to_observation_table(; predicted_soft_labels,
79+
predicted_hard_labels,
80+
elected_hard_labels=elected_hard_multilabeller,
81+
votes)
82+
@test isnothing(Legolas.validate(r_table, Lighthouse.OBSERVATION_ROW_SCHEMA))
83+
84+
# ...can we handle both dataframe input and more generic row iterators?
85+
df_table = DataFrame(r_table)
86+
output_r = Lighthouse._observation_table_to_inputs(r_table)
87+
output_df = Lighthouse._observation_table_to_inputs(df_table)
88+
@test isequal(output_r, output_df)
89+
90+
# Safety last!
91+
transform!(df_table, :votes => ByRow(v -> isodd(sum(v)) ? missing : v);
92+
renamecols=false)
93+
@test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table)
94+
95+
transform!(df_table, :votes => ByRow(v -> ismissing(v) ? [1, 2, 3] : v);
96+
renamecols=false)
97+
@test_throws ArgumentError Lighthouse._observation_table_to_inputs(df_table)
98+
99+
@test_throws DimensionMismatch Lighthouse._inputs_to_observation_table(;
100+
predicted_soft_labels,
101+
predicted_hard_labels=predicted_hard_labels[1:4],
102+
elected_hard_labels=elected_hard_multilabeller,
103+
votes)
104+
end
101105
end
102106

103107
@testset "`ClassRow" begin
104108
@test isa(Lighthouse.ClassRow(; class_index=3, class_labels=missing).class_index, Int64)
105-
@test isa(Lighthouse.ClassRow(; class_index=Int8(3), class_labels=missing).class_index, Int64)
109+
@test isa(Lighthouse.ClassRow(; class_index=Int8(3), class_labels=missing).class_index,
110+
Int64)
106111
@test Lighthouse.ClassRow(; class_index=:multiclass).class_index == :multiclass
107-
@test Lighthouse.ClassRow(; class_index=:multiclass, class_labels=["a", "b"]).class_labels == ["a", "b"]
112+
@test Lighthouse.ClassRow(; class_index=:multiclass,
113+
class_labels=["a", "b"]).class_labels == ["a", "b"]
108114

109-
@test_throws ArgumentError Lighthouse.ClassRow(; class_index=3.0f0, class_labels=missing)
110-
@test_throws ArgumentError Lighthouse.ClassRow(; class_index=:mUlTiClAsS, class_labels=missing)
115+
@test_throws ArgumentError Lighthouse.ClassRow(; class_index=3.0f0,
116+
class_labels=missing)
117+
@test_throws ArgumentError Lighthouse.ClassRow(; class_index=:mUlTiClAsS,
118+
class_labels=missing)
111119
end
112120

113121
@testset "class_labels" begin
@@ -117,7 +125,7 @@ end
117125
0.9 0.1
118126
0.0 1.0]
119127
elected_hard_labels = [1, 2, 2, 2, 1]
120-
predicted_hard_labels = [1,2,2,1,2]
128+
predicted_hard_labels = [1, 2, 2, 1, 2]
121129
thresholds = [0.25, 0.5, 0.75]
122130
i_class = 2
123131
class_labels = ["a", "b"]
@@ -127,7 +135,6 @@ end
127135
@test default_metrics.class_labels == class_labels
128136
end
129137

130-
131138
@testset "`Curve`" begin
132139
c = ([1, 2, 3], [4, 5, 6])
133140
@test Curve(c...) isa Curve

0 commit comments

Comments
 (0)