Skip to content

Commit f3010d9

Browse files
authored
Merge pull request #1411 from Steffy-zxf/upstream-develop
fix skep data type on the win
2 parents 208e880 + c43df37 commit f3010d9

File tree

6 files changed

+14
-14
lines changed

6 files changed

+14
-14
lines changed

examples/sentiment_analysis/skep/predict_aspect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def convert_example(example,
104104
text_pair=example["text_pair"],
105105
max_seq_len=max_seq_length)
106106

107-
input_ids = encoded_inputs["input_ids"]
108-
token_type_ids = encoded_inputs["token_type_ids"]
107+
input_ids = np.array(encoded_inputs["input_ids"], dtype="int64")
108+
token_type_ids = np.array(encoded_inputs["token_type_ids"], dtype="int64")
109109

110110
if not is_test:
111111
label = np.array([example["label"]], dtype="int64")

examples/sentiment_analysis/skep/predict_opinion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
7070
return_length=True,
7171
is_split_into_words=True,
7272
max_seq_len=max_seq_length)
73-
input_ids = encoded_inputs["input_ids"]
74-
token_type_ids = encoded_inputs["token_type_ids"]
75-
seq_len = encoded_inputs["seq_len"]
73+
input_ids = np.array(encoded_inputs["input_ids"], dtype="int64")
74+
token_type_ids = np.array(encoded_inputs["token_type_ids"], dtype="int64")
75+
seq_len = np.array(encoded_inputs["seq_len"], dtype="int64")
7676

7777
return input_ids, token_type_ids, seq_len
7878

examples/sentiment_analysis/skep/predict_sentence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def convert_example(example,
7171
token_type_ids(obj: `list[int]`): List of sequence pair mask.
7272
"""
7373
encoded_inputs = tokenizer(text=example, max_seq_len=max_seq_length)
74-
input_ids = encoded_inputs["input_ids"]
75-
token_type_ids = encoded_inputs["token_type_ids"]
74+
input_ids = np.array(encoded_inputs["input_ids"], dtype="int64")
75+
token_type_ids = np.array(encoded_inputs["token_type_ids"], dtype="int64")
7676

7777
return input_ids, token_type_ids
7878

examples/sentiment_analysis/skep/train_aspect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def convert_example(example,
9393
text_pair=example["text_pair"],
9494
max_seq_len=max_seq_length)
9595

96-
input_ids = encoded_inputs["input_ids"]
97-
token_type_ids = encoded_inputs["token_type_ids"]
96+
input_ids = np.array(encoded_inputs["input_ids"], dtype="int64")
97+
token_type_ids = np.array(encoded_inputs["token_type_ids"], dtype="int64")
9898

9999
if not is_test:
100100
label = np.array([example["label"]], dtype="int64")

examples/sentiment_analysis/skep/train_opinion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def convert_example_to_feature(example,
9494
is_split_into_words=True,
9595
max_seq_len=max_seq_len)
9696

97-
input_ids = tokenized_input['input_ids']
98-
token_type_ids = tokenized_input['token_type_ids']
99-
seq_len = tokenized_input['seq_len']
97+
input_ids = np.array(tokenized_input['input_ids'], dtype="int64")
98+
token_type_ids = np.array(tokenized_input['token_type_ids'], dtype="int64")
99+
seq_len = np.array(tokenized_input['seq_len'], dtype="int64")
100100

101101
if is_test:
102102
return input_ids, token_type_ids, seq_len

examples/sentiment_analysis/skep/train_sentence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def convert_example(example,
122122
encoded_inputs = tokenizer(
123123
text=example["text"], max_seq_len=max_seq_length)
124124

125-
input_ids = encoded_inputs["input_ids"]
126-
token_type_ids = encoded_inputs["token_type_ids"]
125+
input_ids = np.array(encoded_inputs["input_ids"], dtype="int64")
126+
token_type_ids = np.array(encoded_inputs["token_type_ids"], dtype="int64")
127127

128128
if not is_test:
129129
if dataset_name == "sst-2":

0 commit comments

Comments
 (0)