Skip to content

Commit 65f7342

Browse files
author
Mads Dabros
committed
Refactor
1 parent b1576e1 commit 65f7342

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed
Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using BenchmarkDotNet.Attributes;
1+
using System.Collections.Generic;
2+
using BenchmarkDotNet.Attributes;
23
using SharpLearning.AdaBoost.Learners;
4+
using SharpLearning.Common.Interfaces;
35
using SharpLearning.Containers.Matrices;
46
using SharpLearning.DecisionTrees.Learners;
57
using SharpLearning.GradientBoost.Learners;
@@ -21,11 +23,14 @@ public class ClassificationLearners
2123
double[] m_targets;
2224

2325
// Define learners here. Use default parameters for benchmarks.
24-
readonly ClassificationDecisionTreeLearner m_classificationDecisionTreeLearner = new();
25-
readonly ClassificationAdaBoostLearner m_classificationAdaBoostLearner = new();
26-
readonly ClassificationRandomForestLearner m_classificationRandomForestLearner = new();
27-
readonly ClassificationExtremelyRandomizedTreesLearner m_classificationExtremelyRandomizedTreesLearner = new();
28-
readonly ClassificationBinomialGradientBoostLearner m_classificationBinomialGradientBoostLearner = new();
26+
readonly Dictionary<string, ILearner<double>> m_learners = new()
27+
{
28+
{ nameof(ClassificationDecisionTreeLearner), new ClassificationDecisionTreeLearner() },
29+
{ nameof(ClassificationAdaBoostLearner), new ClassificationAdaBoostLearner() },
30+
{ nameof(ClassificationRandomForestLearner), new ClassificationRandomForestLearner() },
31+
{ nameof(ClassificationExtremelyRandomizedTreesLearner), new ClassificationExtremelyRandomizedTreesLearner() },
32+
{ nameof(ClassificationBinomialGradientBoostLearner), new ClassificationBinomialGradientBoostLearner() },
33+
};
2934

3035
[GlobalSetup]
3136
public void GlobalSetup()
@@ -38,33 +43,16 @@ public void GlobalSetup()
3843
}
3944

4045
[Benchmark]
41-
public void ClassificationDecisionTreeLearner_Learn()
42-
{
43-
m_classificationDecisionTreeLearner.Learn(m_features, m_targets);
44-
}
45-
46-
[Benchmark]
47-
public void ClassificationAdaBoostLearner_Learn()
48-
{
49-
m_classificationAdaBoostLearner.Learn(m_features, m_targets);
50-
}
51-
52-
[Benchmark]
53-
public void ClassificationRandomForestLearner_Learn()
46+
[ArgumentsSource(nameof(GetLearners))]
47+
public void Learn(string learnerName)
5448
{
55-
m_classificationRandomForestLearner.Learn(m_features, m_targets);
49+
var learner = m_learners[learnerName];
50+
learner.Learn(m_features, m_targets);
5651
}
5752

58-
[Benchmark]
59-
public void ClassificationExtremelyRandomizedTreesLearner_Learn()
60-
{
61-
m_classificationExtremelyRandomizedTreesLearner.Learn(m_features, m_targets);
62-
}
63-
64-
[Benchmark]
65-
public void ClassificationBinomialGradientBoostLearner_Learn()
53+
public IEnumerable<string> GetLearners()
6654
{
67-
m_classificationBinomialGradientBoostLearner.Learn(m_features, m_targets);
55+
return m_learners.Keys;
6856
}
6957
}
7058
}

0 commit comments

Comments
 (0)