File tree Expand file tree Collapse file tree 3 files changed +9
-2
lines changed
examples/language_model/gpt/faster_gpt
ops/faster_transformer/sample Expand file tree Collapse file tree 3 files changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -124,7 +124,7 @@ def do_predict(args):
124
124
paddle .fluid .core ._cuda_synchronize (place )
125
125
logger .info ("Average test time for decoding is %f ms" % (
126
126
(time .time () - start ) / 50 * 1000 ))
127
- output_sequence = out_seq .numpy ()
127
+ output_sequence = out_seq .numpy (). tolist ()
128
128
for i in range (args .batch_size ):
129
129
print ("========== Sample-%d ==========" % i )
130
130
print (tokenizer .convert_ids_to_string (output_sequence [i ]))
Original file line number Diff line number Diff line change @@ -130,7 +130,7 @@ def do_predict(args):
130
130
paddle .device .cuda .synchronize (place )
131
131
logger .info ("Average test time for decoding is %f ms" % (
132
132
(time .time () - start ) / 50 * 1000 ))
133
- output_sequence = out_seq .numpy ()
133
+ output_sequence = out_seq .numpy (). tolist ()
134
134
for i in range (args .batch_size ):
135
135
print ("========== Sample-%d ==========" % i )
136
136
print (tokenizer .convert_ids_to_string (output_sequence [i ]))
Original file line number Diff line number Diff line change @@ -1127,6 +1127,13 @@ def prepare_faster_entry(self, kwargs):
1127
1127
raise AttributeError (
1128
1128
"'beam_search' is not supported yet in the faster version of GPT"
1129
1129
)
1130
+ # Currently, FasterTransformer only support restricted size_per_head.
1131
+ size_per_head = self .gpt .config ["hidden_size" ] // self .gpt .config [
1132
+ "num_attention_heads" ]
1133
+ if size_per_head not in [32 , 64 , 128 ]:
1134
+ raise AttributeError (
1135
+ "'size_per_head = %d' is not supported yet in the faster version of GPT"
1136
+ % size_per_head )
1130
1137
self ._faster_entry = FasterGPT (
1131
1138
self , use_fp16_decoding = use_fp16_decoding ).forward
1132
1139
return self ._faster_entry
You can’t perform that action at this time.
0 commit comments