@@ -57,70 +57,51 @@ func InnerTestLGMSLTR(t *testing.T, nThreads int) {
5757}
5858
5959func TestLGHiggs (t * testing.T ) {
60- InnerTestLGHiggs (t , 1 )
61- InnerTestLGHiggs (t , 2 )
62- InnerTestLGHiggs (t , 3 )
63- InnerTestLGHiggs (t , 4 )
64- }
65-
66- func InnerTestLGHiggs (t * testing.T , nThreads int ) {
67- // loading test data
68- path := filepath .Join ("testdata" , "higgs_1000examples_test.libsvm" )
69- reader , err := os .Open (path )
70- if err != nil {
71- t .Skipf ("Skipping due to absence of %s" , path )
72- }
73- bufReader := bufio .NewReader (reader )
74- csrMat , err := CSRMatFromLibsvm (bufReader , 0 , true )
75- if err != nil {
76- t .Fatal (err )
77- }
78- nRows := csrMat .Rows ()
79-
60+ filename := "lghiggs_1000examples_true_predictions.txt"
8061 // loading model
81- path = filepath .Join ("testdata" , "lghiggs.model" )
62+ path : = filepath .Join ("testdata" , "lghiggs.model" )
8263 model , err := LGEnsembleFromFile (path )
83- if err != nil {
84- t .Fatal (err )
85- }
86-
87- // loading true predictions as DenseMat
88- path = filepath .Join ("testdata" , "lghiggs_1000examples_true_predictions.txt" )
89- reader , err = os .Open (path )
9064 if err != nil {
9165 t .Skipf ("Skipping due to absence of %s" , path )
9266 }
93- bufReader = bufio .NewReader (reader )
94- truePredictions , err := DenseMatFromCsv (bufReader , 0 , false , "," , 0.0 )
95- if err != nil {
96- t .Fatal (err )
97- }
67+ const tolerance = 1e-12
9868
99- predictions := make ([]float64 , nRows )
100- model .PredictCSR (csrMat .RowHeaders , csrMat .ColIndexes , csrMat .Values , predictions , 0 , nThreads )
101-
102- // compare results
103- if err := almostEqualFloat64Slices (truePredictions .Values , predictions , 1e-12 ); err != nil {
104- t .Fatalf ("different predictions: %s" , err .Error ())
105- }
69+ // Dense matrix
70+ InnerTestHiggs (t , model , 1 , true , filename , tolerance )
71+ InnerTestHiggs (t , model , 2 , true , filename , tolerance )
72+ InnerTestHiggs (t , model , 3 , true , filename , tolerance )
73+ InnerTestHiggs (t , model , 4 , true , filename , tolerance )
74+
75+ InnerTestHiggs (t , model , 1 , false , filename , tolerance )
76+ InnerTestHiggs (t , model , 2 , false , filename , tolerance )
77+ InnerTestHiggs (t , model , 3 , false , filename , tolerance )
78+ InnerTestHiggs (t , model , 4 , false , filename , tolerance )
10679}
10780
10881func TestXGHiggs (t * testing.T ) {
10982 t .Skip ("have mismatch on 45 element" )
83+ filename := "xghiggs_1000examples_true_predictions.txt"
84+ // loading model
85+ path := filepath .Join ("testdata" , "xghiggs.model" )
86+ model , err := XGEnsembleFromFile (path )
87+ if err != nil {
88+ t .Skipf ("Skipping due to absence of %s" , path )
89+ }
90+ const tolerance = 1e-5
91+
11092 // Dense matrix
111- InnerTestXGHiggs (t , 1 , true )
112- InnerTestXGHiggs (t , 2 , true )
113- InnerTestXGHiggs (t , 3 , true )
114- InnerTestXGHiggs (t , 4 , true )
115-
116- // CSR matrix
117- InnerTestXGHiggs (t , 1 , false )
118- InnerTestXGHiggs (t , 2 , false )
119- InnerTestXGHiggs (t , 3 , false )
120- InnerTestXGHiggs (t , 4 , false )
93+ InnerTestHiggs (t , model , 1 , true , filename , tolerance )
94+ InnerTestHiggs (t , model , 2 , true , filename , tolerance )
95+ InnerTestHiggs (t , model , 3 , true , filename , tolerance )
96+ InnerTestHiggs (t , model , 4 , true , filename , tolerance )
97+
98+ InnerTestHiggs (t , model , 1 , false , filename , tolerance )
99+ InnerTestHiggs (t , model , 2 , false , filename , tolerance )
100+ InnerTestHiggs (t , model , 3 , false , filename , tolerance )
101+ InnerTestHiggs (t , model , 4 , false , filename , tolerance )
121102}
122103
123- func InnerTestXGHiggs (t * testing.T , nThreads int , dense bool ) {
104+ func InnerTestHiggs (t * testing.T , model Ensemble , nThreads int , dense bool , truePredictionsFilename string , tolerance float64 ) {
124105 // loading test data
125106 path := filepath .Join ("testdata" , "higgs_1000examples_test.libsvm" )
126107 reader , err := os .Open (path )
@@ -145,15 +126,8 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
145126 nRows = csrMat .Rows ()
146127 }
147128
148- // loading model
149- path = filepath .Join ("testdata" , "xghiggs.model" )
150- model , err := XGEnsembleFromFile (path )
151- if err != nil {
152- t .Fatal (err )
153- }
154-
155129 // loading true predictions as DenseMat
156- path = filepath .Join ("testdata" , "xghiggs_1000examples_true_predictions.txt" )
130+ path = filepath .Join ("testdata" , truePredictionsFilename )
157131 reader , err = os .Open (path )
158132 if err != nil {
159133 t .Skipf ("Skipping due to absence of %s" , path )
@@ -171,7 +145,7 @@ func InnerTestXGHiggs(t *testing.T, nThreads int, dense bool) {
171145 model .PredictCSR (csrMat .RowHeaders , csrMat .ColIndexes , csrMat .Values , predictions , 0 , nThreads )
172146 }
173147 // compare results
174- if err := almostEqualFloat64Slices (truePredictions .Values , predictions , 1e-5 ); err != nil {
148+ if err := almostEqualFloat64Slices (truePredictions .Values , predictions , tolerance ); err != nil {
175149 t .Fatalf ("different predictions: %s" , err .Error ())
176150 }
177151}
@@ -213,65 +187,35 @@ func InnerBenchmarkLGMSLTR(b *testing.B, nThreads int) {
213187}
214188
215189func BenchmarkLGHiggs_dense_1thread (b * testing.B ) {
216- InnerBenchmarkLGHiggs (b , 1 , true )
190+ model , err := LGEnsembleFromFile (filepath .Join ("testdata" , "lghiggs.model" ))
191+ if err != nil {
192+ b .Fatal (err )
193+ }
194+ InnerBenchmarkHiggs (b , model , 1 , true )
217195}
218196
219197func BenchmarkLGHiggs_dense_4thread (b * testing.B ) {
220- InnerBenchmarkLGHiggs (b , 4 , true )
198+ model , err := LGEnsembleFromFile (filepath .Join ("testdata" , "lghiggs.model" ))
199+ if err != nil {
200+ b .Fatal (err )
201+ }
202+ InnerBenchmarkHiggs (b , model , 4 , true )
221203}
222204
223205func BenchmarkLGHiggs_csr_1thread (b * testing.B ) {
224- InnerBenchmarkLGHiggs (b , 1 , false )
225- }
226-
227- func BenchmarkLGHiggs_csr_4thread (b * testing.B ) {
228- InnerBenchmarkLGHiggs (b , 4 , false )
229- }
230-
231- func InnerBenchmarkLGHiggs (b * testing.B , nThreads int , dense bool ) {
232- // loading test data
233- path := filepath .Join ("testdata" , "higgs_1000examples_test.libsvm" )
234- reader , err := os .Open (path )
206+ model , err := LGEnsembleFromFile (filepath .Join ("testdata" , "lghiggs.model" ))
235207 if err != nil {
236- b .Skipf ("Skipping due to absence of %s" , path )
237- }
238- bufReader := bufio .NewReader (reader )
239- var denseMat DenseMat
240- var csrMat CSRMat
241- var nRows uint32
242- if dense {
243- denseMat , err = DenseMatFromLibsvm (bufReader , 0 , true )
244- if err != nil {
245- b .Fatal (err )
246- }
247- nRows = denseMat .Rows
248- } else {
249- csrMat , err = CSRMatFromLibsvm (bufReader , 0 , true )
250- if err != nil {
251- b .Fatal (err )
252- }
253- nRows = csrMat .Rows ()
208+ b .Fatal (err )
254209 }
210+ InnerBenchmarkHiggs (b , model , 1 , false )
211+ }
255212
256- // loading model
257- path = filepath .Join ("testdata" , "lghiggs.model" )
258- model , err := LGEnsembleFromFile (path )
213+ func BenchmarkLGHiggs_csr_4thread (b * testing.B ) {
214+ model , err := LGEnsembleFromFile (filepath .Join ("testdata" , "lghiggs.model" ))
259215 if err != nil {
260216 b .Fatal (err )
261217 }
262-
263- // do benchmark
264- b .ResetTimer ()
265- predictions := make ([]float64 , nRows )
266- if dense {
267- for i := 0 ; i < b .N ; i ++ {
268- model .PredictDense (denseMat .Values , denseMat .Rows , denseMat .Cols , predictions , 0 , nThreads )
269- }
270- } else {
271- for i := 0 ; i < b .N ; i ++ {
272- model .PredictCSR (csrMat .RowHeaders , csrMat .ColIndexes , csrMat .Values , predictions , 0 , nThreads )
273- }
274- }
218+ InnerBenchmarkHiggs (b , model , 4 , false )
275219}
276220
277221func TestXGAgaricus_1thread (t * testing.T ) {
@@ -327,22 +271,38 @@ func InnerTestXGAgaricus(t *testing.T, nThreads int) {
327271}
328272
329273func BenchmarkXGHiggs_dense_1thread (b * testing.B ) {
330- InnerBenchmarkXGHiggs (b , 1 , true )
274+ model , err := XGEnsembleFromFile (filepath .Join ("testdata" , "xghiggs.model" ))
275+ if err != nil {
276+ b .Fatal (err )
277+ }
278+ InnerBenchmarkHiggs (b , model , 1 , true )
331279}
332280
333281func BenchmarkXGHiggs_dense_4thread (b * testing.B ) {
334- InnerBenchmarkXGHiggs (b , 4 , true )
282+ model , err := XGEnsembleFromFile (filepath .Join ("testdata" , "xghiggs.model" ))
283+ if err != nil {
284+ b .Fatal (err )
285+ }
286+ InnerBenchmarkHiggs (b , model , 4 , true )
335287}
336288
337289func BenchmarkXGHiggs_csr_1thread (b * testing.B ) {
338- InnerBenchmarkXGHiggs (b , 1 , false )
290+ model , err := XGEnsembleFromFile (filepath .Join ("testdata" , "xghiggs.model" ))
291+ if err != nil {
292+ b .Fatal (err )
293+ }
294+ InnerBenchmarkHiggs (b , model , 1 , false )
339295}
340296
341297func BenchmarkXGHiggs_csr_4thread (b * testing.B ) {
342- InnerBenchmarkXGHiggs (b , 4 , false )
298+ model , err := XGEnsembleFromFile (filepath .Join ("testdata" , "xghiggs.model" ))
299+ if err != nil {
300+ b .Fatal (err )
301+ }
302+ InnerBenchmarkHiggs (b , model , 4 , false )
343303}
344304
345- func InnerBenchmarkXGHiggs (b * testing.B , nThreads int , dense bool ) {
305+ func InnerBenchmarkHiggs (b * testing.B , model Ensemble , nThreads int , dense bool ) {
346306 // loading test data
347307 path := filepath .Join ("testdata" , "higgs_1000examples_test.libsvm" )
348308 reader , err := os .Open (path )
@@ -367,13 +327,6 @@ func InnerBenchmarkXGHiggs(b *testing.B, nThreads int, dense bool) {
367327 nRows = csrMat .Rows ()
368328 }
369329
370- // loading model
371- path = filepath .Join ("testdata" , "xghiggs.model" )
372- model , err := XGEnsembleFromFile (path )
373- if err != nil {
374- b .Fatal (err )
375- }
376-
377330 // do benchmark
378331 b .ResetTimer ()
379332 predictions := make ([]float64 , nRows )
0 commit comments