diff --git a/probing/data_former.py b/probing/data_former.py index 30a447f..2f6b2c2 100644 --- a/probing/data_former.py +++ b/probing/data_former.py @@ -24,7 +24,7 @@ def __init__( self.shuffle = shuffle self.data_path = get_probe_task_path(probe_task, data_path) - self.samples, self.unique_labels = self.form_data(sep=sep) + self.samples, self.unique_labels, self.num_words = self.form_data(sep=sep) def __len__(self): return len(self.samples) @@ -48,8 +48,9 @@ def form_data( samples_dict = defaultdict(list) unique_labels = set() dataset = pd.read_csv(self.data_path, sep=sep, header=None, dtype=str) - for _, (stage, label, text) in dataset.iterrows(): - samples_dict[stage].append((text, label)) + for _, (stage, label, text, word_indices) in dataset.iterrows(): + num_words = len(word_indices) + samples_dict[stage].append((text, label, word_indices)) unique_labels.add(label) if self.shuffle: @@ -58,7 +59,7 @@ def form_data( } else: samples_dict = {k: np.array(v) for k, v in samples_dict.items()} - return samples_dict, unique_labels + return samples_dict, unique_labels, num_words class EncodedVectorFormer(Dataset): diff --git a/probing/ud_filter/filtering_probing.py b/probing/ud_filter/filtering_probing.py index d114353..0c71e19 100644 --- a/probing/ud_filter/filtering_probing.py +++ b/probing/ud_filter/filtering_probing.py @@ -41,7 +41,7 @@ def __init__(self, shuffle: bool = True): self.classes: Dict[ str, Tuple[Dict[str, Dict[str, Any]], Dict[Tuple[str, str], Dict[str, Any]]] ] = {} - self.probing_dict: Dict[str, List[str]] = {} + self.probing_dict: Dict[str, List[Tuple[str, List[int]]]] = {} self.parts_data: Dict[str, List[List[str]]] = {} def upload_files( @@ -91,10 +91,11 @@ def _filter_conllu(self, class_label: str) -> Tuple[List[str], List[str]]: for sentence in self.sentences: sf = SentenceFilter(sentence) tokenized_sentence = " ".join(wordpunct_tokenize(sentence.metadata["text"])) - if sf.filter_sentence(node_pattern, constraints): - matching.append(tokenized_sentence) + filter_result = sf.filter_sentence(node_pattern, constraints) + if filter_result is not False: + matching.append((tokenized_sentence, filter_result)) else: - not_matching.append(tokenized_sentence) + not_matching.append((tokenized_sentence, ())) return matching, not_matching def filter_and_convert( @@ -128,7 +129,7 @@ def filter_and_convert( matching, not_matching = self._filter_conllu(label) self.probing_dict[label] = matching if len(self.classes) == 1: - self.probing_dict["not_" + list(self.classes.keys())[0]] = not_matching + self.probing_dict["not_" + label] = not_matching self.probing_dict = delete_duplicates(self.probing_dict) self.parts_data = subsamples_split( diff --git a/probing/ud_filter/sentence_filter.py b/probing/ud_filter/sentence_filter.py index b43f253..2d626ce 100644 --- a/probing/ud_filter/sentence_filter.py +++ b/probing/ud_filter/sentence_filter.py @@ -189,7 +189,7 @@ def find_isomorphism(self) -> bool: k: {edges[i]} for i, k in enumerate(self.possible_token_pairs) } self.nodes_tokens = { - np[i]: [list(self.possible_token_pairs[np])[0][i]] + np[i]: list(self.possible_token_pairs[np])[0][i] for np in self.possible_token_pairs for i in range(2) } @@ -243,6 +243,6 @@ def filter_sentence( else: self.sent_deprels = self.all_deprels() if self.match_constraints(): - return True + return tuple(self.nodes_tokens.values()) else: return False diff --git a/probing/ud_filter/utils.py b/probing/ud_filter/utils.py index c8d6361..8f0bc84 100644 --- a/probing/ud_filter/utils.py +++ b/probing/ud_filter/utils.py @@ -49,7 +49,7 @@ def subsamples_split( if not probing_data: raise Exception("All classes have less sentences than the number of classes") parts = {} - data, labels = map(np.array, zip(*probing_data)) + data, labels = map(list, zip(*probing_data)) X_train, X_test, y_train, y_test = train_test_split( data, labels, @@ -58,19 +58,20 @@ def subsamples_split( shuffle=shuffle, random_state=random_seed, ) - if len(partition) == 2: parts = {split[0]: [X_train, y_train], split[1]: [X_test, y_test]} else: filtered_labels = filter_labels_after_split(y_test) if len(filtered_labels) >= 2: - X_train = X_train[np.isin(y_train, filtered_labels)] - y_train = y_train[np.isin(y_train, filtered_labels)] - X_test = X_test[np.isin(y_test, filtered_labels)] - y_test = y_test[np.isin(y_test, filtered_labels)] + train_mask = np.isin(y_train, filtered_labels) + X_train = [X_train[i] for i in range(len(train_mask)) if train_mask[i]] + y_train = [y_train[i] for i in range(len(train_mask)) if train_mask[i]] + test_mask = np.isin(y_test, filtered_labels) + X_test = [X_test[i] for i in range(len(test_mask)) if test_mask[i]] + y_test = [y_test[i] for i in range(len(test_mask)) if test_mask[i]] val_size = partition[1] / (1 - partition[0]) - if y_test.size != 0: + if len(y_test) != 0: X_val, X_test, y_val, y_test = train_test_split( X_test, y_test, @@ -118,8 +119,9 @@ def writer( with open(result_path, "w", encoding="utf-8") as newf: my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n") for part in partition_sets: - for sentence, value in zip(*partition_sets[part]): - my_writer.writerow([part, value, sentence]) + for sentence_and_ids, value in zip(*partition_sets[part]): + sentence, ids = sentence_and_ids + my_writer.writerow([part, value, sentence, ",".join([str(x) for x in ids])]) return result_path @@ -150,11 +152,11 @@ def determine_ud_savepath( def delete_duplicates(probing_dict: Dict[str, List[str]]) -> Dict[str, List[str]]: """Deletes sentences with more than one different classes of node_pattern found""" - all_sent = [s for cl_sent in probing_dict.values() for s in cl_sent] - duplicates = [item for item, count in Counter(all_sent).items() if count > 1] + all_sent = [sent for cl_sent in probing_dict.values() for sent, inds in cl_sent] + duplicates = {item for item, count in Counter(all_sent).items() if count > 1} new_probing_dict = {} for cl in probing_dict: - new_probing_dict[cl] = [s for s in probing_dict[cl] if s not in duplicates] + new_probing_dict[cl] = [(sent, ind) for sent, ind in probing_dict[cl] if sent not in duplicates] return new_probing_dict diff --git a/probing/ud_parser/ud_parser.py b/probing/ud_parser/ud_parser.py index 0260b15..6bf7bcb 100644 --- a/probing/ud_parser/ud_parser.py +++ b/probing/ud_parser/ud_parser.py @@ -65,8 +65,9 @@ def writer( with open(result_path, "w", encoding="utf-8") as newf: my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n") for part in partition_sets: - for sentence, value in zip(*partition_sets[part]): - my_writer.writerow([part, value, sentence]) + for sentence_and_id, value in zip(*partition_sets[part]): + sentence, id = sentence_and_id + my_writer.writerow([part, value, sentence, id]) return result_path def find_category_token( @@ -134,7 +135,8 @@ def classify( ) ): value = category_token["feats"][category] - probing_data[value].append(s_text) + token_id = category_token["id"] - 1 + probing_data[value].append((s_text, token_id)) elif self.sorting == "by_pos_and_deprel": pos, deprel = subcategory.split("_") if ( @@ -142,7 +144,8 @@ def classify( and category_token["deprel"] == deprel ): value = category_token["feats"][category] - probing_data[value].append(s_text) + token_id = category_token["id"] - 1 + probing_data[value].append((s_text, token_id)) return probing_data def filter_labels_after_split(self, labels: List[Any]) -> List[Any]: