diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb2fe38..276f9400 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MetaICLTask now supports fewshots less than 16 and only support getting the test split - set default logging level to `"WARNING"` instead of `"ERROR"` when invoking `python -m catwalk` - changed MetaICLModel formatting to always preserve whitespace, to reproduce MetaICL results +- improved speed of rank classification models by aggregating sequence logits on GPU rather than on CPU ### Added @@ -23,8 +24,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adds a new MetaICLModel that replicates the formatting and truncation used by MetaICL for few shot evaluation - Optional `random_subsample_seed` for PredictStep - An option for rank classification to average log likelihoods by token length -- Adds support for inference with IA3 adapters loaded from a file on decoder only ranked classification models +- Adds support for inference with IA3 adaptors loaded from a file on decoder only ranked classification models - Added the ability to train `HFAutoModel` +- Add support for MetaICL's race-high and numer_sense tasks +- Prefix caching for DecoderOnlyRCModel that reuses overlapping prefixes between instances rather than recomputing them ### Fixed diff --git a/catwalk/models/ia3.py b/catwalk/models/ia3.py index 5a5c698c..8a27eda7 100644 --- a/catwalk/models/ia3.py +++ b/catwalk/models/ia3.py @@ -11,8 +11,16 @@ class DecoderOnlyIA3Mixin: @classmethod - def _make_model(self, pretrained_model_name_or_path: str, *, ia3_weights_file: str = None, **kwargs) -> GPT2LMHeadModel: - model = cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, True) + def _make_model( + + self, + pretrained_model_name_or_path: str, + *, + override_weights_file: str = None, + ia3_weights_file: str = None, + **kwargs + ) -> GPT2LMHeadModel: + model = cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, True, override_weights_file=override_weights_file) isinstance(model, GPT2LMHeadModel) config = IA3ForGPT2Config() model = modify_with_ia3(model, config) @@ -27,6 +35,8 @@ def __init__( pretrained_model_name_or_path: str, *, likelihood_averaging: str = 'token', + override_weights_file: str = None, + prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', example_seperator: str = '\n\n\n', @@ -36,6 +46,8 @@ def __init__( super().__init__( pretrained_model_name_or_path, likelihood_averaging=likelihood_averaging, + override_weights_file=override_weights_file, + prefix_caching=prefix_caching, max_length_per_example=max_length_per_example, continuation_seperator=continuation_seperator, example_seperator=example_seperator, diff --git a/catwalk/models/metaicl.py b/catwalk/models/metaicl.py index 32294bb7..538aafd1 100644 --- a/catwalk/models/metaicl.py +++ b/catwalk/models/metaicl.py @@ -17,6 +17,8 @@ def __init__( pretrained_model_name_or_path: str, *, likelihood_averaging: str = 'token', + override_weights_file: str = None, + prefix_caching: bool = True, max_length_per_example: int = 256, continuation_seperator: str = '\n', example_seperator: str = '\n\n\n', @@ -25,6 +27,8 @@ def __init__( super().__init__( pretrained_model_name_or_path, likelihood_averaging=likelihood_averaging, + override_weights_file=override_weights_file, + prefix_caching=prefix_caching, **model_kwargs ) self.max_length_per_example = max_length_per_example diff --git a/catwalk/models/rank_classification.py b/catwalk/models/rank_classification.py index 6dd56bdd..0029ceec 100644 --- a/catwalk/models/rank_classification.py +++ b/catwalk/models/rank_classification.py @@ -1,5 +1,6 @@ import collections -from typing import Dict, Any, List, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast, Callable +from dataclasses import dataclass +from typing import Dict, Any, List, OrderedDict, Tuple, Sequence, Iterator, Union, Mapping, Optional, cast, Callable import more_itertools import torch @@ -8,11 +9,12 @@ from torch import log_softmax from torch.nn.utils.rnn import pad_sequence from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, GPT2LMHeadModel, \ - AutoTokenizer, GPT2Tokenizer, T5TokenizerFast + AutoTokenizer, GPT2Tokenizer, T5TokenizerFast, BatchEncoding from catwalk import cached_transformers from catwalk.model import Model, TrainableModel, Instance from catwalk.task import Task, InstanceFormat, RankClassificationInstance +from catwalk.utils import PrefixTrie _Model = Union[T5ForConditionalGeneration, GPT2LMHeadModel] _Tokenizer = Union[T5TokenizerFast, GPT2Tokenizer] @@ -26,6 +28,7 @@ def __init__( pretrained_model_name_or_path: str, *, likelihood_averaging: str = 'char', + override_weights_file: str = None, **model_kwargs ): """ @@ -36,16 +39,20 @@ def __init__( likelihood_averaging : `str`, optional (default = `char`) The method for averaging the sum likelihood of the continuation. 'char' averages by character length, 'token' averages by token length. + override_weights_file : `str`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. model_kwargs: Additional kwargs passed to the `_make_model` method. """ assert likelihood_averaging in {'char', 'token'} self.pretrained_model_name_or_path = pretrained_model_name_or_path self.likelihood_averaging = likelihood_averaging + self.override_weights_file = override_weights_file self.model_kwargs = model_kwargs - @classmethod - def _make_model(cls, pretrained_model_name_or_path: str, **kwargs) -> _Model: + def _make_model(cls, pretrained_model_name_or_path: str, *, override_weights_file: str = None, **kwargs) -> _Model: raise NotImplementedError def predict( # type: ignore @@ -59,7 +66,7 @@ def predict( # type: ignore fewshot_seed: int = None ) -> Iterator[Dict[str, Any]]: device = resolve_device() - model = self._make_model(self.pretrained_model_name_or_path, **self.model_kwargs).to(device).eval() + model = self._make_model(self.pretrained_model_name_or_path, override_weights_file=self.override_weights_file, **self.model_kwargs).to(device).eval() tokenizer = cached_transformers.get_tokenizer(AutoTokenizer, self.pretrained_model_name_or_path) for instance_chunk in more_itertools.chunked(instances, max_instances_in_memory): @@ -192,8 +199,8 @@ def collate_for_training(self, instances: Sequence[Tuple[Task, Instance]]) -> An @Model.register("rc::encoder_decoder") class EncoderDecoderRCModel(RankClassificationModel): @classmethod - def _make_model(cls, pretrained_model_name_or_path: str, **kwargs) -> T5ForConditionalGeneration: - return cached_transformers.get(AutoModelForSeq2SeqLM, pretrained_model_name_or_path, False) + def _make_model(cls, pretrained_model_name_or_path: str, *, override_weights_file: str = None, **kwargs) -> T5ForConditionalGeneration: + return cached_transformers.get(AutoModelForSeq2SeqLM, pretrained_model_name_or_path, False, override_weights_file=override_weights_file) def _run_loglikelihood( self, @@ -242,11 +249,11 @@ def _run_loglikelihood( for field_name, tensors in unpadded_batch.items() } - batch_logits = log_softmax(model(**padded_batch).logits, dim=-1).cpu() + batch_logits = log_softmax(model(**padded_batch).logits, dim=-1) for i, instance_logits, decoder_input_ids in zip(batch_of_indices, batch_logits, unpadded_batch["labels"]): instance_logits = instance_logits[:len(decoder_input_ids)] - instance_logits = torch.gather(instance_logits, 1, decoder_input_ids.unsqueeze(-1)) + instance_logits = torch.gather(instance_logits, 1, decoder_input_ids.unsqueeze(-1).to(model.device)) denom = len(tuples[i][1]) if self.likelihood_averaging == 'char' else len(decoder_input_ids) results[i] = float(instance_logits.sum()) / denom @@ -254,11 +261,48 @@ def _run_loglikelihood( return cast(Sequence[float], results) +@dataclass +class CacheData: + cached_sequence: Optional[Sequence[Optional[int]]] = None + cached_past_key_values: torch.Tensor = None + longest_prefix_to_indices: Optional[Dict[Sequence[Optional[int]],Sequence[int]]] = None + indices_to_longest_prefix: Optional[OrderedDict[int,Sequence[Optional[int]]]] = None + @Model.register("rc::decoder_only") class DecoderOnlyRCModel(RankClassificationModel): + def __init__( + self, + pretrained_model_name_or_path: str, + *, + likelihood_averaging: str = 'char', + override_weights_file: str = None, + prefix_caching: bool = False, + **model_kwargs + ): + """ + # Parameters + + pretrained_model_name_or_path : `str` + The name of the transformer, for example `"gpt2-large"` + likelihood_averaging : `str`, optional (default = `char`) + The method for averaging the sum likelihood of the continuation. 'char' averages by + character length, 'token' averages by token length. + override_weights_file : `str`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + prefix_caching : `bool`, optional (default = `False`) + If set to True uses a caching strategy that improves performance when many inputs in a task + share prefixes. This orders the dataset by common prefixes and caches the current shared prefix. + model_kwargs: + Additional kwargs passed to the `_make_model` method. + """ + super().__init__(pretrained_model_name_or_path, likelihood_averaging=likelihood_averaging, override_weights_file=override_weights_file, **model_kwargs) + self.prefix_caching = prefix_caching + @classmethod - def _make_model(cls, pretrained_model_name_or_path: str, **kwargs) -> GPT2LMHeadModel: - return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, False) + def _make_model(cls, pretrained_model_name_or_path: str, *, override_weights_file: str = None, **kwargs) -> GPT2LMHeadModel: + return cached_transformers.get(AutoModelForCausalLM, pretrained_model_name_or_path, False, override_weights_file=override_weights_file) def _run_loglikelihood( self, @@ -267,9 +311,19 @@ def _run_loglikelihood( tokenizer: _Tokenizer, batch_size: int = 32, ) -> Sequence[float]: + cache = CacheData() if self.prefix_caching else None + tokenized_contexts = tokenizer([t[0] for t in tuples]) tokenized_continuations = tokenizer([t[1] for t in tuples]) + self._final_truncatation( + tokenized_contexts, tokenized_continuations, tokenizer.model_max_length + ) + + ordered_indices = self._reorder_instances( + tokenized_contexts, tokenized_continuations, cache + ) + # transpose the token ids so we can access them one instance at a time cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = [] assert tokenized_contexts.keys() == tokenized_continuations.keys() @@ -287,14 +341,6 @@ def _run_loglikelihood( torch.tensor(continuation, dtype=torch.long) ) - # find out the order to process sequences in - lengths = torch.tensor([ - len(cc_pair["input_ids"][0]) + len(cc_pair["input_ids"][1]) - for cc_pair in cc_pairs - ], dtype=torch.int) - ordered_indices = torch.argsort(lengths, descending=True) - del lengths - # actually do the processing results: List[Optional[float]] = [None] * len(ordered_indices) with torch.inference_mode(): @@ -302,32 +348,199 @@ def _run_loglikelihood( Tqdm.tqdm(ordered_indices, desc="Running log-likelihood queries"), batch_size) for batch_of_indices in batches_of_indices: - unpadded_batch = collections.defaultdict(list) - input_lengths = [] - batch_contexts = [] - batch_continuations = [] - for index in batch_of_indices: - for field_name, (context_ids, continuation_ids) in cc_pairs[index].items(): - ids = torch.cat([context_ids, continuation_ids]) - ids = ids[-(tokenizer.model_max_length+1):][:-1] - unpadded_batch[field_name].append(ids) - - input_lengths.append(len(unpadded_batch["input_ids"][-1])) - batch_contexts.append(cc_pairs[index]["input_ids"][0]) - batch_continuations.append(cc_pairs[index]["input_ids"][1]) - - padded_batch = { - field_name: pad_sequence(tensors, batch_first=True).to(model.device) - for field_name, tensors in unpadded_batch.items() - } - - batch_logits = log_softmax(model(**padded_batch)[0], dim=-1).cpu() + inputs, input_lengths, batch_contexts, batch_continuations = self._get_inputs(batch_of_indices, cc_pairs, model, cache) + batch_logits = log_softmax(model(**inputs)[0], dim=-1) z = zip(batch_of_indices, batch_logits, input_lengths, batch_contexts, batch_continuations) for i, instance_logits, input_length, instance_context, instance_continuation in z: + assert input_length-len(instance_continuation) >=0 instance_logits = instance_logits[input_length-len(instance_continuation):input_length] - instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1)) + instance_logits = torch.gather(instance_logits, 1, instance_continuation.unsqueeze(-1).to(model.device)) denom = len(tuples[i][1]) if self.likelihood_averaging == 'char' else len(instance_continuation) results[i] = float(instance_logits.sum()) / denom assert None not in results return cast(Sequence[float], results) + + def _final_truncatation(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, model_max_length: int): + """ Apply a last pass of truncation on the concatenated inputs to make sure it fits in the model_max_length""" + assert len(tokenized_contexts['input_ids']) == len(tokenized_continuations['input_ids']) + for i in range(len(tokenized_contexts['input_ids'])): + context_len = len(tokenized_contexts['input_ids'][i]) + cont_len = len(tokenized_continuations['input_ids'][i]) + assert cont_len < model_max_length + if context_len + cont_len > model_max_length: + tokenized_contexts['input_ids'][i] = tokenized_contexts['input_ids'][i][-model_max_length + cont_len:] + tokenized_contexts['attention_mask'][i] = tokenized_contexts['attention_mask'][i][-model_max_length + cont_len:] + + def _reorder_instances(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData = None) -> Sequence[int]: + if self.prefix_caching: + assert cache is not None, 'prefix reordering requires a CacheData object' + return self._reorder_by_prefix(tokenized_contexts, tokenized_continuations, cache) + else: + return self._reorder_by_longest(tokenized_contexts, tokenized_continuations) + + def _reorder_by_prefix(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding, cache: CacheData) -> Sequence[int]: + combined_ids = [context + continuation for context, continuation in zip(tokenized_contexts['input_ids'], tokenized_continuations['input_ids'])] + cache.longest_prefix_to_indices = self._greedy_assign_prefix_by_total_coverage(combined_ids) + cache.indices_to_longest_prefix = OrderedDict() + # secondarily sort by length so that largest batches that may cause memory overflow are likely come early + for prefix in sorted(cache.longest_prefix_to_indices.keys(), key = lambda x : -len(x)): + # sequence indices for each prefix are already sorted by length from reading trie from leaf to root + for index in cache.longest_prefix_to_indices[prefix]: + cache.indices_to_longest_prefix[index] = prefix + return list(cache.indices_to_longest_prefix.keys()) + + def _greedy_assign_prefix_by_total_coverage(self, sequences: Sequence[Sequence[int]]) -> Dict[Sequence[Optional[int]],Sequence[int]]: + """Returns a Dict of prefixes and the sequence indices assigned to them. Sorts possible prefixes by total tokens covered in + subsequences and assigns sequences to the first prefix they appear in. PrefixTrie only tracks subsequences after a minimum + track_after_depth so short coincidental overlaps are be ignored.""" + longest_prefix_to_indices: Dict[Sequence[Optional[int]],Sequence[int]] = {} + trie = PrefixTrie(sequences) + leaves = trie.get_leaf_nodes() + leaves_sequences = [tuple(leaf.get_sequence()) for leaf in leaves] + leaves_and_sequences = Tqdm.tqdm(zip(leaves_sequences, leaves), desc="Finding prefixes", total=len(leaves)) + leaves2prefixes = {leaf_sequence:leaf.get_subsequences() for leaf_sequence, leaf in leaves_and_sequences} + + # greedily assign sequences to prefixes with top total coverage + indices_already_assigned = set() + for leaf_sequence in sorted(leaves_sequences, key=lambda leaf_sequence : -leaves2prefixes[leaf_sequence][1]): + prefix_indices, _ = leaves2prefixes[leaf_sequence] + prefix_indices = [prefix_index for prefix_index in prefix_indices if prefix_index not in indices_already_assigned] + indices_already_assigned.update(prefix_indices) + if len(prefix_indices) > 0: + longest_prefix_to_indices[leaf_sequence] = tuple(prefix_indices) + + return longest_prefix_to_indices + + def _reorder_by_longest(self, tokenized_contexts: BatchEncoding, tokenized_continuations: BatchEncoding) -> Sequence[int]: + assert len(tokenized_contexts['input_ids']) == len(tokenized_continuations['input_ids']) + lengths = torch.tensor([ + len(tokenized_contexts["input_ids"][i]) + len(tokenized_continuations["input_ids"][i]) + for i in range(len(tokenized_contexts['input_ids'])) + ], dtype=torch.int) + return torch.argsort(lengths, descending=True).tolist() + + def _get_inputs(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData = None): + if self.prefix_caching: + assert cache is not None + return self._get_inputs_with_cache(batch_of_indices, cc_pairs, model, cache) + else: + return self._get_inputs_without_cache(batch_of_indices, cc_pairs, model) + + + def _get_inputs_with_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model, cache: CacheData): + assert cache.indices_to_longest_prefix is not None + prefixes = [cache.indices_to_longest_prefix[index] for index in batch_of_indices] + prefix2cache = OrderedDict() + + # compute prefixes + if prefixes[0] == cache.cached_sequence: + prefix2cache[prefixes[0]] = cache.cached_past_key_values + + uncached_prefixes = list(set(prefix for prefix in prefixes if prefix not in prefix2cache)) # ordering must be fixed + if len(uncached_prefixes) > 0: + unpadded_prefixes = [torch.tensor(prefix) for prefix in uncached_prefixes] + unpadded_prefix_mask = [torch.ones_like(prefix) for prefix in unpadded_prefixes] + padded_prefixes = pad_sequence(unpadded_prefixes, batch_first=True).to(model.device) + padded_prefix_masks = pad_sequence(unpadded_prefix_mask, batch_first=True, padding_value=0.0).to(model.device) + past_key_values = model(input_ids=padded_prefixes, attention_mask=padded_prefix_masks).past_key_values + # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) + past_key_values = torch.stack(tuple(torch.stack(past_key_values[i]) for i in range(len(past_key_values)))) + for i, prefix in enumerate(uncached_prefixes): + # tensor(layers, keys/values, num_heads, sequence_len, embed_size_per_head) + prefix2cache[prefix] = past_key_values[:,:,i,:,:len(prefix),:] + + # update cache with last one retrieved since instances come in order by common prefix + cache.cached_sequence = prefixes[-1] + cache.cached_past_key_values = prefix2cache[prefixes[-1]] + + # pad and mask batched past_key_values + unpadded_past_keys_values = [prefix2cache[prefix] for prefix in prefixes] + unpadded_past_keys_values_attn_mask = [] + # only use the prefixed part of past_key_values that is present in the instance + for prefix_idx, cc_pairs_idx in enumerate(batch_of_indices): + is_identical = True + for tok_idx, tok in enumerate(cc_pairs[cc_pairs_idx]['input_ids'][0]): + if tok.item() != prefixes[prefix_idx][tok_idx]: + unpadded_past_keys_values_attn_mask.append(torch.tensor([1] * tok_idx, dtype=torch.int64)) + is_identical = False + break + if is_identical: + # Avoid empty input by leaving last token of context for input because continuations drop one token for right shift + max_prefix_len = len(cc_pairs[cc_pairs_idx]['input_ids'][0]) - 1 + unpadded_past_keys_values_attn_mask.append(torch.tensor([1] * max_prefix_len, dtype=torch.int64)) + + # past_keys_values needs its own attention mask + padded_past_keys_values_attn_mask = pad_sequence(unpadded_past_keys_values_attn_mask, batch_first=True, padding_value=0) + cache_lengths = [mask.sum().item() for mask in padded_past_keys_values_attn_mask] + max_past_key_value_len = max(cache_lengths) + + # pad and truncate past_keys_values to longest actually used + unpadded_past_keys_values = [t.transpose(0,-2) for t in unpadded_past_keys_values] + padded_past_keys_values = pad_sequence(unpadded_past_keys_values, batch_first=True) + padded_past_keys_values = padded_past_keys_values.permute((4, 2, 0, 3, 1, 5)) + # tensor(layers, keys/values, batch_size, num_heads, sequence_len, embed_size_per_head) + padded_past_keys_values = padded_past_keys_values[:,:,:,:,:max_past_key_value_len] + + # make input_ids by removing whatever parts of past_key_values are present + unpadded_input_ids = [] + input_lengths = [] + batch_contexts = [] + batch_continuations = [] + + for prefix_idx, cc_pairs_idx in enumerate(batch_of_indices): + context_ids, continuation_ids = cc_pairs[cc_pairs_idx]['input_ids'] + ids = torch.cat([context_ids, continuation_ids])[:-1] + ids = ids[cache_lengths[prefix_idx]:] + + # used to find logits specifically for continuation + input_lengths.append(len(ids)) + batch_contexts.append(cc_pairs[cc_pairs_idx]["input_ids"][0]) + batch_continuations.append(cc_pairs[cc_pairs_idx]["input_ids"][1]) + + unpadded_input_ids.append(ids) + + # batch and pad and make attention mask + unpadded_attn_mask = [torch.ones_like(t) for t in unpadded_input_ids] + padded_attn_mask = pad_sequence(unpadded_attn_mask, batch_first=True, padding_value=0) + padded_input_ids = pad_sequence(unpadded_input_ids, batch_first=True) + + # combine the attention masks + full_attn_mask = torch.cat((padded_past_keys_values_attn_mask, padded_attn_mask), dim=1) + assert full_attn_mask.shape[1] <= model.config.n_positions, "Presently batches with wide range of prefix and input lengths are not supported due overrun of max model size" + + # make position_ids + max_input_len = padded_input_ids.shape[-1] + position_ids = torch.stack([torch.arange(cache_length, cache_length + max_input_len) for cache_length in cache_lengths], dim=0) + position_ids = position_ids * padded_attn_mask + assert (position_ids < model.config.n_positions).all() + + inputs = { + 'input_ids': padded_input_ids.to(model.device), + 'past_key_values': padded_past_keys_values, + 'attention_mask': full_attn_mask.to(model.device), + 'position_ids': position_ids.to(model.device) + } + + return inputs, input_lengths, batch_contexts, batch_continuations + + def _get_inputs_without_cache(self, batch_of_indices: Sequence[int], cc_pairs: List[Dict[str, Tuple[torch.Tensor, torch.Tensor]]], model: _Model): + unpadded_batch = collections.defaultdict(list) + input_lengths = [] + batch_contexts = [] + batch_continuations = [] + + for index in batch_of_indices: + for field_name, (context_ids, continuation_ids) in cc_pairs[index].items(): + ids = torch.cat([context_ids, continuation_ids])[:-1] + unpadded_batch[field_name].append(ids) + + input_lengths.append(len(unpadded_batch["input_ids"][-1])) + batch_contexts.append(cc_pairs[index]["input_ids"][0]) + batch_continuations.append(cc_pairs[index]["input_ids"][1]) + + padded_batch = { + field_name: pad_sequence(tensors, batch_first=True).to(model.device) + for field_name, tensors in unpadded_batch.items() + } + return padded_batch, input_lengths, batch_contexts, batch_continuations \ No newline at end of file diff --git a/catwalk/task.py b/catwalk/task.py index d158a6a6..55b94b0b 100644 --- a/catwalk/task.py +++ b/catwalk/task.py @@ -37,13 +37,12 @@ "squad_metrics": torchmetrics.SQuAD, } - -def classification_metrics(num_classes: int): +def classification_metrics(num_classes: int, *, average = None): return { "acc": torchmetrics.Accuracy, - "f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=None), - "precision": partial(torchmetrics.Precision, num_classes=num_classes, average=None), - "recall": partial(torchmetrics.Recall, num_classes=num_classes, average=None) + "f1": partial(torchmetrics.F1Score, num_classes=num_classes, average=average), + "precision": partial(torchmetrics.Precision, num_classes=num_classes, average=average), + "recall": partial(torchmetrics.Recall, num_classes=num_classes, average=average) } diff --git a/catwalk/tasks/__init__.py b/catwalk/tasks/__init__.py index 5c169969..7b486a63 100644 --- a/catwalk/tasks/__init__.py +++ b/catwalk/tasks/__init__.py @@ -270,8 +270,11 @@ "metaicl::unifiedqa:openbookqa_with_ir": MetaICLTask("unifiedqa:openbookqa_with_ir").add_metrics(MC_METRICS), "metaicl::unifiedqa:mctest": MetaICLTask("unifiedqa:mctest").add_metrics(MC_METRICS), "metaicl::unifiedqa:ai2_science_middle": MetaICLTask("unifiedqa:ai2_science_middle").add_metrics(MC_METRICS), - + "metaicl::commonsense_qa": MetaICLTask("commonsense_qa").add_metrics(MC_METRICS), + + "metaicl::numer_sense": MetaICLTask("numer_sense").add_metrics(classification_metrics(12)), + "metaicl::race-high": MetaICLTask("race-high").add_metrics(MC_METRICS), } for config in datasets.get_dataset_config_names("bigscience/P3"): diff --git a/catwalk/utils/__init__.py b/catwalk/utils/__init__.py new file mode 100644 index 00000000..1988e6d8 --- /dev/null +++ b/catwalk/utils/__init__.py @@ -0,0 +1 @@ +from catwalk.utils.prefix_trie import PrefixTrie \ No newline at end of file diff --git a/catwalk/utils/prefix_trie.py b/catwalk/utils/prefix_trie.py new file mode 100644 index 00000000..7d77403d --- /dev/null +++ b/catwalk/utils/prefix_trie.py @@ -0,0 +1,85 @@ +from typing import List, Optional, Sequence, Tuple, Dict +from tango.common import Tqdm + +class PrefixTrie(): + def __init__(self, sequences: Sequence[Sequence[int]], track_after_depth: int = 10): + """ + Returns a PrefixTrie for ordering examples by common prefixes + + # Parameters + + sequences : `Sequence[Sequence[int]]` + Sequences of tokens to add to the add to the Trie + track_after_depth : `int` + Only record sequence indices in nodes at or below this depth. This allows distinct + sequences that coincidentally start with the first few tokens as another sequence + not to be dropped from this barely overlapping prefix. Sequences shorter than the + minimum depth will only have their index recorded in their final node. + """ + self.root = PrefixTrieNode() + self.track_after_depth = track_after_depth + self.nodes: List['PrefixTrieNode'] = [] + for i, sequence in Tqdm.tqdm(enumerate(sequences), desc="Building PrefixTrie for caching", total=len(sequences)): + self._add_sequence(sequence=sequence, index=i) + # only need to track sequences at forks and terminations + for node in self.nodes: + if len(node.children) == 1: + node.subsequences_on_this_path = node.subsequences_ending_here + + def _add_sequence(self, sequence: Sequence[int], index: int): + seq_len = len(sequence) + current_node = self.root + for token_idx, token in enumerate(sequence): + if token not in current_node.children: + current_node.children[token] = PrefixTrieNode(parent=current_node, token=token) + self.nodes.append(current_node.children[token]) + current_node = current_node.children[token] + if (token_idx + 1 >= self.track_after_depth) or (token_idx + 1 >= seq_len): + current_node.subsequences_on_this_path[index] = token_idx + 1 + current_node.subsequences_ending_here[index] = len(sequence) + + def get_leaf_nodes(self) -> List['PrefixTrieNode']: + return [node for node in self.nodes if len(node.children) == 0] + +class PrefixTrieNode(): + def __init__(self, parent: 'PrefixTrieNode' = None, token: int = None): + self.parent = parent + self.token = token + self.subsequences_on_this_path: Dict[int,int] = {} + self.subsequences_ending_here: Dict[int,int] = {} + self.lengths_covered: List[int] = [] + self.children: Dict[int,'PrefixTrieNode'] = {} + + def get_sequence(self) -> List[Optional[int]]: + """Returns the sequence associated with a node""" + current_node = self + sequence = [] + while current_node.parent is not None: + sequence.append(current_node.token) + current_node = current_node.parent + return sequence[::-1] + + def get_subsequences(self) -> Tuple[List[int], int]: + """ + Returns a tuple of: + - a list of all indices for subsequences of the current node including itself + starting with longest and decreasing + - an int, the total number of tokens covered in all subsequences by this prefix + + Note when a PrefixTrie with track_after_depth > 0, some subsequences will be intentionally + ignored here as their indices are not registered in low depth nodes. + """ + current_node = self + indices = [] + already_found = set() + total_lengths_covered = 0 + while current_node.parent is not None: + new_indices = [] + for index, length_covered in current_node.subsequences_on_this_path.items(): + if index not in already_found: + new_indices.append(index) + total_lengths_covered += length_covered + already_found.update(new_indices) + indices.extend(new_indices) + current_node = current_node.parent + return indices, total_lengths_covered \ No newline at end of file diff --git a/experiments/prefix_cache_demo.py b/experiments/prefix_cache_demo.py new file mode 100644 index 00000000..d8a535be --- /dev/null +++ b/experiments/prefix_cache_demo.py @@ -0,0 +1,86 @@ +import argparse +import os +from tango.common.logging import initialize_logging +import time + +from catwalk.models import MetaICLModel +from catwalk.steps import CalculateMetricsStep, PredictStep +from catwalk.tasks import TASK_SETS + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--zeroshot', action='store_true') + parser.add_argument('--no_prefix_caching', action='store_true') + parser.add_argument('--first_n_tasks', type=int, default=20) + args = parser.parse_args() + + start = time.time() + initialize_logging(log_level="ERROR") + os.environ['TOKENIZERS_PARALLELISM'] = "false" + + tasks = TASK_SETS['metaicl-classification-eval'] + tasks = sorted(tasks)[:args.first_n_tasks] + + num_shots = 0 if args.zeroshot else 16 + if args.zeroshot: + batch_size = 64 + elif args.no_prefix_caching: + batch_size = 16 # to account for larger input sizes with ICL + # CACHING with batching does not work close to the max model size as + # the largest prefix + largest continuation in a batch must be <= max model size + else: + batch_size = 1 + limit = 1000 + random_subsample_seed=42 + seeds = [100] if args.zeroshot else [100, 13, 21, 42, 87] + + model = MetaICLModel('gpt2-large', continuation_seperator = ' ' if args.zeroshot else '\n', prefix_caching = not args.no_prefix_caching) + + seed2metrics = {} + for fewshot_seed in seeds: + metric_task_dict = {} + for task in tasks: + + predictions = PredictStep( + model=model, + task=task, + batch_size=batch_size, + limit=limit, + random_subsample_seed=random_subsample_seed, + num_shots=num_shots, + fewshot_seed=fewshot_seed, + ) + metrics = CalculateMetricsStep( + model=model, + task=task, + predictions=predictions) + metric_task_dict[task] = metrics + seed2metrics[fewshot_seed] = metric_task_dict + + avg_f1_per_seed = [] + avg_acc_per_seed = [] + for seed, metric_task_dict in seed2metrics.items(): + total_sum_f1 = 0.0 + total_sum_acc = 0.0 + for task, metrics in metric_task_dict.items(): + for metric, result in metrics.result().items(): + avg_result = result.mean() + if metric == 'f1': + total_sum_f1 += avg_result.item() + elif metric == 'acc': + total_sum_acc += avg_result.item() + print(f"{task}\t{seed}\t{metric}\t{avg_result}") + avg_f1_per_seed.append(total_sum_f1 / len(tasks)) + avg_acc_per_seed.append(total_sum_acc / len(tasks)) + + print(f"avg macro f1 over seeds {sum(avg_f1_per_seed) / len(seeds)}") + print(f"min macro f1 over seeds {min(avg_f1_per_seed)}") + print(f"avg macro acc over seeds {sum(avg_acc_per_seed) / len(seeds)}") + print(f"min macro acc over seeds {min(avg_acc_per_seed)}") + + end = time.time() + print(f"total seconds elapsed: {end - start}") + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..9271bd4e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,10 @@ +from catwalk.utils import PrefixTrie + +def test_prefix_trie(): + sequences = [[1,2,3],[2,3,4],[1,2,3,4]] + trie = PrefixTrie(sequences, track_after_depth=1) + leaves = trie.get_leaf_nodes() + assert leaves[0].get_sequence() == [2,3,4] + assert leaves[1].get_sequence() == [1,2,3,4] + assert leaves[0].get_subsequences() == ([1], 3) + assert leaves[1].get_subsequences() == ([2,0], 7) \ No newline at end of file