11# TODO : every N epochs
22# TODO : best_model saving method, using default metric validation loss, overwritten in dagger
33# TODO : Implement validation loss as a metric callback
4+ # TODO : batch training option
5+ # TODO : parallelize loss computation on validation set
46
57function fyl_train_model! (
68 model,
79 maximizer,
810 train_dataset:: AbstractArray{<:DataSample} ,
911 validation_dataset;
1012 epochs= 100 ,
11- maximizer_kwargs= (sample -> (; instance= sample. instance )),
13+ maximizer_kwargs= (sample -> (; instance= sample. info )),
1214 metrics_callbacks:: NamedTuple = NamedTuple (),
1315)
14- perturbed = PerturbedAdditive (maximizer; nb_samples= 20 , ε= 1.0 , threaded= true )
16+ perturbed = PerturbedAdditive (maximizer; nb_samples= 50 , ε= 1.0 , threaded= true , seed = 0 )
1517 loss = FenchelYoungLoss (perturbed)
1618
1719 optimizer = Adam ()
1820 opt_state = Flux. setup (optimizer, model)
1921
2022 total_loss = 0.0
2123 for sample in validation_dataset
22- (; x, y_true ) = sample
23- total_loss += loss (model (x), y_true ; maximizer_kwargs (sample)... )
24+ (; x, y ) = sample
25+ total_loss += loss (model (x), y ; maximizer_kwargs (sample)... )
2426 end
2527 loss_history = [total_loss / length (validation_dataset)]
2628
29+ total_train_loss = 0.0
30+ for sample in train_dataset
31+ (; x, y) = sample
32+ total_train_loss += loss (model (x), y; maximizer_kwargs (sample)... )
33+ end
34+
2735 # Initialize metrics history with epoch 0 for type stability
2836 metrics_history = _initialize_nested_metrics (metrics_callbacks, model, maximizer, 0 )
2937
3038 # Add validation loss to metrics
3139 metrics_history = merge (
32- metrics_history, (; validation_loss= [total_loss / length (validation_dataset)])
40+ metrics_history,
41+ (;
42+ validation_loss= [total_loss / length (validation_dataset)],
43+ training_loss= [total_train_loss / length (train_dataset)],
44+ ),
3345 )
3446
3547 @showprogress for epoch in 1 : epochs
48+ l = 0
3649 for sample in train_dataset
37- (; x, y_true ) = sample
38- grads = Flux. gradient (model) do m
39- loss (m (x), y_true ; maximizer_kwargs (sample)... )
50+ (; x, y ) = sample
51+ val, grads = Flux. withgradient (model) do m
52+ loss (m (x), y ; maximizer_kwargs (sample)... )
4053 end
54+ l += val
4155 Flux. update! (opt_state, model, grads[1 ])
4256 end
4357 # Evaluate on validation set
4458 total_loss = 0.0
4559 for sample in validation_dataset
46- (; x, y_true ) = sample
47- total_loss += loss (model (x), y_true ; maximizer_kwargs (sample)... )
60+ (; x, y ) = sample
61+ total_loss += loss (model (x), y ; maximizer_kwargs (sample)... )
4862 end
4963 push! (loss_history, total_loss / length (validation_dataset))
5064 push! (metrics_history. validation_loss, total_loss / length (validation_dataset))
65+ # push!(metrics_history.training_loss, l / length(train_dataset))
66+
67+ total_loss = 0.0
68+ for sample in train_dataset
69+ (; x, y) = sample
70+ total_loss += loss (model (x), y; maximizer_kwargs (sample)... )
71+ end
72+ push! (metrics_history. training_loss, total_loss / length (train_dataset))
5173
5274 # Call metrics callbacks
5375 if ! isempty (metrics_callbacks)
@@ -64,10 +86,8 @@ function fyl_train_model!(
6486end
6587
6688function fyl_train_model (b:: AbstractBenchmark ; kwargs... )
67- dataset = generate_dataset (b, 30 )
68- train_dataset, validation_dataset, test_dataset = dataset[2 : 2 ],
69- dataset[11 : 20 ],
70- dataset[21 : 30 ]
89+ dataset = generate_dataset (b, 100 )
90+ train_dataset, validation_dataset, _ = splitobs (dataset; at= (0.3 , 0.3 , 0.4 ))
7191 model = generate_statistical_model (b)
7292 maximizer = generate_maximizer (b)
7393 return fyl_train_model! (model, maximizer, train_dataset, validation_dataset; kwargs... )
0 commit comments