33# TODO : Implement validation loss as a metric callback
44# TODO : batch training option
55# TODO : parallelize loss computation on validation set
6+ # TODO : have supervised learning training method, where fyl_train calls it, therefore we can easily test new supervised losses if needed
7+ # TODO : easier way to define and provide metrics
68
79function fyl_train_model! (
810 model,
@@ -13,7 +15,7 @@ function fyl_train_model!(
1315 maximizer_kwargs= (sample -> (; instance= sample. info)),
1416 metrics_callbacks:: NamedTuple = NamedTuple (),
1517)
16- perturbed = PerturbedAdditive (maximizer; nb_samples= 50 , ε= 1 .0 , threaded= true , seed= 0 )
18+ perturbed = PerturbedAdditive (maximizer; nb_samples= 50 , ε= 0 .0 , threaded= true , seed= 0 )
1719 loss = FenchelYoungLoss (perturbed)
1820
1921 optimizer = Adam ()
@@ -86,7 +88,7 @@ function fyl_train_model!(
8688end
8789
8890function fyl_train_model (b:: AbstractBenchmark ; kwargs... )
89- dataset = generate_dataset (b, 100 )
91+ dataset = generate_dataset (b, 20 )
9092 train_dataset, validation_dataset, _ = splitobs (dataset; at= (0.3 , 0.3 , 0.4 ))
9193 model = generate_statistical_model (b)
9294 maximizer = generate_maximizer (b)
0 commit comments