@@ -54,51 +54,59 @@ def to_lodtensor(data, place):
54
54
return res
55
55
56
56
57
- def chop_data (data , chop_len = 80 , batch_len = 50 ):
57
+ def chop_data (data , chop_len = 80 , batch_size = 50 ):
58
58
data = [(x [0 ][:chop_len ], x [1 ]) for x in data if len (x [0 ]) >= chop_len ]
59
59
60
- return data [:batch_len ]
60
+ return data [:batch_size ]
61
61
62
62
63
63
def prepare_feed_data (data , place ):
64
64
tensor_words = to_lodtensor (map (lambda x : x [0 ], data ), place )
65
65
66
66
label = np .array (map (lambda x : x [1 ], data )).astype ("int64" )
67
- label = label .reshape ([50 , 1 ])
67
+ label = label .reshape ([len ( label ) , 1 ])
68
68
tensor_label = core .LoDTensor ()
69
69
tensor_label .set (label , place )
70
70
71
71
return tensor_words , tensor_label
72
72
73
73
74
74
def main ():
75
- word_dict = paddle . dataset . imdb . word_dict ()
76
- cost , acc = lstm_net ( dict_dim = len ( word_dict ), class_dim = 2 )
75
+ BATCH_SIZE = 100
76
+ PASS_NUM = 5
77
77
78
- batch_size = 100
79
- train_data = paddle .batch (
80
- paddle .reader .buffered (
81
- paddle .dataset .imdb .train (word_dict ), size = batch_size * 10 ),
82
- batch_size = batch_size )
78
+ word_dict = paddle .dataset .imdb .word_dict ()
79
+ print "load word dict successfully"
80
+ dict_dim = len (word_dict )
81
+ class_dim = 2
83
82
84
- data = chop_data ( next ( train_data ()) )
83
+ cost , acc = lstm_net ( dict_dim = dict_dim , class_dim = class_dim )
85
84
85
+ train_data = paddle .batch (
86
+ paddle .reader .shuffle (
87
+ paddle .dataset .imdb .train (word_dict ), buf_size = BATCH_SIZE * 10 ),
88
+ batch_size = BATCH_SIZE )
86
89
place = core .CPUPlace ()
87
- tensor_words , tensor_label = prepare_feed_data (data , place )
88
90
exe = Executor (place )
91
+
89
92
exe .run (framework .default_startup_program ())
90
93
91
- while True :
92
- outs = exe .run (framework .default_main_program (),
93
- feed = {"words" : tensor_words ,
94
- "label" : tensor_label },
95
- fetch_list = [cost , acc ])
96
- cost_val = np .array (outs [0 ])
97
- acc_val = np .array (outs [1 ])
98
-
99
- print ("cost=" + str (cost_val ) + " acc=" + str (acc_val ))
100
- if acc_val > 0.9 :
101
- break
94
+ for pass_id in xrange (PASS_NUM ):
95
+ for data in train_data ():
96
+ chopped_data = chop_data (data )
97
+ tensor_words , tensor_label = prepare_feed_data (chopped_data , place )
98
+
99
+ outs = exe .run (framework .default_main_program (),
100
+ feed = {"words" : tensor_words ,
101
+ "label" : tensor_label },
102
+ fetch_list = [cost , acc ])
103
+ cost_val = np .array (outs [0 ])
104
+ acc_val = np .array (outs [1 ])
105
+
106
+ print ("cost=" + str (cost_val ) + " acc=" + str (acc_val ))
107
+ if acc_val > 0.7 :
108
+ exit (0 )
109
+ exit (1 )
102
110
103
111
104
112
if __name__ == '__main__' :
0 commit comments