Skip to content

Commit 1f98d65

Browse files
author
Mads Dabros
committed
Refactor
1 parent 65f7342 commit 1f98d65

File tree

1 file changed

+20
-50
lines changed

1 file changed

+20
-50
lines changed
Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using BenchmarkDotNet.Attributes;
1+
using System.Collections.Generic;
2+
using BenchmarkDotNet.Attributes;
23
using SharpLearning.AdaBoost.Learners;
4+
using SharpLearning.Common.Interfaces;
35
using SharpLearning.Containers.Matrices;
46
using SharpLearning.DecisionTrees.Learners;
57
using SharpLearning.GradientBoost.Learners;
@@ -19,14 +21,17 @@ public class RegressionLearners
1921
double[] m_targets;
2022

2123
// Define learners here. Use default parameters for benchmarks.
22-
readonly RegressionDecisionTreeLearner m_regressionDecisionTreeLearner = new();
23-
readonly RegressionAdaBoostLearner m_regressionAdaBoostLearner = new();
24-
readonly RegressionRandomForestLearner m_regressionRandomForestLearner = new();
25-
readonly RegressionExtremelyRandomizedTreesLearner m_regressionExtremelyRandomizedTreesLearner = new();
26-
readonly RegressionAbsoluteLossGradientBoostLearner m_regressionAbsoluteLossGradientBoostLearner = new();
27-
readonly RegressionHuberLossGradientBoostLearner m_regressionHuberLossGradientBoostLearner = new();
28-
readonly RegressionQuantileLossGradientBoostLearner m_regressionQuantileLossGradientBoostLearner = new();
29-
readonly RegressionSquareLossGradientBoostLearner m_regressionSquareLossGradientBoostLearner = new();
24+
readonly Dictionary<string, ILearner<double>> m_learners = new()
25+
{
26+
{ nameof(RegressionDecisionTreeLearner), new RegressionDecisionTreeLearner() },
27+
{ nameof(RegressionAdaBoostLearner), new RegressionAdaBoostLearner() },
28+
{ nameof(RegressionRandomForestLearner), new RegressionRandomForestLearner() },
29+
{ nameof(RegressionExtremelyRandomizedTreesLearner), new RegressionExtremelyRandomizedTreesLearner() },
30+
{ nameof(RegressionAbsoluteLossGradientBoostLearner), new RegressionAbsoluteLossGradientBoostLearner() },
31+
{ nameof(RegressionHuberLossGradientBoostLearner), new RegressionHuberLossGradientBoostLearner() },
32+
{ nameof(RegressionQuantileLossGradientBoostLearner), new RegressionQuantileLossGradientBoostLearner() },
33+
{ nameof(RegressionSquareLossGradientBoostLearner), new RegressionSquareLossGradientBoostLearner() }
34+
};
3035

3136
[GlobalSetup]
3237
public void GlobalSetup()
@@ -38,51 +43,16 @@ public void GlobalSetup()
3843
}
3944

4045
[Benchmark]
41-
public void RegressionDecisionTreeLearner_Learn()
46+
[ArgumentsSource(nameof(GetLearners))]
47+
public void Learn(string learnerName)
4248
{
43-
m_regressionDecisionTreeLearner.Learn(m_features, m_targets);
49+
var learner = m_learners[learnerName];
50+
learner.Learn(m_features, m_targets);
4451
}
4552

46-
[Benchmark]
47-
public void RegressionAdaBoostLearner_Learn()
48-
{
49-
m_regressionAdaBoostLearner.Learn(m_features, m_targets);
50-
}
51-
52-
[Benchmark]
53-
public void RegressionRandomForestLearner_Learn()
54-
{
55-
m_regressionRandomForestLearner.Learn(m_features, m_targets);
56-
}
57-
58-
[Benchmark]
59-
public void RegressionExtremelyRandomizedTreesLearner_Learn()
60-
{
61-
m_regressionExtremelyRandomizedTreesLearner.Learn(m_features, m_targets);
62-
}
63-
64-
[Benchmark]
65-
public void RegressionAbsoluteLossGradientBoostLearner_Learn()
66-
{
67-
m_regressionAbsoluteLossGradientBoostLearner.Learn(m_features, m_targets);
68-
}
69-
70-
[Benchmark]
71-
public void RegressionHuberLossGradientBoostLearner_Learn()
72-
{
73-
m_regressionHuberLossGradientBoostLearner.Learn(m_features, m_targets);
74-
}
75-
76-
[Benchmark]
77-
public void RegressionQuantileLossGradientBoostLearner_Learn()
78-
{
79-
m_regressionQuantileLossGradientBoostLearner.Learn(m_features, m_targets);
80-
}
81-
82-
[Benchmark]
83-
public void RegressionSquareLossGradientBoostLearner_Learn()
53+
public IEnumerable<string> GetLearners()
8454
{
85-
m_regressionSquareLossGradientBoostLearner.Learn(m_features, m_targets);
55+
return m_learners.Keys;
8656
}
8757
}
8858
}

0 commit comments

Comments
 (0)