From f6b64b88b3157b7bb71aa3708063067c69ad5326 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 8 Jul 2022 13:17:53 -0700 Subject: [PATCH 01/28] prefix trie --- catwalk/utils/prefix_trie.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 catwalk/utils/prefix_trie.py diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py new file mode 100644 index 00000000..bb959026 --- /dev/null +++ b/catwalk/utils/prefix_trie.py @@ -0,0 +1,46 @@ +from typing import List, Sequence, Set +from typing import Dict + + +class PrefixTrie(): + def __init__(self, sequences: Sequence[Sequence[int]]): + self.root = PrefixTrieNode(parent=None, token=None) + self.nodes = [] + for i, sequence in enumerate(sequences): + self.add_sequence(sequence=sequence, index=i) + + def add_sequence(self, sequence: Sequence[int], index: int): + current_node = self.root + for token in sequence: + if token not in current_node.children: + current_node.children[token] = PrefixTrieNode(parent=current_node, token=token) + self.nodes.append(current_node.children[token]) + current_node = current_node.children[token] + current_node.indices.append(index) + + def get_leaf_nodes(self) -> List['PrefixTrieNode']: + return [node for node in self.nodes if len(node.children) == 0] + +class PrefixTrieNode(): + def __init__(self, parent: 'PrefixTrieNode', token: int): + self.parent = parent + self.token = token + self.indices: List[int] = [] + self.children: Dict[int,'PrefixTrieNode'] = {} + + def get_sequence(self) -> List[int]: + if self.parent is not None: + sequence = self.parent.get_sequence() + sequence.append(self.token) + return sequence + else: + return [] + + def get_prefix_indices(self) -> Set[int]: + """Returns all indices for subsequences of the current node including itself""" + if self.parent is not None: + sequence = self.parent.get_prefix_indices() + sequence.update(self.indices) + return sequence + else: + return set() \ No newline at end of file From f1b6534937b42fd2911cc2bc4e75adb94443c084 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 8 Jul 2022 13:18:05 -0700 Subject: [PATCH 02/28] prefix trie --- catwalk/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 catwalk/utils/__init__.py diff --git a/catwalk/utils/__init__.py b/catwalk/utils/__init__.py new file mode 100644 index 00000000..1988e6d8 --- /dev/null +++ b/catwalk/utils/__init__.py @@ -0,0 +1 @@ +from catwalk.utils.prefix_trie import PrefixTrie \ No newline at end of file From ab2314cdb817dc55803d8fb62a504b8ccf083916 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 15 Jul 2022 11:35:00 -0700 Subject: [PATCH 03/28] generalized caching --- catwalk/models/metaicl.py | 3 +- catwalk/models/rank_classification.py | 226 ++++++++++++++++++++++---- catwalk/utils/prefix_trie.py | 67 +++++--- 3 files changed, 245 insertions(+), 51 deletions(-) diff --git a/catwalk/models/metaicl.py b/catwalk/models/metaicl.py index f808c074..2cb11c9b 100644 --- a/catwalk/models/metaicl.py +++ b/catwalk/models/metaicl.py @@ -16,11 +16,12 @@ def __init__( self, pretrained_model_name_or_path: str, *, + prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', example_seperator: str = '\n\n\n' ): - super().__init__(pretrained_model_name_or_path) + super().__init__(pretrained_model_name_or_path, prefix_caching=prefix_caching) self.max_length_per_example = max_length_per_example self.continuation_seperator = continuation_seperator self.example_seperator = example_seperator diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 261e5211..09ca9da5 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -1,5 +1,5 @@ import collections -from typing import Dict, Any, List, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast +from typing import Dict, Any, List, OrderedDict, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast import more_itertools import torch @@ -8,11 +8,12 @@ from torch import log_softmax from torch.nn.utils.rnn import pad_sequence from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, GPT2LMHeadModel, \ - AutoTokenizer, GPT2Tokenizer, T5TokenizerFast + AutoTokenizer, GPT2Tokenizer, T5TokenizerFast, BatchEncoding from catwalk import cached_transformers from catwalk.model import Model from catwalk.task import Task, InstanceFormat, RankClassificationInstance +from catwalk.utils import PrefixTrie _Model = Union[T5ForConditionalGeneration, GPT2LMHeadModel] _Tokenizer = Union[T5TokenizerFast, GPT2Tokenizer] @@ -171,6 +172,11 @@ def _run_loglikelihood( @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): + def __init__(self, pretrained_model_name_or_path: str, *, prefix_caching: bool = False): + super().__init__(pretrained_model_name_or_path) + self.prefix_caching = prefix_caching + self._reset_cache_variables() + @classmethod def _make_model(cls, pretrained_model_name_or_path: str) -> GPT2LMHeadModel: return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, False) @@ -185,6 +191,14 @@ def _run_loglikelihood( tokenized_contexts = tokenizer([t[0] for t in tuples]) tokenized_continuations = tokenizer([t[1] for t in tuples]) + self._final_truncatation( + tokenized_contexts, tokenized_continuations, tokenizer.model_max_length + ) + + ordered_indices = self._reorder_instances( + tokenized_contexts, tokenized_continuations + ) + # transpose the token ids so we can access them one instance at a time cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = [] assert tokenized_contexts.keys() == tokenized_continuations.keys() @@ -202,14 +216,6 @@ def _run_loglikelihood( torch.tensor(continuation, dtype=torch.long) ) - # find out the order to process sequences in - lengths = torch.tensor([ - len(cc_pair["input_ids"][0]) + len(cc_pair["input_ids"][1]) - for cc_pair in cc_pairs - ], dtype=torch.int) - ordered_indices = torch.argsort(lengths, descending=True) - del lengths - # actually do the processing results: List[Optional[float]] = [None] * len(ordered_indices) with torch.inference_mode(): @@ -217,31 +223,191 @@ def _run_loglikelihood( Tqdm.tqdm(ordered_indices, desc="Running log-likelihood queries"), batch_size) for batch_of_indices in batches_of_indices: - unpadded_batch = collections.defaultdict(list) - input_lengths = [] - batch_contexts = [] - batch_continuations = [] - for index in batch_of_indices: - for field_name, (context_ids, continuation_ids) in cc_pairs[index].items(): - ids = torch.cat([context_ids, continuation_ids]) - ids = ids[-(tokenizer.model_max_length+1):][:-1] - unpadded_batch[field_name].append(ids) - - input_lengths.append(len(unpadded_batch["input_ids"][-1])) - batch_contexts.append(cc_pairs[index]["input_ids"][0]) - batch_continuations.append(cc_pairs[index]["input_ids"][1]) - - padded_batch = { - field_name: pad_sequence(tensors, batch_first=True).to(model.device) - for field_name, tensors in unpadded_batch.items() - } - - batch_logits = log_softmax(model(**padded_batch)[0], dim=-1).cpu() + inputs, input_lengths, batch_contexts, batch_continuations = self._get_inputs(batch_of_indices, cc_pairs, model) + batch_logits = log_softmax(model(**inputs)[0], dim=-1).cpu() z = zip(batch_of_indices, batch_logits, input_lengths, batch_contexts, batch_continuations) for i, instance_logits, input_length, instance_context, instance_continuation in z: + assert input_length-len(instance_continuation) >=0 instance_logits = instance_logits[input_length-len(instance_continuation):input_length] instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1)) results[i] = float(instance_logits.sum()) / len(tuples[i][1]) assert None not in results return cast(Sequence[float], results) + + def _final_truncatation(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, model_max_length: int): + """ Apply a last pass of truncation on the concatenated inputs to make sure it fits in the model_max_length""" + assert len(tokenized_contexts['input_ids']) == len(tokenized_continuations['input_ids']) + for i in range(len(tokenized_contexts['input_ids'])): + context_len = len(tokenized_contexts['input_ids'][i]) + cont_len = len(tokenized_continuations['input_ids'][i]) + assert cont_len < model_max_length + if context_len + cont_len > model_max_length: + tokenized_contexts['input_ids'][i] = tokenized_contexts['input_ids'][i][-model_max_length + cont_len:] + tokenized_contexts['attention_mask'][i] = tokenized_contexts['attention_mask'][i][-model_max_length + cont_len:] + + def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: + if self.prefix_caching: + return self._reorder_by_prefix(tokenized_contexts, tokenized_continuations) + else: + return self._reorder_by_longest(tokenized_contexts, tokenized_continuations) + + def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: + self._reset_cache_variables() + combined_ids = [context + continuation for context, continuation in zip(tokenized_contexts['input_ids'], tokenized_continuations['input_ids'])] + self.longest_prefix_to_indices = self._order_by_common_prefix(combined_ids) + self.indices_to_longest_prefix = OrderedDict() + for prefix in sorted(self.longest_prefix_to_indices.keys(), key = lambda x : -len(x)): + # indices for each prefix are already sorted by trie + for index in self.longest_prefix_to_indices[prefix]: + self.indices_to_longest_prefix[index] = prefix + return list(self.indices_to_longest_prefix.keys()) + + def _order_by_common_prefix(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[int],Sequence[int]]: + longest_prefix_to_indices: Dict[Sequence[int],Sequence[int]] = {} + trie = PrefixTrie(sequences) + leaves = trie.get_leaf_nodes() + leaves_sequences = [tuple(leaf.get_sequence()) for leaf in leaves] + leaves_and_sequences = Tqdm.tqdm(zip(leaves_sequences, leaves), desc="Finding prefixes", total=len(leaves)) + leaves2prefixes = {leaf_sequence:leaf.get_prefix_indices() for leaf_sequence, leaf in leaves_and_sequences} + + indices_already_assigned = set() + for leaf_sequence in sorted(leaves_sequences, key=lambda leaf_sequence : -leaves2prefixes[leaf_sequence][1]): + prefix_indices, _ = leaves2prefixes[leaf_sequence] + prefix_indices = [prefix_index for prefix_index in prefix_indices if prefix_index not in indices_already_assigned] + indices_already_assigned.update(prefix_indices) + if len(prefix_indices) > 0: + longest_prefix_to_indices[leaf_sequence] = tuple(prefix_indices) + + return longest_prefix_to_indices + + def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: + assert len(tokenized_contexts['input_ids']) == len(tokenized_continuations['input_ids']) + lengths = torch.tensor([ + len(tokenized_contexts["input_ids"][i]) + len(tokenized_continuations["input_ids"][i]) + for i in range(len(tokenized_contexts['input_ids'])) + ], dtype=torch.int) + return torch.argsort(lengths, descending=True).tolist() + + def _reset_cache_variables(self): + self.cached_sequence: Sequence[int] = None + self.cached_past_key_values: Tuple[Tuple[torch.Tensor]] = None + self.longest_prefix_to_indices: Dict[Sequence[int],Sequence[int]] = None + self.indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None + + def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): + if self.prefix_caching: + return self._get_inputs_with_cache(batch_of_indices, cc_pairs, model) + else: + return self._get_inputs_without_cache(batch_of_indices, cc_pairs, model) + + + def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): + prefixes = [self.indices_to_longest_prefix[index] for index in batch_of_indices] + prefix2cache = OrderedDict() + + # compute prefixes + for prefix in set(prefixes): + if prefix == self.cached_sequence: + past_key_values = self.cached_past_key_values + else: + past_key_values = model(input_ids=torch.tensor(prefix).to(model.device)).past_key_values + # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) + past_key_values = torch.stack(tuple(torch.stack(past_key_values[i]) for i in range(len(past_key_values)))) + # tensor(layers, keys/values, num_heads, sequence_len, embed_size_per_head) + past_key_values = past_key_values.squeeze(2) + prefix2cache[prefix] = past_key_values + + # update cache with last one retrieved since instances come in order by common prefix + self.cached_sequence, self.cached_past_key_values = list(prefix2cache.items())[-1] + + # pad and mask batched past_key_values + unpadded_past_keys_values = [prefix2cache[prefix] for prefix in prefixes] + unpadded_past_keys_values_attn_mask = [] + # only use the prefixed part of past_key_values that is present in the instance + for prefix_idx, cc_pairs_idx in enumerate(batch_of_indices): + is_identical = True + for tok_idx, tok in enumerate(cc_pairs[cc_pairs_idx]['input_ids'][0]): + if tok.item() != prefixes[prefix_idx][tok_idx]: + unpadded_past_keys_values_attn_mask.append(torch.tensor([1] * tok_idx, dtype=torch.int64)) + is_identical = False + break + if is_identical: + # Avoid empty input by leaving last token of context for input because continuations drop one token for right shift + max_prefix_len = len(cc_pairs[cc_pairs_idx]['input_ids'][0]) - 1 + unpadded_past_keys_values_attn_mask.append(torch.tensor([1] * max_prefix_len, dtype=torch.int64)) + + # past_keys_values needs its own attention mask + padded_past_keys_values_attn_mask = pad_sequence(unpadded_past_keys_values_attn_mask, batch_first=True, padding_value=0) + cache_lengths = [mask.sum().item() for mask in padded_past_keys_values_attn_mask] + max_past_key_value_len = max(cache_lengths) + + # pad and truncate past_keys_values to longest actually used + unpadded_past_keys_values = [t.transpose(0,-2) for t in unpadded_past_keys_values] + padded_past_keys_values = pad_sequence(unpadded_past_keys_values, batch_first=True) + padded_past_keys_values = padded_past_keys_values.permute((4, 2, 0, 3, 1, 5)) + # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) + padded_past_keys_values = padded_past_keys_values[:,:,:,:,:max_past_key_value_len] + + # make input_ids by removing whatever parts of past_key_values are present + unpadded_input_ids = [] + input_lengths = [] + batch_contexts = [] + batch_continuations = [] + + for prefix_idx, cc_pairs_idx in enumerate(batch_of_indices): + context_ids, continuation_ids = cc_pairs[cc_pairs_idx]['input_ids'] + ids = torch.cat([context_ids, continuation_ids])[:-1] + ids = ids[cache_lengths[prefix_idx]:] + + # used to find logits specifically for continuation + input_lengths.append(len(ids)) + batch_contexts.append(cc_pairs[cc_pairs_idx]["input_ids"][0]) + batch_continuations.append(cc_pairs[cc_pairs_idx]["input_ids"][1]) + + unpadded_input_ids.append(ids) + + # batch and pad and make attention mask + unpadded_attn_mask = [torch.ones_like(t) for t in unpadded_input_ids] + padded_attn_mask = pad_sequence(unpadded_attn_mask, batch_first=True, padding_value=0) + padded_input_ids = pad_sequence(unpadded_input_ids, batch_first=True) + + # combine the attention masks + full_attn_mask = torch.cat((padded_past_keys_values_attn_mask, padded_attn_mask), dim=1) + assert full_attn_mask.shape[1] <= model.config.n_positions, "Presently batches with wide range of prefix and input lengths are not supported due overrun of max model size" + + # make position_ids + max_input_len = padded_input_ids.shape[-1] + position_ids = torch.stack([torch.arange(cache_length, cache_length + max_input_len) for cache_length in cache_lengths], dim=0) + position_ids = position_ids * padded_attn_mask + assert (position_ids < model.config.n_positions).all() + + inputs = { + 'input_ids': padded_input_ids.to(model.device), + 'past_key_values': padded_past_keys_values, + 'attention_mask': full_attn_mask.to(model.device), + 'position_ids': position_ids.to(model.device) + } + + return inputs, input_lengths, batch_contexts, batch_continuations + + def _get_inputs_without_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): + unpadded_batch = collections.defaultdict(list) + input_lengths = [] + batch_contexts = [] + batch_continuations = [] + + for index in batch_of_indices: + for field_name, (context_ids, continuation_ids) in cc_pairs[index].items(): + ids = torch.cat([context_ids, continuation_ids])[:-1] + unpadded_batch[field_name].append(ids) + + input_lengths.append(len(unpadded_batch["input_ids"][-1])) + batch_contexts.append(cc_pairs[index]["input_ids"][0]) + batch_continuations.append(cc_pairs[index]["input_ids"][1]) + + padded_batch = { + field_name: pad_sequence(tensors, batch_first=True).to(model.device) + for field_name, tensors in unpadded_batch.items() + } + return padded_batch, input_lengths, batch_contexts, batch_continuations \ No newline at end of file diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index bb959026..1e7988de 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -1,22 +1,31 @@ -from typing import List, Sequence, Set -from typing import Dict +from typing import List, Sequence, Tuple, Dict +from tango.common import Tqdm class PrefixTrie(): - def __init__(self, sequences: Sequence[Sequence[int]]): + def __init__(self, sequences: Sequence[Sequence[int]], minimum_prefix: int = 10): self.root = PrefixTrieNode(parent=None, token=None) + self.minimum_prefix = minimum_prefix self.nodes = [] - for i, sequence in enumerate(sequences): + for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)): self.add_sequence(sequence=sequence, index=i) + # remove all indices and lenghts_covered at non-forking, non-leaf nodes + for node in self.nodes: + if len(node.children) == 1: + node.indices = [] + node.lengths_covered = [] def add_sequence(self, sequence: Sequence[int], index: int): + assert len(sequence) >= self.minimum_prefix, f"sequence with len {len(sequence)} too small for PrefixTrie with minimum_prefix {self.minimum_prefix}" current_node = self.root - for token in sequence: + for token_idx, token in enumerate(sequence): if token not in current_node.children: current_node.children[token] = PrefixTrieNode(parent=current_node, token=token) self.nodes.append(current_node.children[token]) current_node = current_node.children[token] - current_node.indices.append(index) + if token_idx + 1 >= self.minimum_prefix: + current_node.indices.append(index) + current_node.lengths_covered.append(token_idx + 1) def get_leaf_nodes(self) -> List['PrefixTrieNode']: return [node for node in self.nodes if len(node.children) == 0] @@ -26,21 +35,39 @@ def __init__(self, parent: 'PrefixTrieNode', token: int): self.parent = parent self.token = token self.indices: List[int] = [] + self.lengths_covered: List[int] = [] self.children: Dict[int,'PrefixTrieNode'] = {} def get_sequence(self) -> List[int]: - if self.parent is not None: - sequence = self.parent.get_sequence() - sequence.append(self.token) - return sequence - else: - return [] + current_node = self + sequence = [] + while current_node.parent is not None: + sequence.append(current_node.token) + current_node = current_node.parent + return sequence[::-1] - def get_prefix_indices(self) -> Set[int]: - """Returns all indices for subsequences of the current node including itself""" - if self.parent is not None: - sequence = self.parent.get_prefix_indices() - sequence.update(self.indices) - return sequence - else: - return set() \ No newline at end of file + def get_prefix_indices(self) -> Tuple[List[int], int]: + """Returns all indices for subsequences of the current node including itself starting with longest and decreasing""" + current_node = self + indices = [] + already_found = set() + total_lengths_covered = 0 + while current_node.parent is not None: + new_indices = [] + for index, length_covered in zip(current_node.indices, current_node.lengths_covered): + if index not in already_found: + new_indices.append(index) + total_lengths_covered += length_covered + already_found.update(new_indices) + indices.extend(new_indices) + current_node = current_node.parent + return indices, total_lengths_covered + +if __name__ == "__main__": + sequences = [[1,2,3],[2,3,4],[1,2,3,4]] + trie = PrefixTrie(sequences, minimum_prefix=1) + leaves = trie.get_leaf_nodes() + assert leaves[0].get_sequence() == [2,3,4] + assert leaves[1].get_sequence() == [1,2,3,4] + assert leaves[0].get_prefix_indices() == ([1], 3) + assert leaves[1].get_prefix_indices() == ([2,0], 7) \ No newline at end of file From 235e63c05a162c2ab9b037c4b60521170f92e178 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 15 Jul 2022 11:45:45 -0700 Subject: [PATCH 04/28] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cb2753b..2bacb936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - SQuAD and SQuADShifts tasks - Adds a new MetaICLTask that supports the evaluation classification tasks in that benchmark - Adds a new MetaICLModel that replicates the formatting and truncation used by MetaICL for few shot evaluation +- Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them ### Fixed From 2c8655439a418f72d9ac28d33ff1897d73e1dcca Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 15 Jul 2022 12:16:40 -0700 Subject: [PATCH 05/28] fix type hints --- catwalk/models/rank_classification.py | 2 +- catwalk/utils/prefix_trie.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 09ca9da5..c1dcb1c3 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -291,7 +291,7 @@ def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_conti def _reset_cache_variables(self): self.cached_sequence: Sequence[int] = None - self.cached_past_key_values: Tuple[Tuple[torch.Tensor]] = None + self.cached_past_key_values: torch.Tensor = None self.longest_prefix_to_indices: Dict[Sequence[int],Sequence[int]] = None self.indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index 1e7988de..2dbdd94d 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -4,9 +4,9 @@ class PrefixTrie(): def __init__(self, sequences: Sequence[Sequence[int]], minimum_prefix: int = 10): - self.root = PrefixTrieNode(parent=None, token=None) + self.root = PrefixTrieNode() self.minimum_prefix = minimum_prefix - self.nodes = [] + self.nodes: List['PrefixTrieNode'] = [] for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)): self.add_sequence(sequence=sequence, index=i) # remove all indices and lenghts_covered at non-forking, non-leaf nodes @@ -31,7 +31,7 @@ def get_leaf_nodes(self) -> List['PrefixTrieNode']: return [node for node in self.nodes if len(node.children) == 0] class PrefixTrieNode(): - def __init__(self, parent: 'PrefixTrieNode', token: int): + def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None): self.parent = parent self.token = token self.indices: List[int] = [] From fef986d4952ac8d99f61ffd781dd29641e0f56d8 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 15 Jul 2022 13:07:18 -0700 Subject: [PATCH 06/28] fix type hints --- catwalk/utils/prefix_trie.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index 2dbdd94d..8fa2243e 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Tuple, Dict +from typing import List, Optional, Sequence, Tuple, Dict from tango.common import Tqdm @@ -38,7 +38,7 @@ def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None): self.lengths_covered: List[int] = [] self.children: Dict[int,'PrefixTrieNode'] = {} - def get_sequence(self) -> List[int]: + def get_sequence(self) -> List[Optional[int]]: current_node = self sequence = [] while current_node.parent is not None: From bb150b5a9b8e67d139d8511ff181301e2083573d Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 15 Jul 2022 13:38:15 -0700 Subject: [PATCH 07/28] fix type hints --- catwalk/models/rank_classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index c1dcb1c3..4a1d0260 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -263,8 +263,8 @@ def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_contin self.indices_to_longest_prefix[index] = prefix return list(self.indices_to_longest_prefix.keys()) - def _order_by_common_prefix(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[int],Sequence[int]]: - longest_prefix_to_indices: Dict[Sequence[int],Sequence[int]] = {} + def _order_by_common_prefix(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[Optional[int]],Sequence[int]]: + longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = {} trie = PrefixTrie(sequences) leaves = trie.get_leaf_nodes() leaves_sequences = [tuple(leaf.get_sequence()) for leaf in leaves] @@ -290,9 +290,9 @@ def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_conti return torch.argsort(lengths, descending=True).tolist() def _reset_cache_variables(self): - self.cached_sequence: Sequence[int] = None + self.cached_sequence: Sequence[Optional[int]] = None self.cached_past_key_values: torch.Tensor = None - self.longest_prefix_to_indices: Dict[Sequence[int],Sequence[int]] = None + self.longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = None self.indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): From 2b993e83ed61d04ee745a78b1e779ee9e27594a6 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 22 Jul 2022 16:48:51 -0700 Subject: [PATCH 08/28] relax min prefix requirement --- catwalk/utils/prefix_trie.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index 8fa2243e..6cfdec83 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -3,9 +3,22 @@ class PrefixTrie(): - def __init__(self, sequences: Sequence[Sequence[int]], minimum_prefix: int = 10): + def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = 10): + """ + Returns a PrefixTrie for ordering examples by common prefixes + + # Parameters + + sequences : `Sequence[Sequence[int]]` + Sequences of tokens to add to the add to the Trie + track_after_depth : `int` + Only record sequence indices in nodes at or below this depth. This allows distinct + sequences that coincidentally start with the first few tokens as another sequence + not to be dropped from this barely overlapping prefix. Sequences shorter than the + minimum depth will only have their index recorded in their final node. + """ self.root = PrefixTrieNode() - self.minimum_prefix = minimum_prefix + self.track_after_depth = track_after_depth self.nodes: List['PrefixTrieNode'] = [] for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)): self.add_sequence(sequence=sequence, index=i) @@ -16,14 +29,14 @@ def __init__(self, sequences: Sequence[Sequence[int]], minimum_prefix: int = 10) node.lengths_covered = [] def add_sequence(self, sequence: Sequence[int], index: int): - assert len(sequence) >= self.minimum_prefix, f"sequence with len {len(sequence)} too small for PrefixTrie with minimum_prefix {self.minimum_prefix}" + seq_len = len(sequence) current_node = self.root for token_idx, token in enumerate(sequence): if token not in current_node.children: current_node.children[token] = PrefixTrieNode(parent=current_node, token=token) self.nodes.append(current_node.children[token]) current_node = current_node.children[token] - if token_idx + 1 >= self.minimum_prefix: + if (token_idx + 1 >= self.track_after_depth) or (token_idx + 1 >= seq_len): current_node.indices.append(index) current_node.lengths_covered.append(token_idx + 1) @@ -65,7 +78,7 @@ def get_prefix_indices(self) -> Tuple[List[int], int]: if __name__ == "__main__": sequences = [[1,2,3],[2,3,4],[1,2,3,4]] - trie = PrefixTrie(sequences, minimum_prefix=1) + trie = PrefixTrie(sequences, track_after_depth=1) leaves = trie.get_leaf_nodes() assert leaves[0].get_sequence() == [2,3,4] assert leaves[1].get_sequence() == [1,2,3,4] From 7138a13d52fc6467b0a4222fdc8ca4e41c534967 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 22 Jul 2022 16:51:02 -0700 Subject: [PATCH 09/28] fix likelihood averaging --- catwalk/models/rank_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 4a1d0260..2cbd7ff7 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -230,7 +230,7 @@ def _run_loglikelihood( assert input_length-len(instance_continuation) >=0 instance_logits = instance_logits[input_length-len(instance_continuation):input_length] instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1)) - results[i] = float(instance_logits.sum()) / len(tuples[i][1]) + results[i] = float(instance_logits.sum()) / len(tokenized_continuations.input_ids[i]) assert None not in results return cast(Sequence[float], results) From 7a4ddd29fe4ae5902da1f1ff872ee1ad636f330e Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 22 Jul 2022 17:53:22 -0700 Subject: [PATCH 10/28] fix cached transformer override_weights_file --- CHANGELOG.md | 1 + catwalk/cached_transformers.py | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bacb936..e4b1fcb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed bug causing few-shot to use more than specified number of shots +- Fixed bug in cached_transformer.get() that prevented using override_weights_file arg ## [v0.1.0](https://github.com/allenai/catwalk/releases/tag/v0.1.0) - 2022-06-10 diff --git a/catwalk/cached_transformers.py b/catwalk/cached_transformers.py index cb00fd32..4fe7bde4 100644 --- a/catwalk/cached_transformers.py +++ b/catwalk/cached_transformers.py @@ -98,11 +98,9 @@ def strip_prefix(s): ) override_weights = {strip_prefix(k): override_weights[k] for k in valid_keys} - transformer = cls.from_config( # type: ignore - cls.from_pretrained( # type: ignore - model_name, - **kwargs, - ) + transformer = cls.from_pretrained( # type: ignore + model_name, + **kwargs, ) # When DistributedDataParallel or DataParallel is used, the state dict of the # DistributedDataParallel/DataParallel wrapper prepends "module." to all parameters From 0c5dcb137123381bea0c8af8a070f8f6dbb61e70 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 11:33:06 -0700 Subject: [PATCH 11/28] optional random_subsample_seed for PredictStep --- catwalk/steps.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/catwalk/steps.py b/catwalk/steps.py index 532022b5..8341edcd 100644 --- a/catwalk/steps.py +++ b/catwalk/steps.py @@ -1,5 +1,6 @@ from typing import Union, Dict, Any, Optional, Sequence, Iterable from collections import defaultdict +from random import Random from tango import Step, JsonFormat from tango.common.sequences import SqliteSparseSequence @@ -33,6 +34,7 @@ def run( task: Union[str, Task], split: Optional[str] = None, limit: Optional[int] = None, + random_subsample_seed: Optional[int] = None, **kwargs ) -> Sequence[Any]: if isinstance(model, str): @@ -44,8 +46,8 @@ def run( results = SqliteSparseSequence(self.work_dir_for_run / "result.sqlite") instances = task.get_split(split) - if limit is not None: - instances = instances[:limit] + if limit is not None and len(instances) > limit: + instances = instances[:limit] if random_subsample_seed is None else Random(random_subsample_seed).sample(instances, limit) instances = instances[len(results):] for result in model.predict(task, instances, **kwargs): results.append(result) From 39b4f07863bc83a52cb9e62ae72def6abe0d3f1a Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 11:34:09 -0700 Subject: [PATCH 12/28] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4b1fcb7..9d8f743b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adds a new MetaICLTask that supports the evaluation classification tasks in that benchmark - Adds a new MetaICLModel that replicates the formatting and truncation used by MetaICL for few shot evaluation - Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them +- Option random_subsample_seed for PrefixCache ### Fixed From 9e00fbc370d457bdac692acda1042effe83cddc1 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 11:51:56 -0700 Subject: [PATCH 13/28] allow different metric averaging --- catwalk/task.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/catwalk/task.py b/catwalk/task.py index 323f2abc..303e53f9 100644 --- a/catwalk/task.py +++ b/catwalk/task.py @@ -38,12 +38,12 @@ "squad_metrics": torchmetrics.SQuAD, } -def classification_metrics(num_classes: int): +def classification_metrics(num_classes: int, *, average = None): return { "acc": torchmetrics.Accuracy, - "f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=None), - "precision": partial(torchmetrics.Precision, num_classes=num_classes, average=None), - "recall": partial(torchmetrics.Recall, num_classes=num_classes, average=None) + "f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=average), + "precision": partial(torchmetrics.Precision, num_classes=num_classes, average=average), + "recall": partial(torchmetrics.Recall, num_classes=num_classes, average=average) } From 054847fe5fe13e5ad29e1eaf36f2a17a26e85471 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 12:12:50 -0700 Subject: [PATCH 14/28] expose override_weights_file in RC models --- catwalk/models/rank_classification.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 2cbd7ff7..9cb68c3e 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -22,11 +22,12 @@ class RankClassificationModel(Model): VERSION = "001nul" - def __init__(self, pretrained_model_name_or_path: str): + def __init__(self, pretrained_model_name_or_path: str, override_weights_file: str = None): self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.override_weights_file = override_weights_file @classmethod - def _make_model(cls, pretrained_model_name_or_path: str) -> _Model: + def _make_model(cls, pretrained_model_name_or_path: str, override_weights_file: str = None) -> _Model: raise NotImplementedError def predict( # type: ignore @@ -40,7 +41,7 @@ def predict( # type: ignore fewshot_seed: int = None ) -> Iterator[Dict[str, Any]]: device = resolve_device() - model = self._make_model(self.pretrained_model_name_or_path).to(device).eval() + model = self._make_model(self.pretrained_model_name_or_path, self.override_weights_file).to(device).eval() tokenizer = cached_transformers.get_tokenizer(AutoTokenizer, self.pretrained_model_name_or_path) for instance_chunk in more_itertools.chunked(instances, max_instances_in_memory): @@ -109,8 +110,8 @@ def _run_loglikelihood( @Model.register("rc::encoder_decoder") class EncoderDecoderRCModel(RankClassificationModel): @classmethod - def _make_model(cls, pretrained_model_name_or_path: str) -> T5ForConditionalGeneration: - return cached_transformers.get(AutoModelForSeq2SeqLM, pretrained_model_name_or_path, False) + def _make_model(cls, pretrained_model_name_or_path: str, override_weights_file: str = None) -> T5ForConditionalGeneration: + return cached_transformers.get(AutoModelForSeq2SeqLM, pretrained_model_name_or_path, False, override_weights_file=override_weights_file) def _run_loglikelihood( self, @@ -172,14 +173,14 @@ def _run_loglikelihood( @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): - def __init__(self, pretrained_model_name_or_path: str, *, prefix_caching: bool = False): - super().__init__(pretrained_model_name_or_path) + def __init__(self, pretrained_model_name_or_path: str, *, override_weights_file: str = None, prefix_caching: bool = False): + super().__init__(pretrained_model_name_or_path, override_weights_file=override_weights_file) self.prefix_caching = prefix_caching self._reset_cache_variables() @classmethod - def _make_model(cls, pretrained_model_name_or_path: str) -> GPT2LMHeadModel: - return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, False) + def _make_model(cls, pretrained_model_name_or_path: str, override_weights_file: str = None) -> GPT2LMHeadModel: + return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, False, override_weights_file=override_weights_file) def _run_loglikelihood( self, From d19deb41899fd767f21119b8bff92ca5946c78ba Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 12:14:08 -0700 Subject: [PATCH 15/28] expose override_weights_file in MetaICLModel --- catwalk/models/metaicl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/catwalk/models/metaicl.py b/catwalk/models/metaicl.py index 2cb11c9b..34729289 100644 --- a/catwalk/models/metaicl.py +++ b/catwalk/models/metaicl.py @@ -19,9 +19,10 @@ def __init__( prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', - example_seperator: str = '\n\n\n' + example_seperator: str = '\n\n\n', + override_weights_file: str = None ): - super().__init__(pretrained_model_name_or_path, prefix_caching=prefix_caching) + super().__init__(pretrained_model_name_or_path, override_weights_file=override_weights_file, prefix_caching=prefix_caching) self.max_length_per_example = max_length_per_example self.continuation_seperator = continuation_seperator self.example_seperator = example_seperator From 5dd423daf15d8aa7b5f7b5105c54e0086f17b143 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 12:43:39 -0700 Subject: [PATCH 16/28] oops deleted import by accident --- catwalk/models/rank_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 205a9d71..bd770a61 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -1,3 +1,4 @@ +import collections from typing import Dict, Any, List, OrderedDict, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast, Callable import more_itertools From 8fbf35994f3841ae1e5be7dbade2d92790afa2f9 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 15:37:09 -0700 Subject: [PATCH 17/28] move cache data out of model attributes --- catwalk/models/rank_classification.py | 56 ++++++++++++++------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index bd770a61..4356aa62 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -1,4 +1,5 @@ import collections +from dataclasses import dataclass from typing import Dict, Any, List, OrderedDict, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast, Callable import more_itertools @@ -235,12 +236,18 @@ def _run_loglikelihood( return cast(Sequence[float], results) +@dataclass +class CacheData: + cached_sequence: Sequence[Optional[int]] = None + cached_past_key_values: torch.Tensor = None + longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = None + indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None + @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): def __init__(self, pretrained_model_name_or_path: str, *, override_weights_file: str = None, prefix_caching: bool = False): super().__init__(pretrained_model_name_or_path, override_weights_file=override_weights_file) self.prefix_caching = prefix_caching - self._reset_cache_variables() @classmethod def _make_model(cls, pretrained_model_name_or_path: str, override_weights_file: str = None) -> GPT2LMHeadModel: @@ -253,6 +260,8 @@ def _run_loglikelihood( tokenizer: _Tokenizer, batch_size: int = 32, ) -> Sequence[float]: + cache = CacheData() if self.prefix_caching else None + tokenized_contexts = tokenizer([t[0] for t in tuples]) tokenized_continuations = tokenizer([t[1] for t in tuples]) @@ -261,7 +270,7 @@ def _run_loglikelihood( ) ordered_indices = self._reorder_instances( - tokenized_contexts, tokenized_continuations + tokenized_contexts, tokenized_continuations, cache ) # transpose the token ids so we can access them one instance at a time @@ -288,7 +297,7 @@ def _run_loglikelihood( Tqdm.tqdm(ordered_indices, desc="Running log-likelihood queries"), batch_size) for batch_of_indices in batches_of_indices: - inputs, input_lengths, batch_contexts, batch_continuations = self._get_inputs(batch_of_indices, cc_pairs, model) + inputs, input_lengths, batch_contexts, batch_continuations = self._get_inputs(batch_of_indices, cc_pairs, model, cache) batch_logits = log_softmax(model(**inputs)[0], dim=-1).cpu() z = zip(batch_of_indices, batch_logits, input_lengths, batch_contexts, batch_continuations) for i, instance_logits, input_length, instance_context, instance_continuation in z: @@ -311,22 +320,22 @@ def _final_truncatation(self, tokenized_contexts: BatchEncoding, tokenized_conti tokenized_contexts['input_ids'][i] = tokenized_contexts['input_ids'][i][-model_max_length + cont_len:] tokenized_contexts['attention_mask'][i] = tokenized_contexts['attention_mask'][i][-model_max_length + cont_len:] - def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: + def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData = None) -> Sequence[int]: if self.prefix_caching: - return self._reorder_by_prefix(tokenized_contexts, tokenized_continuations) + return self._reorder_by_prefix(tokenized_contexts, tokenized_continuations, cache) else: return self._reorder_by_longest(tokenized_contexts, tokenized_continuations) - def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: - self._reset_cache_variables() + def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData) -> Sequence[int]: + assert cache is not None, 'prefix reordering requires a CacheData object' combined_ids = [context + continuation for context, continuation in zip(tokenized_contexts['input_ids'], tokenized_continuations['input_ids'])] - self.longest_prefix_to_indices = self._order_by_common_prefix(combined_ids) - self.indices_to_longest_prefix = OrderedDict() - for prefix in sorted(self.longest_prefix_to_indices.keys(), key = lambda x : -len(x)): + cache.longest_prefix_to_indices = self._order_by_common_prefix(combined_ids) + cache.indices_to_longest_prefix = OrderedDict() + for prefix in sorted(cache.longest_prefix_to_indices.keys(), key = lambda x : -len(x)): # indices for each prefix are already sorted by trie - for index in self.longest_prefix_to_indices[prefix]: - self.indices_to_longest_prefix[index] = prefix - return list(self.indices_to_longest_prefix.keys()) + for index in cache.longest_prefix_to_indices[prefix]: + cache.indices_to_longest_prefix[index] = prefix + return list(cache.indices_to_longest_prefix.keys()) def _order_by_common_prefix(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[Optional[int]],Sequence[int]]: longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = {} @@ -354,27 +363,22 @@ def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_conti ], dtype=torch.int) return torch.argsort(lengths, descending=True).tolist() - def _reset_cache_variables(self): - self.cached_sequence: Sequence[Optional[int]] = None - self.cached_past_key_values: torch.Tensor = None - self.longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = None - self.indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None - - def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): + def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData = None): if self.prefix_caching: - return self._get_inputs_with_cache(batch_of_indices, cc_pairs, model) + return self._get_inputs_with_cache(batch_of_indices, cc_pairs, model, cache) else: return self._get_inputs_without_cache(batch_of_indices, cc_pairs, model) - def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): - prefixes = [self.indices_to_longest_prefix[index] for index in batch_of_indices] + def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData): + assert cache is not None + prefixes = [cache.indices_to_longest_prefix[index] for index in batch_of_indices] prefix2cache = OrderedDict() # compute prefixes for prefix in set(prefixes): - if prefix == self.cached_sequence: - past_key_values = self.cached_past_key_values + if prefix == cache.cached_sequence: + past_key_values = cache.cached_past_key_values else: past_key_values = model(input_ids=torch.tensor(prefix).to(model.device)).past_key_values # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) @@ -384,7 +388,7 @@ def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List prefix2cache[prefix] = past_key_values # update cache with last one retrieved since instances come in order by common prefix - self.cached_sequence, self.cached_past_key_values = list(prefix2cache.items())[-1] + cache.cached_sequence, cache.cached_past_key_values = list(prefix2cache.items())[-1] # pad and mask batched past_key_values unpadded_past_keys_values = [prefix2cache[prefix] for prefix in prefixes] From 451ad3419f2cd5f9d715450df8e14f68ddfe91c8 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 15:55:41 -0700 Subject: [PATCH 18/28] use consistent arg order --- catwalk/models/metaicl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catwalk/models/metaicl.py b/catwalk/models/metaicl.py index 7fc7d931..ca36f4a4 100644 --- a/catwalk/models/metaicl.py +++ b/catwalk/models/metaicl.py @@ -16,11 +16,11 @@ def __init__( self, pretrained_model_name_or_path: str, *, + override_weights_file: str = None, prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', - example_seperator: str = '\n\n\n', - override_weights_file: str = None + example_seperator: str = '\n\n\n' ): super().__init__(pretrained_model_name_or_path, override_weights_file=override_weights_file, prefix_caching=prefix_caching) self.max_length_per_example = max_length_per_example From c54249346e9d3495082d279ec11befbd031d2ffa Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 16:21:56 -0700 Subject: [PATCH 19/28] wrestle with mypy --- catwalk/models/rank_classification.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 4356aa62..8e7afcd8 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -238,10 +238,10 @@ def _run_loglikelihood( @dataclass class CacheData: - cached_sequence: Sequence[Optional[int]] = None + cached_sequence: Optional[Sequence[Optional[int]]] = None cached_past_key_values: torch.Tensor = None - longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = None - indices_to_longest_prefix: OrderedDict[int,Sequence[int]] = None + longest_prefix_to_indices: Optional[Dict[Sequence[Optional[int]],Sequence[int]]] = None + indices_to_longest_prefix: Optional[OrderedDict[int,Sequence[Optional[int]]]] = None @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): @@ -322,12 +322,12 @@ def _final_truncatation(self, tokenized_contexts: BatchEncoding, tokenized_conti def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData = None) -> Sequence[int]: if self.prefix_caching: + assert cache is not None, 'prefix reordering requires a CacheData object' return self._reorder_by_prefix(tokenized_contexts, tokenized_continuations, cache) else: return self._reorder_by_longest(tokenized_contexts, tokenized_continuations) def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData) -> Sequence[int]: - assert cache is not None, 'prefix reordering requires a CacheData object' combined_ids = [context + continuation for context, continuation in zip(tokenized_contexts['input_ids'], tokenized_continuations['input_ids'])] cache.longest_prefix_to_indices = self._order_by_common_prefix(combined_ids) cache.indices_to_longest_prefix = OrderedDict() @@ -365,13 +365,14 @@ def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_conti def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData = None): if self.prefix_caching: + assert cache is not None return self._get_inputs_with_cache(batch_of_indices, cc_pairs, model, cache) else: return self._get_inputs_without_cache(batch_of_indices, cc_pairs, model) def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData): - assert cache is not None + assert cache.indices_to_longest_prefix is not None prefixes = [cache.indices_to_longest_prefix[index] for index in batch_of_indices] prefix2cache = OrderedDict() From 5052e0ffc64f3ac4f112aae36c8b391775f3478f Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 17:58:41 -0700 Subject: [PATCH 20/28] docs and better names --- catwalk/models/rank_classification.py | 26 ++++++++++++++++++++++---- catwalk/utils/prefix_trie.py | 19 ++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 8e7afcd8..1264c360 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -246,6 +246,19 @@ class CacheData: @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): def __init__(self, pretrained_model_name_or_path: str, *, override_weights_file: str = None, prefix_caching: bool = False): + """ + # Parameters + + pretrained_model_name_or_path : `str` + The name of the transformer, for example `"gpt2-large"` + override_weights_file : `str`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + prefix_caching : `bool`, optional (default = `False`) + If set to True uses a caching strategy that improves performance when many inputs in a task + share prefixes. This orders the dataset by common prefixes and caches the current shared prefix. + """ super().__init__(pretrained_model_name_or_path, override_weights_file=override_weights_file) self.prefix_caching = prefix_caching @@ -329,22 +342,27 @@ def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_contin def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData) -> Sequence[int]: combined_ids = [context + continuation for context, continuation in zip(tokenized_contexts['input_ids'], tokenized_continuations['input_ids'])] - cache.longest_prefix_to_indices = self._order_by_common_prefix(combined_ids) + cache.longest_prefix_to_indices = self._greedy_assign_prefix_by_total_coverage(combined_ids) cache.indices_to_longest_prefix = OrderedDict() + # secondarily sort by length so that largest batches that may cause memory overflow are likely come early for prefix in sorted(cache.longest_prefix_to_indices.keys(), key = lambda x : -len(x)): - # indices for each prefix are already sorted by trie + # sequence indices for each prefix are already sorted by length from reading trie from leaf to root for index in cache.longest_prefix_to_indices[prefix]: cache.indices_to_longest_prefix[index] = prefix return list(cache.indices_to_longest_prefix.keys()) - def _order_by_common_prefix(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[Optional[int]],Sequence[int]]: + def _greedy_assign_prefix_by_total_coverage(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[Optional[int]],Sequence[int]]: + """Returns a Dict of prefixes and the sequence indices assigned to them. Sorts possible prefixes by total tokens covered in + subsequences and assigns sequences to the first prefix they appear in. PrefixTrie only tracks subsequences after a minimum + track_after_depth so short coincidental overlaps are be ignored.""" longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = {} trie = PrefixTrie(sequences) leaves = trie.get_leaf_nodes() leaves_sequences = [tuple(leaf.get_sequence()) for leaf in leaves] leaves_and_sequences = Tqdm.tqdm(zip(leaves_sequences, leaves), desc="Finding prefixes", total=len(leaves)) - leaves2prefixes = {leaf_sequence:leaf.get_prefix_indices() for leaf_sequence, leaf in leaves_and_sequences} + leaves2prefixes = {leaf_sequence:leaf.get_subsequences() for leaf_sequence, leaf in leaves_and_sequences} + # greedily assign sequences to prefixes with top total coverage indices_already_assigned = set() for leaf_sequence in sorted(leaves_sequences, key=lambda leaf_sequence : -leaves2prefixes[leaf_sequence][1]): prefix_indices, _ = leaves2prefixes[leaf_sequence] diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index 6cfdec83..9f0633ba 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -21,14 +21,14 @@ def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = self.track_after_depth = track_after_depth self.nodes: List['PrefixTrieNode'] = [] for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)): - self.add_sequence(sequence=sequence, index=i) - # remove all indices and lenghts_covered at non-forking, non-leaf nodes + self._add_sequence(sequence=sequence, index=i) + # only need to track sequences at forks and terminations for node in self.nodes: if len(node.children) == 1: node.indices = [] node.lengths_covered = [] - def add_sequence(self, sequence: Sequence[int], index: int): + def _add_sequence(self, sequence: Sequence[int], index: int): seq_len = len(sequence) current_node = self.root for token_idx, token in enumerate(sequence): @@ -52,6 +52,7 @@ def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None): self.children: Dict[int,'PrefixTrieNode'] = {} def get_sequence(self) -> List[Optional[int]]: + """Returns the sequence associated with a node""" current_node = self sequence = [] while current_node.parent is not None: @@ -59,8 +60,16 @@ def get_sequence(self) -> List[Optional[int]]: current_node = current_node.parent return sequence[::-1] - def get_prefix_indices(self) -> Tuple[List[int], int]: - """Returns all indices for subsequences of the current node including itself starting with longest and decreasing""" + def get_subsequences(self) -> Tuple[List[int], int]: + """ + Returns a tuple of: + - a list of all indices for subsequences of the current node including itself + starting with longest and decreasing + - an int, the total number of tokens covered in all subsequences by this prefix + + Note when a PrefixTrie with track_after_depth > 0, some subsequences will be intentionally + ignored here as their indices are not registered in low depth nodes. + """ current_node = self indices = [] already_found = set() From e584320c9fa4ab74130a4f368c023499574d40fd Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 17:59:44 -0700 Subject: [PATCH 21/28] test for PrefixTrie --- tests/test_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..9271bd4e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,10 @@ +from catwalk.utils import PrefixTrie + +def test_prefix_trie(): + sequences = [[1,2,3],[2,3,4],[1,2,3,4]] + trie = PrefixTrie(sequences, track_after_depth=1) + leaves = trie.get_leaf_nodes() + assert leaves[0].get_sequence() == [2,3,4] + assert leaves[1].get_sequence() == [1,2,3,4] + assert leaves[0].get_subsequences() == ([1], 3) + assert leaves[1].get_subsequences() == ([2,0], 7) \ No newline at end of file From 8c02abef4bf41126509a60278b367050ec8e1de8 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Mon, 25 Jul 2022 18:00:24 -0700 Subject: [PATCH 22/28] handle full overlap corner case --- catwalk/utils/prefix_trie.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py index 9f0633ba..7d77403d 100644 --- a/catwalk/utils/prefix_trie.py +++ b/catwalk/utils/prefix_trie.py @@ -1,7 +1,6 @@ from typing import List, Optional, Sequence, Tuple, Dict from tango.common import Tqdm - class PrefixTrie(): def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = 10): """ @@ -25,8 +24,7 @@ def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = # only need to track sequences at forks and terminations for node in self.nodes: if len(node.children) == 1: - node.indices = [] - node.lengths_covered = [] + node.subsequences_on_this_path = node.subsequences_ending_here def _add_sequence(self, sequence: Sequence[int], index: int): seq_len = len(sequence) @@ -37,8 +35,8 @@ def _add_sequence(self, sequence: Sequence[int], index: int): self.nodes.append(current_node.children[token]) current_node = current_node.children[token] if (token_idx + 1 >= self.track_after_depth) or (token_idx + 1 >= seq_len): - current_node.indices.append(index) - current_node.lengths_covered.append(token_idx + 1) + current_node.subsequences_on_this_path[index] = token_idx + 1 + current_node.subsequences_ending_here[index] = len(sequence) def get_leaf_nodes(self) -> List['PrefixTrieNode']: return [node for node in self.nodes if len(node.children) == 0] @@ -47,7 +45,8 @@ class PrefixTrieNode(): def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None): self.parent = parent self.token = token - self.indices: List[int] = [] + self.subsequences_on_this_path: Dict[int,int] = {} + self.subsequences_ending_here: Dict[int,int] = {} self.lengths_covered: List[int] = [] self.children: Dict[int,'PrefixTrieNode'] = {} @@ -76,20 +75,11 @@ def get_subsequences(self) -> Tuple[List[int], int]: total_lengths_covered = 0 while current_node.parent is not None: new_indices = [] - for index, length_covered in zip(current_node.indices, current_node.lengths_covered): + for index, length_covered in current_node.subsequences_on_this_path.items(): if index not in already_found: new_indices.append(index) total_lengths_covered += length_covered already_found.update(new_indices) indices.extend(new_indices) current_node = current_node.parent - return indices, total_lengths_covered - -if __name__ == "__main__": - sequences = [[1,2,3],[2,3,4],[1,2,3,4]] - trie = PrefixTrie(sequences, track_after_depth=1) - leaves = trie.get_leaf_nodes() - assert leaves[0].get_sequence() == [2,3,4] - assert leaves[1].get_sequence() == [1,2,3,4] - assert leaves[0].get_prefix_indices() == ([1], 3) - assert leaves[1].get_prefix_indices() == ([2,0], 7) \ No newline at end of file + return indices, total_lengths_covered \ No newline at end of file From 827b35025644c277085a431e4b2534c111d7f928 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Thu, 28 Jul 2022 15:11:16 -0700 Subject: [PATCH 23/28] expose prefix_caching in ia3 --- catwalk/models/ia3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/catwalk/models/ia3.py b/catwalk/models/ia3.py index 10e8907a..8a27eda7 100644 --- a/catwalk/models/ia3.py +++ b/catwalk/models/ia3.py @@ -36,6 +36,7 @@ def __init__( *, likelihood_averaging: str = 'token', override_weights_file: str = None, + prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', example_seperator: str = '\n\n\n', @@ -46,6 +47,7 @@ def __init__( pretrained_model_name_or_path, likelihood_averaging=likelihood_averaging, override_weights_file=override_weights_file, + prefix_caching=prefix_caching, max_length_per_example=max_length_per_example, continuation_seperator=continuation_seperator, example_seperator=example_seperator, From 4e98c01c774d96f263f8d2687289f76322ca839f Mon Sep 17 00:00:00 2001 From: jagnusson Date: Fri, 29 Jul 2022 11:32:33 -0700 Subject: [PATCH 24/28] batch processing of cached prefixes --- catwalk/models/rank_classification.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 8cde703f..14e1b47a 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -434,19 +434,25 @@ def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List prefix2cache = OrderedDict() # compute prefixes - for prefix in set(prefixes): - if prefix == cache.cached_sequence: - past_key_values = cache.cached_past_key_values - else: - past_key_values = model(input_ids=torch.tensor(prefix).to(model.device)).past_key_values - # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) - past_key_values = torch.stack(tuple(torch.stack(past_key_values[i]) for i in range(len(past_key_values)))) + if prefixes[0] == cache.cached_sequence: + prefix2cache[prefixes[0]] = cache.cached_past_key_values + + uncached_prefixes = list(set(prefix for prefix in prefixes if prefix not in prefix2cache)) # ordering must be fixed + if len(uncached_prefixes) > 0: + unpadded_prefixes = [torch.tensor(prefix) for prefix in uncached_prefixes] + unpadded_prefix_mask = [torch.ones_like(prefix) for prefix in unpadded_prefixes] + padded_prefixes = pad_sequence(unpadded_prefixes, batch_first=True).to(model.device) + padded_prefix_masks = pad_sequence(unpadded_prefix_mask, batch_first=True, padding_value=0.0).to(model.device) + past_key_values = model(input_ids=padded_prefixes, attention_mask=padded_prefix_masks).past_key_values + # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) + past_key_values = torch.stack(tuple(torch.stack(past_key_values[i]) for i in range(len(past_key_values)))) + for i, prefix in enumerate(uncached_prefixes): # tensor(layers, keys/values, num_heads, sequence_len, embed_size_per_head) - past_key_values = past_key_values.squeeze(2) - prefix2cache[prefix] = past_key_values + prefix2cache[prefix] = past_key_values[:,:,i,:,:len(prefix),:] # update cache with last one retrieved since instances come in order by common prefix - cache.cached_sequence, cache.cached_past_key_values = list(prefix2cache.items())[-1] + cache.cached_sequence = prefixes[-1] + cache.cached_past_key_values = prefix2cache[prefixes[-1]] # pad and mask batched past_key_values unpadded_past_keys_values = [prefix2cache[prefix] for prefix in prefixes] From fda8433360eaacbfe2c9ab4fb4a47e75b6a2e7c0 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Thu, 18 Aug 2022 17:38:18 -0700 Subject: [PATCH 25/28] gpu logit processing for rc models --- CHANGELOG.md | 1 + catwalk/models/rank_classification.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bb235a8..2b02fd0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MetaICLTask now supports fewshots less than 16 and only support getting the test split - set default logging level to `"WARNING"` instead of `"ERROR"` when invoking `python -m catwalk` - changed MetaICLModel formatting to always preserve whitespace, to reproduce MetaICL results +- improved speed of rank classification models by aggregating sequence logits on GPU rather than on CPU ### Added diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 6dd56bdd..9d68420f 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -242,11 +242,11 @@ def _run_loglikelihood( for field_name, tensors in unpadded_batch.items() } - batch_logits = log_softmax(model(**padded_batch).logits, dim=-1).cpu() + batch_logits = log_softmax(model(**padded_batch).logits, dim=-1) for i, instance_logits, decoder_input_ids in zip(batch_of_indices, batch_logits, unpadded_batch["labels"]): instance_logits = instance_logits[:len(decoder_input_ids)] - instance_logits = torch.gather(instance_logits, 1, decoder_input_ids.unsqueeze(-1)) + instance_logits = torch.gather(instance_logits, 1, decoder_input_ids.unsqueeze(-1).to(model.device)) denom = len(tuples[i][1]) if self.likelihood_averaging == 'char' else len(decoder_input_ids) results[i] = float(instance_logits.sum()) / denom @@ -321,11 +321,11 @@ def _run_loglikelihood( for field_name, tensors in unpadded_batch.items() } - batch_logits = log_softmax(model(**padded_batch)[0], dim=-1).cpu() + batch_logits = log_softmax(model(**padded_batch)[0], dim=-1) z = zip(batch_of_indices, batch_logits, input_lengths, batch_contexts, batch_continuations) for i, instance_logits, input_length, instance_context, instance_continuation in z: instance_logits = instance_logits[input_length-len(instance_continuation):input_length] - instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1)) + instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1).to(model.device)) denom = len(tuples[i][1]) if self.likelihood_averaging == 'char' else len(instance_continuation) results[i] = float(instance_logits.sum()) / denom From 4acdd95e2dabe659e568ca71fcb0eab6c7af4443 Mon Sep 17 00:00:00 2001 From: jagnusson Date: Thu, 18 Aug 2022 17:55:38 -0700 Subject: [PATCH 26/28] add 2 new tasks --- CHANGELOG.md | 1 + catwalk/tasks/__init__.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bb235a8..95596e11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Optional random_subsample_seed for PredictStep - An option for rank classification to average log likelihoods by token length - Adds support for inference with IA3 adaptors loaded from a file on decoder only ranked classification models +- Add support for MetaICL's race-high and numer_sense tasks ### Fixed diff --git a/catwalk/tasks/__init__.py b/catwalk/tasks/__init__.py index d8144cff..09f2dde8 100644 --- a/catwalk/tasks/__init__.py +++ b/catwalk/tasks/__init__.py @@ -241,8 +241,11 @@ "metaicl::unifiedqa:openbookqa_with_ir": MetaICLTask("unifiedqa:openbookqa_with_ir").add_metrics(MC_METRICS), "metaicl::unifiedqa:mctest": MetaICLTask("unifiedqa:mctest").add_metrics(MC_METRICS), "metaicl::unifiedqa:ai2_science_middle": MetaICLTask("unifiedqa:ai2_science_middle").add_metrics(MC_METRICS), + + "metaicl::commonsense_qa": MetaICLTask("commonsense_qa").add_metrics(MC_METRICS), - "metaicl::commonsense_qa": MetaICLTask("commonsense_qa").add_metrics(MC_METRICS), + "metaicl::numer_sense": MetaICLTask("numer_sense").add_metrics(classification_metrics(12)), + "metaicl::race-high": MetaICLTask("race-high").add_metrics(MC_METRICS), } for config in datasets.get_dataset_config_names("bigscience/P3"): From 357928253fecf0aaec66a128b59a88d85c8be3df Mon Sep 17 00:00:00 2001 From: jagnusson Date: Thu, 18 Aug 2022 18:17:38 -0700 Subject: [PATCH 27/28] clean up bad merge --- catwalk/models/rank_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 10b50aaf..0029ceec 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -349,12 +349,12 @@ def _run_loglikelihood( batch_size) for batch_of_indices in batches_of_indices: inputs, input_lengths, batch_contexts, batch_continuations = self._get_inputs(batch_of_indices, cc_pairs, model, cache) - batch_logits = log_softmax(model(**inputs)[0], dim=-1).cpu() + batch_logits = log_softmax(model(**inputs)[0], dim=-1) z = zip(batch_of_indices, batch_logits, input_lengths, batch_contexts, batch_continuations) for i, instance_logits, input_length, instance_context, instance_continuation in z: assert input_length-len(instance_continuation) >=0 instance_logits = instance_logits[input_length-len(instance_continuation):input_length] - instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1)) + instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1).to(model.device)) denom = len(tuples[i][1]) if self.likelihood_averaging == 'char' else len(instance_continuation) results[i] = float(instance_logits.sum()) / denom From 3cab644d3e897e67a809ca1a9accd5a839451883 Mon Sep 17 00:00:00 2001 From: IanMagnusson Date: Mon, 29 Aug 2022 18:46:43 -0700 Subject: [PATCH 28/28] An example usage of prefix_caching --- experiments/prefix_cache_demo.py | 86 ++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 experiments/prefix_cache_demo.py diff --git a/experiments/prefix_cache_demo.py b/experiments/prefix_cache_demo.py new file mode 100644 index 00000000..d8a535be --- /dev/null +++ b/experiments/prefix_cache_demo.py @@ -0,0 +1,86 @@ +import argparse +import os +from tango.common.logging import initialize_logging +import time + +from catwalk.models import MetaICLModel +from catwalk.steps import CalculateMetricsStep, PredictStep +from catwalk.tasks import TASK_SETS + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--zeroshot', action='store_true') + parser.add_argument('--no_prefix_caching', action='store_true') + parser.add_argument('--first_n_tasks', type=int, default=20) + args = parser.parse_args() + + start = time.time() + initialize_logging(log_level="ERROR") + os.environ['TOKENIZERS_PARALLELISM'] = "false" + + tasks = TASK_SETS['metaicl-classification-eval'] + tasks = sorted(tasks)[:args.first_n_tasks] + + num_shots = 0 if args.zeroshot else 16 + if args.zeroshot: + batch_size = 64 + elif args.no_prefix_caching: + batch_size = 16 # to account for larger input sizes with ICL + # CACHING with batching does not work close to the max model size as + # the largest prefix + largest continuation in a batch must be <= max model size + else: + batch_size = 1 + limit = 1000 + random_subsample_seed=42 + seeds = [100] if args.zeroshot else [100, 13, 21, 42, 87] + + model = MetaICLModel('gpt2-large', continuation_seperator = ' ' if args.zeroshot else '\n', prefix_caching = not args.no_prefix_caching) + + seed2metrics = {} + for fewshot_seed in seeds: + metric_task_dict = {} + for task in tasks: + + predictions = PredictStep( + model=model, + task=task, + batch_size=batch_size, + limit=limit, + random_subsample_seed=random_subsample_seed, + num_shots=num_shots, + fewshot_seed=fewshot_seed, + ) + metrics = CalculateMetricsStep( + model=model, + task=task, + predictions=predictions) + metric_task_dict[task] = metrics + seed2metrics[fewshot_seed] = metric_task_dict + + avg_f1_per_seed = [] + avg_acc_per_seed = [] + for seed, metric_task_dict in seed2metrics.items(): + total_sum_f1 = 0.0 + total_sum_acc = 0.0 + for task, metrics in metric_task_dict.items(): + for metric, result in metrics.result().items(): + avg_result = result.mean() + if metric == 'f1': + total_sum_f1 += avg_result.item() + elif metric == 'acc': + total_sum_acc += avg_result.item() + print(f"{task}\t{seed}\t{metric}\t{avg_result}") + avg_f1_per_seed.append(total_sum_f1 / len(tasks)) + avg_acc_per_seed.append(total_sum_acc / len(tasks)) + + print(f"avg macro f1 over seeds {sum(avg_f1_per_seed) / len(seeds)}") + print(f"min macro f1 over seeds {min(avg_f1_per_seed)}") + print(f"avg macro acc over seeds {sum(avg_acc_per_seed) / len(seeds)}") + print(f"min macro acc over seeds {min(avg_acc_per_seed)}") + + end = time.time() + print(f"total seconds elapsed: {end - start}") + +if __name__ == "__main__": + + main() \ No newline at end of file