12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import numpy as np
16
15
import os
16
+ import unittest
17
17
18
18
from paddlenlp .data import SamplerHelper
19
19
from paddlenlp .datasets import load_dataset
20
-
21
- from common_test import CpuCommonTest
22
- import util
23
- import unittest
20
+ from tests .common_test import CpuCommonTest
21
+ from tests .testing_utils import assert_raises , get_tests_dir
24
22
25
23
26
24
def cmp (x , y ):
27
25
return - 1 if x < y else 1 if x > y else 0
28
26
29
27
30
28
class TestSampler (CpuCommonTest ):
31
- def setUp ( self ):
32
- self . config [ "path" ] = "imdb"
33
- self . config [ "splits" ] = "train"
34
- self .train_ds = load_dataset (** self . config )
29
+ @ classmethod
30
+ def setUpClass ( cls ):
31
+ fixture_path = get_tests_dir ( os . path . join ( "fixtures" , "dummy" ))
32
+ cls .train_ds = load_dataset ("clue" , "tnews" , data_files = [ os . path . join ( fixture_path , "tnews" , "train.json" )] )
35
33
36
34
def test_length (self ):
37
35
train_batch_sampler = SamplerHelper (self .train_ds )
38
- self .check_output_equal (len (train_batch_sampler ), 25000 )
36
+ self .check_output_equal (len (train_batch_sampler ), 10 )
39
37
self .check_output_equal (len (train_batch_sampler ), train_batch_sampler .length )
40
38
41
- train_batch_sampler .length = 20
42
- self .check_output_equal (len (train_batch_sampler ), 20 )
39
+ train_batch_sampler .length = 5
40
+ self .check_output_equal (len (train_batch_sampler ), 5 )
43
41
44
42
def test_iter1 (self ):
45
43
train_ds_len = len (self .train_ds )
@@ -63,23 +61,15 @@ def test_list(self):
63
61
def test_shuffle_no_buffer_size (self ):
64
62
train_batch_sampler = SamplerHelper (self .train_ds )
65
63
shuffle_sampler = train_batch_sampler .shuffle (seed = 102 )
66
- expected_result = {
67
- 0 : 5189 ,
68
- 12000 : 11777 ,
69
- 24999 : 10496 ,
70
- }
64
+ expected_result = {0 : 4 , 1 : 9 }
71
65
for i , sample in enumerate (shuffle_sampler ):
72
66
if i in expected_result .keys ():
73
67
self .check_output_equal (sample , expected_result [i ])
74
68
75
69
def test_shuffle_buffer_size (self ):
76
70
train_batch_sampler = SamplerHelper (self .train_ds )
77
71
shuffle_sampler = train_batch_sampler .shuffle (buffer_size = 10 , seed = 102 )
78
- expected_result = {
79
- 0 : 4 ,
80
- 12000 : 12003 ,
81
- 24999 : 24997 ,
82
- }
72
+ expected_result = {0 : 4 , 1 : 9 }
83
73
for i , sample in enumerate (shuffle_sampler ):
84
74
if i in expected_result .keys ():
85
75
self .check_output_equal (sample , expected_result [i ])
@@ -88,12 +78,12 @@ def test_sort_buffer_size(self):
88
78
train_ds_len = len (self .train_ds )
89
79
ds_iter = iter (range (train_ds_len - 1 , - 1 , - 1 ))
90
80
train_batch_sampler = SamplerHelper (self .train_ds , ds_iter )
91
- sort_sampler = train_batch_sampler .sort (cmp = lambda x , y , dataset : cmp (x , y ), buffer_size = 12500 )
81
+ sort_sampler = train_batch_sampler .sort (cmp = lambda x , y , dataset : cmp (x , y ), buffer_size = 5 )
92
82
for i , sample in enumerate (sort_sampler ):
93
- if i < 12500 :
94
- self .check_output_equal (i + 12500 , sample )
83
+ if i < 5 :
84
+ self .check_output_equal (i + 5 , sample )
95
85
else :
96
- self .check_output_equal (i - 12500 , sample )
86
+ self .check_output_equal (i - 5 , sample )
97
87
98
88
def test_sort_no_buffer_size (self ):
99
89
train_ds_len = len (self .train_ds )
@@ -111,14 +101,16 @@ def test_batch(self):
111
101
for j , minibatch in enumerate (sample ):
112
102
self .check_output_equal (i * batch_size + j , minibatch )
113
103
114
- @util . assert_raises (ValueError )
104
+ @assert_raises (ValueError )
115
105
def test_batch_oversize (self ):
116
106
train_batch_sampler = SamplerHelper (self .train_ds )
117
107
batch_size = 3
118
- key = lambda size_so_far , minibatch_len : max (size_so_far , minibatch_len )
119
- batch_size_fn = lambda new , count , sofar , data_source : len (data_source )
120
108
121
- batch_sampler = train_batch_sampler .batch (batch_size , key = key , batch_size_fn = batch_size_fn )
109
+ batch_sampler = train_batch_sampler .batch (
110
+ batch_size ,
111
+ key = lambda size_so_far , minibatch_len : max (size_so_far , minibatch_len ),
112
+ batch_size_fn = lambda new , count , sofar , data_source : len (data_source ),
113
+ )
122
114
for i , sample in enumerate (batch_sampler ):
123
115
for j , minibatch in enumerate (sample ):
124
116
self .check_output_equal (i * batch_size + j , minibatch )
@@ -143,8 +135,9 @@ def test_apply(self):
143
135
train_ds_len = len (self .train_ds )
144
136
ds_iter = iter (range (train_ds_len - 1 , - 1 , - 1 ))
145
137
train_batch_sampler = SamplerHelper (self .train_ds , ds_iter )
146
- fn = lambda sampler : SamplerHelper .sort (sampler , cmp = lambda x , y , dataset : cmp (x , y ))
147
- apply_sampler = train_batch_sampler .apply (fn )
138
+ apply_sampler = train_batch_sampler .apply (
139
+ lambda sampler : SamplerHelper .sort (sampler , cmp = lambda x , y , dataset : cmp (x , y ))
140
+ )
148
141
for i , sample in enumerate (apply_sampler ):
149
142
self .check_output_equal (i , sample )
150
143
0 commit comments