Skip to content

Commit bde077d

Browse files
committed
wip improvement of metrics system
1 parent 4eeda07 commit bde077d

File tree

10 files changed

+680
-1
lines changed

10 files changed

+680
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,19 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
99
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
1010
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1111
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
14+
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
1315

1416
[compat]
1517
DecisionFocusedLearningBenchmarks = "0.3.0"
1618
Flux = "0.16.5"
1719
InferOpt = "0.7.1"
1820
MLUtils = "0.4.8"
1921
ProgressMeter = "1.11.0"
22+
Statistics = "1.11.1"
2023
UnicodePlots = "3.8.1"
24+
ValueHistories = "0.5.4"
2125
julia = "1.11"
2226

2327
[extras]

examples/consistent_signature.jl

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Consistent Metric Function Signature
2+
3+
using DecisionFocusedLearningAlgorithms
4+
using DecisionFocusedLearningBenchmarks
5+
using MLUtils: splitobs
6+
using Statistics
7+
8+
b = ArgmaxBenchmark()
9+
dataset = generate_dataset(b, 100)
10+
train_instances, val_instances, test_instances = splitobs(dataset; at=(0.3, 0.3, 0.4))
11+
12+
model = generate_statistical_model(b; seed=0)
13+
maximizer = generate_maximizer(b)
14+
15+
# ============================================================================
16+
# NEW: ALL metric functions have the SAME signature!
17+
# (model, maximizer, data, context) -> value
18+
# ============================================================================
19+
20+
# Simple metric - just uses model, maximizer, and data
21+
compute_gap = (model, max, data, ctx) -> compute_gap(b, data, model, max)
22+
23+
# Metric that also uses context
24+
compute_gap_ratio =
25+
(model, max, data, ctx) -> begin
26+
# data is the dataset from 'on' parameter
27+
# context gives access to everything else
28+
train_gap = compute_gap(b, ctx.train_dataset, model, max)
29+
data_gap = compute_gap(b, data, model, max)
30+
return train_gap / data_gap
31+
end
32+
33+
# Metric that ignores data, just uses context
34+
get_epoch = (model, max, data, ctx) -> ctx.epoch
35+
36+
# Metric that uses everything
37+
complex_metric = (model, max, data, ctx) -> begin
38+
# Can access:
39+
# - model, max (always provided)
40+
# - data (the dataset from 'on')
41+
# - ctx.epoch
42+
# - ctx.train_dataset, ctx.validation_dataset
43+
# - ctx.training_loss, ctx.validation_loss
44+
gap = compute_gap(b, data, model, max)
45+
return gap * ctx.epoch # silly example, but shows flexibility
46+
end
47+
48+
# ============================================================================
49+
# Usage - Same function signature works everywhere!
50+
# ============================================================================
51+
52+
callbacks = [
53+
# on=:validation (default) - data will be validation_dataset
54+
Metric(:gap, compute_gap),
55+
# Creates: val_gap
56+
57+
# on=:both - function called twice with train and val datasets
58+
Metric(:gap, compute_gap; on=:both),
59+
# Creates: train_gap, val_gap
60+
61+
# on=test_instances - data will be test_instances
62+
Metric(:test_gap, compute_gap; on=test_instances),
63+
# Creates: test_gap
64+
65+
# Complex metric using context
66+
Metric(:gap_ratio, compute_gap_ratio; on=:validation),
67+
# Creates: val_gap_ratio
68+
69+
# Ignore data parameter completely
70+
Metric(:current_epoch, get_epoch),
71+
# Creates: val_current_epoch (on=:validation by default)
72+
]
73+
74+
# ============================================================================
75+
# Benefits of Consistent Signature
76+
# ============================================================================
77+
78+
# ✅ ALWAYS the same signature: (model, max, data, ctx) -> value
79+
# ✅ No confusion about what arguments metric_fn receives
80+
# ✅ Easy to write - just follow one pattern
81+
# ✅ Easy to compose - all functions compatible
82+
# ✅ Full flexibility - context gives access to everything
83+
# ✅ Can ignore unused parameters (data or parts of context)
84+
85+
# ============================================================================
86+
# Comparison: OLD vs NEW
87+
# ============================================================================
88+
89+
# OLD (inconsistent signatures):
90+
# on=nothing → metric_fn(context) # 1 arg
91+
# on=:both → metric_fn(model, maximizer, dataset) # 3 args
92+
# on=data → metric_fn(model, maximizer, data) # 3 args
93+
# 😕 Confusing! Different signatures for different modes!
94+
95+
# NEW (consistent signature):
96+
# Always: metric_fn(model, maximizer, data, context) # 4 args
97+
# ✨ Clear! Same signature everywhere!
98+
99+
# ============================================================================
100+
# Practical Example: Define metrics once, use everywhere
101+
# ============================================================================
102+
103+
# Define your metrics library with consistent signature
104+
module MyMetrics
105+
gap(model, max, data, ctx) = compute_gap(benchmark, data, model, max)
106+
regret(model, max, data, ctx) = compute_regret(benchmark, data, model, max)
107+
accuracy(model, max, data, ctx) = compute_accuracy(benchmark, data, model, max)
108+
109+
# Complex metric using context
110+
function overfitting_indicator(model, max, data, ctx)
111+
train_metric = gap(model, max, ctx.train_dataset, ctx)
112+
val_metric = gap(model, max, ctx.validation_dataset, ctx)
113+
return val_metric - train_metric
114+
end
115+
end
116+
117+
# Use them easily
118+
callbacks = [
119+
Metric(:gap, MyMetrics.gap; on=:both),
120+
Metric(:regret, MyMetrics.regret; on=:both),
121+
Metric(:test_accuracy, MyMetrics.accuracy; on=test_instances),
122+
Metric(:overfitting, MyMetrics.overfitting_indicator),
123+
]
124+
125+
# ============================================================================
126+
# Advanced: Higher-order functions
127+
# ============================================================================
128+
129+
# Create a metric factory that returns properly-signed functions
130+
function dataset_metric(benchmark, compute_fn)
131+
return (model, max, data, ctx) -> compute_fn(benchmark, data, model, max)
132+
end
133+
134+
# Use it
135+
callbacks = [
136+
Metric(:gap, dataset_metric(b, compute_gap); on=:both),
137+
Metric(:regret, dataset_metric(b, compute_regret); on=:both),
138+
]
139+
140+
# ============================================================================
141+
# Migration Helper
142+
# ============================================================================
143+
144+
# If you have old-style functions: (model, max, data) -> value
145+
# Wrap them easily:
146+
old_compute_gap = (model, max, data) -> compute_gap(b, data, model, max)
147+
148+
# Convert to new signature:
149+
new_compute_gap = (model, max, data, ctx) -> old_compute_gap(model, max, data)
150+
# Or more concisely:
151+
new_compute_gap = (model, max, data, _) -> old_compute_gap(model, max, data)
152+
153+
Metric(:gap, new_compute_gap; on=:both)

examples/two_argument_signature.jl

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Simplified Metric Signature - Just (data, context)!
2+
3+
using DecisionFocusedLearningAlgorithms
4+
using DecisionFocusedLearningBenchmarks
5+
using MLUtils: splitobs
6+
7+
b = ArgmaxBenchmark()
8+
dataset = generate_dataset(b, 100)
9+
train, val, test = splitobs(dataset; at=(0.3, 0.3, 0.4))
10+
model = generate_statistical_model(b)
11+
maximizer = generate_maximizer(b)
12+
13+
# ============================================================================
14+
# NEW: Metric functions take just 2 arguments: (data, context)
15+
# Everything you need is in context!
16+
# ============================================================================
17+
18+
# Simple metric - model and maximizer from context
19+
compute_gapp = (data, ctx) -> compute_gap(b, data, ctx.model, ctx.maximizer)
20+
21+
# Complex metric - access other datasets from context
22+
compute_ratio =
23+
(data, ctx) -> begin
24+
train_gap = compute_gap(b, ctx.train_dataset, ctx.model, ctx.maximizer)
25+
val_gap = compute_gap(b, data, ctx.model, ctx.maximizer)
26+
return train_gap / val_gap
27+
end
28+
29+
# Context-only metrics - ignore data completely
30+
get_epoch = (_, ctx) -> ctx.epoch
31+
32+
# ============================================================================
33+
# Usage Examples
34+
# ============================================================================
35+
36+
callbacks = [
37+
# Default: on=:validation
38+
Metric(:gap, compute_gap),
39+
# Creates: val_gap
40+
41+
# Automatic train and validation
42+
Metric(:gap, compute_gapp; on=:both),
43+
# Creates: train_gap, val_gap
44+
45+
# Specific test set
46+
Metric(:test_gap, compute_gapp; on=test),
47+
# Creates: test_gap
48+
49+
# Complex metric using context
50+
Metric(:gap_ratio, compute_ratio),
51+
# Creates: val_gap_ratio
52+
53+
# Context-only metrics
54+
Metric(:current_epoch, get_epoch),
55+
]
56+
57+
# Note: training_loss and validation_loss are automatically tracked in history!
58+
# Access them with: get(history, :training_loss), get(history, :validation_loss)
59+
60+
history = fyl_train_model!(model, maximizer, train, val; epochs=100, callbacks=callbacks)
61+
62+
# ============================================================================
63+
# Why This is Better
64+
# ============================================================================
65+
66+
# BEFORE: Redundant parameters (4 arguments)
67+
# metric_fn(model, maximizer, data, context)
68+
# - model and maximizer are ALSO in context (redundant!)
69+
# - Longer signature
70+
# - More typing
71+
72+
# AFTER: Clean and minimal (2 arguments)
73+
# metric_fn(data, context)
74+
# - Get model from ctx.model
75+
# - Get maximizer from ctx.maximizer
76+
# - Everything in one place (context)
77+
# - Shorter, cleaner
78+
79+
# ============================================================================
80+
# Real-World Example
81+
# ============================================================================
82+
83+
# Define your metric functions
84+
compute_gap = (data, ctx) -> compute_gap(benchmark, data, ctx.model, ctx.maximizer)
85+
compute_regret = (data, ctx) -> compute_regret(benchmark, data, ctx.model, ctx.maximizer)
86+
87+
# Metric that uses multiple datasets
88+
overfitting_indicator =
89+
(data, ctx) -> begin
90+
train_metric = compute_gap(b, ctx.train_dataset, ctx.model, ctx.maximizer)
91+
val_metric = compute_gap(b, ctx.validation_dataset, ctx.model, ctx.maximizer)
92+
return val_metric - train_metric
93+
end
94+
95+
# Metric that evaluates policy on environments
96+
eval_policy = (envs, ctx) -> begin
97+
policy = Policy("", "", PolicyWrapper(ctx.model))
98+
rewards, _ = evaluate_policy!(policy, envs, 100)
99+
return mean(rewards)
100+
end
101+
102+
test_envs = generate_environments(b, test)
103+
104+
callbacks = [
105+
Metric(:gap, compute_gap; on=:both),
106+
Metric(:regret, compute_regret; on=:both),
107+
Metric(:test_gap, compute_gap; on=test),
108+
Metric(:overfitting, overfitting_indicator),
109+
Metric(:test_reward, eval_policy; on=test_envs),
110+
]
111+
112+
# ============================================================================
113+
# Metric Library Pattern
114+
# ============================================================================
115+
116+
# Create a module with all your metrics
117+
module MyMetrics
118+
gap(data, ctx) = compute_gap(benchmark, data, ctx.model, ctx.maximizer)
119+
regret(data, ctx) = compute_regret(benchmark, data, ctx.model, ctx.maximizer)
120+
121+
# More complex metrics
122+
overfitting(data, ctx) = begin
123+
train = gap(ctx.train_dataset, ctx)
124+
val = gap(ctx.validation_dataset, ctx)
125+
return val - train
126+
end
127+
end
128+
129+
# Use them
130+
callbacks = [
131+
Metric(:gap, MyMetrics.gap; on=:both),
132+
Metric(:regret, MyMetrics.regret; on=:both),
133+
Metric(:overfitting, MyMetrics.overfitting),
134+
]
135+
136+
# ============================================================================
137+
# Migration from 4-argument signature
138+
# ============================================================================
139+
140+
# If you have old 4-argument functions:
141+
old_metric = (model, max, data, ctx) -> compute_gap(b, data, model, max)
142+
143+
# Convert to new 2-argument:
144+
new_metric = (data, ctx) -> compute_gap(b, data, ctx.model, ctx.maximizer)
145+
146+
# Or just update inline:
147+
Metric(:gap, (data, ctx) -> compute_gap(b, data, ctx.model, ctx.maximizer); on=:both)
148+
149+
# ============================================================================
150+
# Benefits Summary
151+
# ============================================================================
152+
153+
# ✅ Cleaner: 2 arguments instead of 4
154+
# ✅ Less redundancy: No duplicate model/maximizer
155+
# ✅ Consistent: Everything from context
156+
# ✅ Simpler: Less to type and remember
157+
# ✅ Flexible: Context has everything you need

examples/using_mvhistory.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Using MVHistory for Metrics Storage
2+
3+
using DecisionFocusedLearningAlgorithms
4+
using DecisionFocusedLearningBenchmarks
5+
using MLUtils: splitobs
6+
using ValueHistories
7+
using Plots
8+
9+
b = ArgmaxBenchmark()
10+
dataset = generate_dataset(b, 100)
11+
train_instances, val_instances, test_instances = splitobs(dataset; at=(0.3, 0.3, 0.4))
12+
13+
model = generate_statistical_model(b; seed=0)
14+
maximizer = generate_maximizer(b)
15+
16+
compute_gap_fn = (m, max, data) -> compute_gap(b, data, m, max)
17+
18+
# Define callbacks
19+
callbacks = [
20+
Metric(:gap, compute_gap_fn; on=:both),
21+
Metric(:test_gap, compute_gap_fn; on=test_instances),
22+
]
23+
24+
# Train and get MVHistory back
25+
history = fyl_train_model!(
26+
model, maximizer, train_instances, val_instances; epochs=100, callbacks=callbacks
27+
)
28+
29+
# ============================================================================
30+
# Working with MVHistory - Much Cleaner!
31+
# ============================================================================
32+
33+
# Get values and iterations
34+
epochs, train_losses = get(history, :training_loss)
35+
epochs, val_losses = get(history, :validation_loss)
36+
epochs, train_gaps = get(history, :train_gap)
37+
epochs, val_gaps = get(history, :val_gap)
38+
test_epochs, test_gaps = get(history, :test_gap)
39+
40+
# Plot multiple metrics
41+
plot(epochs, train_losses; label="Train Loss")
42+
plot!(epochs, val_losses; label="Val Loss")
43+
44+
plot(epochs, train_gaps; label="Train Gap")
45+
plot!(epochs, val_gaps; label="Val Gap")
46+
plot!(test_epochs, test_gaps; label="Test Gap")
47+
48+
using JLD2
49+
@save "training_history.jld2" history
50+
@load "training_history.jld2" history

scripts/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[deps]
2+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
23
DecisionFocusedLearningAlgorithms = "46d52364-bc3b-4fac-a992-eb1d3ef2de15"
34
DecisionFocusedLearningBenchmarks = "2fbe496a-299b-4c81-bab5-c44dfc55cf20"
5+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
46
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
57
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
68
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
9+
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"

0 commit comments

Comments
 (0)