@@ -100,19 +100,6 @@ def do_train(args):
100
100
Stack (dtype = "int64" ), # masked_positions
101
101
Stack (dtype = "int64" ), # candidate_labels_ids [candidate_num, label_length]
102
102
): [data for data in fn (samples )]
103
- elif args .task_name == "cluewsc" :
104
- batchify_fn = lambda samples , fn = Tuple (
105
- Pad (axis = 0 , pad_val = tokenizer .pad_token_id ), # src_ids
106
- Pad (axis = 0 , pad_val = tokenizer .pad_token_type_id ), # token_type_ids
107
- Stack (dtype = "int64" ), # masked_positions
108
- Stack (dtype = "int64" ), # masked_lm_labels
109
- ): [data for data in fn (samples )]
110
- batchify_test_fn = lambda samples , fn = Tuple (
111
- Pad (axis = 0 , pad_val = tokenizer .pad_token_id ), # src_ids
112
- Pad (axis = 0 , pad_val = tokenizer .pad_token_type_id ), # token_type_ids
113
- Stack (dtype = "int64" ), # masked_positions
114
- Stack (dtype = "int64" ), # masked_positions
115
- ): [data for data in fn (samples )]
116
103
else :
117
104
# [src_ids, token_type_ids, masked_positions, masked_lm_labels]
118
105
batchify_fn = lambda samples , fn = Tuple (
0 commit comments