Skip to content

Commit 243cdd7

Browse files
authored
Fix plato-2 and plato-mini dtype bug (#767)
* fix unified transformer dtype problem * fix win dtype bug * Fix plato-2 and plato-mini dtype bug
1 parent 147a943 commit 243cdd7

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

examples/dialogue/plato-2/interaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def interact(args):
7676
example, is_infer=True)
7777
data = plato_reader._pad_batch_records([record], is_infer=True)
7878
inputs = gen_inputs(data, args.latent_type_size)
79+
inputs['tgt_ids'] = inputs['tgt_ids'].astype('int64')
7980
pred = model(inputs)[0]
8081
bot_response = pred["response"]
8182
print(

examples/dialogue/unified_transformer/interaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def interaction(args, model, tokenizer):
4848
add_start_token_as_response=True,
4949
return_tensors=True,
5050
is_split_into_words=False)
51+
inputs['input_ids'] = inputs['input_ids'].astype('int64')
5152
ids, scores = model.generate(
5253
input_ids=inputs['input_ids'],
5354
token_type_ids=inputs['token_type_ids'],

0 commit comments

Comments
 (0)