Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/fitting/result.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export FitResult, update_model!, get_objective, get_objective_variance
export FitResult, update_model!, get_objective, get_objective_variance, dof, reduced_statistic

struct FitResult{Config<:FittingConfig,U,Err,T,Sol}
config::Config
Expand All @@ -15,6 +15,10 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(result::FitResult))
buff_c = IOContext(buff, io)

total_stat = prettyfloat(sum(result.stats))
total_n_bins = sum(length(result.config.data_cache[i].objective) for i in 1:length(result.stats))
# we need to count dof manually here since we want the total dof across all datasets, which may have tied parameters
total_dof = total_n_bins - count(result.config.parameter_cache.free_mask)
total_reduced = prettyfloat(sum(result.stats) / total_dof)

println(buff_c, "FitResult:")
print(buff_c, " ")
Expand All @@ -34,7 +38,7 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(result::FitResult))
print(
io,
encapsulate(text) *
"Σ$(statistic_symbol(fit_statistic(result.config))) = $(total_stat)",
"Σ$(statistic_symbol(fit_statistic(result.config))) = $(total_stat), $(reduced_statistic_symbol(fit_statistic(result.config))) = $(total_reduced) (dof=$(total_dof))",
)
end

Expand All @@ -48,6 +52,19 @@ struct FitResultSlice{P<:FitResult,U,Err,T}
stats::T
end

function dof(slice::FitResultSlice)
i = slice.index
parent_config = slice.parent.config

n_bins = length(parent_config.data_cache[i].objective)
local_free_mask = parent_config.parameter_cache.free_mask[parent_config.parameter_bindings[i]]
n_free = count(local_free_mask)

n_bins - n_free
end

reduced_statistic(slice::FitResultSlice) = slice.stats / dof(slice)

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(slice::FitResultSlice))
buff = IOBuffer()
buff_c = IOContext(buff, io)
Expand Down Expand Up @@ -89,8 +106,12 @@ function _pretty_print_result(io::IO, slice::FitResultSlice)
println(io)

stat_sym = statistic_symbol(fit_statistic(slice.parent.config))
red_stat_sym = reduced_statistic_symbol(fit_statistic(slice.parent.config))
ν = dof(slice)
print(io, " . $(rpad(stat_sym, param_padding - 3)): $(prettyfloat(slice.stats))")
println(io)
print(io, " . $(rpad(red_stat_sym, param_padding - 3)): $(prettyfloat(reduced_statistic(slice))) (dof=$ν)")
println(io)
end

function Base.getindex(result::FitResult, i)
Expand Down
4 changes: 3 additions & 1 deletion src/plots-recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ end
@recipe function _plotting_func(slice::FitResultSlice)
label -->
statistic_symbol(fit_statistic(slice.parent.config)) *
Printf.@sprintf("=%.2f", slice.stats)
Printf.@sprintf("=%.2f, ", slice.stats) *
reduced_statistic_symbol(fit_statistic(slice.parent.config)) *
Printf.@sprintf("=%.2f", reduced_statistic(slice))
seriestype --> :stepmid
x = plotting_domain(slice)
y = calculate_objective!(slice, slice.u)
Expand Down
2 changes: 2 additions & 0 deletions src/statistics.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
abstract type AbstractStatistic end
statistic_symbol(s::AbstractStatistic) = Base.typename(typeof(s)).name
reduced_statistic_symbol(s::AbstractStatistic) = statistic_symbol(s) * "_reduced"

struct ChiSquared <: AbstractStatistic end

statistic_symbol(::ChiSquared) = "χ²"
reduced_statistic_symbol(::ChiSquared) = "χᵥ²"

measure(::ChiSquared, y, ŷ, σ²) = sum(@.((y - ŷ)^2 / σ²))

Expand Down
26 changes: 26 additions & 0 deletions test/fitting/test-results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,29 @@ calculate_objective!(result, [11.0, 2.0])
# mismatch
@test_throws "" calculate_objective!(result, [11.0])
@test_throws "" calculate_objective!(result[1], [11.0])

# test reduced statistic
# dof = n_bins - n_free_params = 100 - 2 = 98
# reduced_statistic should be χ² (stats) / dof (98)
@test dof(result[1]) == 98
@test reduced_statistic(result[1]) ≈ result[1].stats / 98 rtol = 1e-10

# multi-dataset with one tied parameter (K)
dummy_data2 = make_dummy_dataset((E) -> (E^(-3.0)); units = u"counts / (s * keV)")
prob_tied = FittingProblem(PowerLaw() => dummy_data, PowerLaw() => dummy_data2)
bindall!(prob_tied, :K)
result_tied = fit(prob_tied, LevenbergMarquadt())

# K is shared, so there are only 3 free params for all
@test count(result_tied.config.parameter_cache.free_mask) == 3

# dof for each slice shows that K still counts as free for each dataset
@test dof(result_tied[1]) == 98
@test dof(result_tied[2]) == 98

# use the free_mask directly instead of dof function to check total free params across all datasets
# should not double count the tied parameter, so total free params should be 3, not 4
total_n_bins =
length(result_tied.config.data_cache[1].objective) +
length(result_tied.config.data_cache[2].objective)
@test total_n_bins - count(result_tied.config.parameter_cache.free_mask) == 197
34 changes: 33 additions & 1 deletion test/io/test-result-printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@ slice_string = showstring(result[1])
# │ . m.a : 3.0810 0.012679
# │
# │ . χ² : 14.963
# │ . χᵥ² : 0.15268 (dof=98)
# └
expected_slice = "┌ FitResultSlice:\n│ Model: PowerLaw\n│ . Name : u Δu \n│ . m.K : 12.066 0.24374 \n│ . m.a : 3.0810 0.012679 \n│ \n│ . χ² : 14.963\n└"
expected_slice = "┌ FitResultSlice:\n│ Model: PowerLaw\n│ . Name : u Δu \n│ . m.K : 12.066 0.24374 \n│ . m.a : 3.0810 0.012679 \n│ \n│ . χ² : 14.963\n│ . χᵥ² : 0.15268 (dof=98)\n└"
@test slice_string == expected_slice

# Two-dataset with K bound across models: global dof should be 3 and use free_mask count, not sum of
# per-slice counts (which would double-count the shared K and give dof 196 instead of 197)
dummy_data2 = make_dummy_dataset((E) -> (E^(-3.0)); units = u"counts / (s * keV)")
prob_bound = FittingProblem(PowerLaw() => dummy_data, PowerLaw() => dummy_data2)
bindall!(prob_bound, :K)
result_bound = fit(prob_bound, LevenbergMarquadt())
result_bound_string = showstring(result_bound)

# Expected output (formatted):
#
# ┌ FitResult:
# │ Model: PowerLaw
# │ . Name : u Δu
# │ . m.K : 12.066 0.17235
# │ . m.a : 3.0810 0.010006
# │
# │ . χ² : 14.963
# │ . χᵥ² : 0.15268 (dof=98)
# │ Model: PowerLaw
# │ . Name : u Δu
# │ . m.K : 12.066 0.17235
# │ . m.a : 3.0810 0.010006
# │
# │ . χ² : 14.963
# │ . χᵥ² : 0.15268 (dof=98)
# └ Σχ² = 29.925, χᵥ² = 0.15190 (dof=197)
#
# total dof = 200 total bins - 3 free params (K shared, a in first dataset, a in second dataset) = 197
expected_result_bound = "┌ FitResult:\n│ Model: PowerLaw\n│ . Name : u Δu \n│ . m.K : 12.066 0.17235 \n│ . m.a : 3.0810 0.010006 \n│ \n│ . χ² : 14.963\n│ . χᵥ² : 0.15268 (dof=98)\n│ Model: PowerLaw\n│ . Name : u Δu \n│ . m.K : 12.066 0.17235 \n│ . m.a : 3.0810 0.010006 \n│ \n│ . χ² : 14.963\n│ . χᵥ² : 0.15268 (dof=98)\n└ Σχ² = 29.925, χᵥ² = 0.15190 (dof=197)"
@test result_bound_string == expected_result_bound
Loading