Skip to content

Commit 7e8b4dc

Browse files
committed
[*] use Ensemble interface for common test cases
1 parent 986d193 commit 7e8b4dc

File tree

1 file changed

+72
-119
lines changed

1 file changed

+72
-119
lines changed

leaves_test.go

Lines changed: 72 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -57,70 +57,51 @@ func InnerTestLGMSLTR(t *testing.T, nThreads int) {
5757
}
5858

5959
func TestLGHiggs(t *testing.T) {
60-
InnerTestLGHiggs(t, 1)
61-
InnerTestLGHiggs(t, 2)
62-
InnerTestLGHiggs(t, 3)
63-
InnerTestLGHiggs(t, 4)
64-
}
65-
66-
func InnerTestLGHiggs(t *testing.T, nThreads int) {
67-
// loading test data
68-
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
69-
reader, err := os.Open(path)
70-
if err != nil {
71-
t.Skipf("Skipping due to absence of %s", path)
72-
}
73-
bufReader := bufio.NewReader(reader)
74-
csrMat, err := CSRMatFromLibsvm(bufReader, 0, true)
75-
if err != nil {
76-
t.Fatal(err)
77-
}
78-
nRows := csrMat.Rows()
79-
60+
filename := "lghiggs_1000examples_true_predictions.txt"
8061
// loading model
81-
path = filepath.Join("testdata", "lghiggs.model")
62+
path := filepath.Join("testdata", "lghiggs.model")
8263
model, err := LGEnsembleFromFile(path)
83-
if err != nil {
84-
t.Fatal(err)
85-
}
86-
87-
// loading true predictions as DenseMat
88-
path = filepath.Join("testdata", "lghiggs_1000examples_true_predictions.txt")
89-
reader, err = os.Open(path)
9064
if err != nil {
9165
t.Skipf("Skipping due to absence of %s", path)
9266
}
93-
bufReader = bufio.NewReader(reader)
94-
truePredictions, err := DenseMatFromCsv(bufReader, 0, false, ",", 0.0)
95-
if err != nil {
96-
t.Fatal(err)
97-
}
67+
const tolerance = 1e-12
9868

99-
predictions := make([]float64, nRows)
100-
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)
101-
102-
// compare results
103-
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, 1e-12); err != nil {
104-
t.Fatalf("different predictions: %s", err.Error())
105-
}
69+
// Dense matrix
70+
InnerTestHiggs(t, model, 1, true, filename, tolerance)
71+
InnerTestHiggs(t, model, 2, true, filename, tolerance)
72+
InnerTestHiggs(t, model, 3, true, filename, tolerance)
73+
InnerTestHiggs(t, model, 4, true, filename, tolerance)
74+
75+
InnerTestHiggs(t, model, 1, false, filename, tolerance)
76+
InnerTestHiggs(t, model, 2, false, filename, tolerance)
77+
InnerTestHiggs(t, model, 3, false, filename, tolerance)
78+
InnerTestHiggs(t, model, 4, false, filename, tolerance)
10679
}
10780

10881
func TestXGHiggs(t *testing.T) {
10982
t.Skip("have mismatch on 45 element")
83+
filename := "xghiggs_1000examples_true_predictions.txt"
84+
// loading model
85+
path := filepath.Join("testdata", "xghiggs.model")
86+
model, err := XGEnsembleFromFile(path)
87+
if err != nil {
88+
t.Skipf("Skipping due to absence of %s", path)
89+
}
90+
const tolerance = 1e-5
91+
11092
// Dense matrix
111-
InnerTestXGHiggs(t, 1, true)
112-
InnerTestXGHiggs(t, 2, true)
113-
InnerTestXGHiggs(t, 3, true)
114-
InnerTestXGHiggs(t, 4, true)
115-
116-
// CSR matrix
117-
InnerTestXGHiggs(t, 1, false)
118-
InnerTestXGHiggs(t, 2, false)
119-
InnerTestXGHiggs(t, 3, false)
120-
InnerTestXGHiggs(t, 4, false)
93+
InnerTestHiggs(t, model, 1, true, filename, tolerance)
94+
InnerTestHiggs(t, model, 2, true, filename, tolerance)
95+
InnerTestHiggs(t, model, 3, true, filename, tolerance)
96+
InnerTestHiggs(t, model, 4, true, filename, tolerance)
97+
98+
InnerTestHiggs(t, model, 1, false, filename, tolerance)
99+
InnerTestHiggs(t, model, 2, false, filename, tolerance)
100+
InnerTestHiggs(t, model, 3, false, filename, tolerance)
101+
InnerTestHiggs(t, model, 4, false, filename, tolerance)
121102
}
122103

123-
func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
104+
func InnerTestHiggs(t *testing.T, model Ensemble, nThreads int, dense bool, truePredictionsFilename string, tolerance float64) {
124105
// loading test data
125106
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
126107
reader, err := os.Open(path)
@@ -145,15 +126,8 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
145126
nRows = csrMat.Rows()
146127
}
147128

148-
// loading model
149-
path = filepath.Join("testdata", "xghiggs.model")
150-
model, err := XGEnsembleFromFile(path)
151-
if err != nil {
152-
t.Fatal(err)
153-
}
154-
155129
// loading true predictions as DenseMat
156-
path = filepath.Join("testdata", "xghiggs_1000examples_true_predictions.txt")
130+
path = filepath.Join("testdata", truePredictionsFilename)
157131
reader, err = os.Open(path)
158132
if err != nil {
159133
t.Skipf("Skipping due to absence of %s", path)
@@ -171,7 +145,7 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
171145
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)
172146
}
173147
// compare results
174-
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, 1e-5); err != nil {
148+
if err := almostEqualFloat64Slices(truePredictions.Values, predictions, tolerance); err != nil {
175149
t.Fatalf("different predictions: %s", err.Error())
176150
}
177151
}
@@ -213,65 +187,35 @@ func InnerBenchmarkLGMSLTR(b *testing.B, nThreads int) {
213187
}
214188

215189
func BenchmarkLGHiggs_dense_1thread(b *testing.B) {
216-
InnerBenchmarkLGHiggs(b, 1, true)
190+
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
191+
if err != nil {
192+
b.Fatal(err)
193+
}
194+
InnerBenchmarkHiggs(b, model, 1, true)
217195
}
218196

219197
func BenchmarkLGHiggs_dense_4thread(b *testing.B) {
220-
InnerBenchmarkLGHiggs(b, 4, true)
198+
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
199+
if err != nil {
200+
b.Fatal(err)
201+
}
202+
InnerBenchmarkHiggs(b, model, 4, true)
221203
}
222204

223205
func BenchmarkLGHiggs_csr_1thread(b *testing.B) {
224-
InnerBenchmarkLGHiggs(b, 1, false)
225-
}
226-
227-
func BenchmarkLGHiggs_csr_4thread(b *testing.B) {
228-
InnerBenchmarkLGHiggs(b, 4, false)
229-
}
230-
231-
func InnerBenchmarkLGHiggs(b *testing.B, nThreads int, dense bool) {
232-
// loading test data
233-
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
234-
reader, err := os.Open(path)
206+
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
235207
if err != nil {
236-
b.Skipf("Skipping due to absence of %s", path)
237-
}
238-
bufReader := bufio.NewReader(reader)
239-
var denseMat DenseMat
240-
var csrMat CSRMat
241-
var nRows uint32
242-
if dense {
243-
denseMat, err = DenseMatFromLibsvm(bufReader, 0, true)
244-
if err != nil {
245-
b.Fatal(err)
246-
}
247-
nRows = denseMat.Rows
248-
} else {
249-
csrMat, err = CSRMatFromLibsvm(bufReader, 0, true)
250-
if err != nil {
251-
b.Fatal(err)
252-
}
253-
nRows = csrMat.Rows()
208+
b.Fatal(err)
254209
}
210+
InnerBenchmarkHiggs(b, model, 1, false)
211+
}
255212

256-
// loading model
257-
path = filepath.Join("testdata", "lghiggs.model")
258-
model, err := LGEnsembleFromFile(path)
213+
func BenchmarkLGHiggs_csr_4thread(b *testing.B) {
214+
model, err := LGEnsembleFromFile(filepath.Join("testdata", "lghiggs.model"))
259215
if err != nil {
260216
b.Fatal(err)
261217
}
262-
263-
// do benchmark
264-
b.ResetTimer()
265-
predictions := make([]float64, nRows)
266-
if dense {
267-
for i := 0; i < b.N; i++ {
268-
model.PredictDense(denseMat.Values, denseMat.Rows, denseMat.Cols, predictions, 0, nThreads)
269-
}
270-
} else {
271-
for i := 0; i < b.N; i++ {
272-
model.PredictCSR(csrMat.RowHeaders, csrMat.ColIndexes, csrMat.Values, predictions, 0, nThreads)
273-
}
274-
}
218+
InnerBenchmarkHiggs(b, model, 4, false)
275219
}
276220

277221
func TestXGAgaricus_1thread(t *testing.T) {
@@ -327,22 +271,38 @@ func InnerTestXGAgaricus(t *testing.T, nThreads int) {
327271
}
328272

329273
func BenchmarkXGHiggs_dense_1thread(b *testing.B) {
330-
InnerBenchmarkXGHiggs(b, 1, true)
274+
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
275+
if err != nil {
276+
b.Fatal(err)
277+
}
278+
InnerBenchmarkHiggs(b, model, 1, true)
331279
}
332280

333281
func BenchmarkXGHiggs_dense_4thread(b *testing.B) {
334-
InnerBenchmarkXGHiggs(b, 4, true)
282+
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
283+
if err != nil {
284+
b.Fatal(err)
285+
}
286+
InnerBenchmarkHiggs(b, model, 4, true)
335287
}
336288

337289
func BenchmarkXGHiggs_csr_1thread(b *testing.B) {
338-
InnerBenchmarkXGHiggs(b, 1, false)
290+
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
291+
if err != nil {
292+
b.Fatal(err)
293+
}
294+
InnerBenchmarkHiggs(b, model, 1, false)
339295
}
340296

341297
func BenchmarkXGHiggs_csr_4thread(b *testing.B) {
342-
InnerBenchmarkXGHiggs(b, 4, false)
298+
model, err := XGEnsembleFromFile(filepath.Join("testdata", "xghiggs.model"))
299+
if err != nil {
300+
b.Fatal(err)
301+
}
302+
InnerBenchmarkHiggs(b, model, 4, false)
343303
}
344304

345-
func InnerBenchmarkXGHiggs(b *testing.B, nThreads int, dense bool) {
305+
func InnerBenchmarkHiggs(b *testing.B, model Ensemble, nThreads int, dense bool) {
346306
// loading test data
347307
path := filepath.Join("testdata", "higgs_1000examples_test.libsvm")
348308
reader, err := os.Open(path)
@@ -367,13 +327,6 @@ func InnerBenchmarkXGHiggs(b *testing.B, nThreads int, dense bool) {
367327
nRows = csrMat.Rows()
368328
}
369329

370-
// loading model
371-
path = filepath.Join("testdata", "xghiggs.model")
372-
model, err := XGEnsembleFromFile(path)
373-
if err != nil {
374-
b.Fatal(err)
375-
}
376-
377330
// do benchmark
378331
b.ResetTimer()
379332
predictions := make([]float64, nRows)

0 commit comments

Comments
 (0)