diff --git a/src/obnb/label/split/__init__.py b/src/obnb/label/split/__init__.py index 204e5dc..2039f70 100644 --- a/src/obnb/label/split/__init__.py +++ b/src/obnb/label/split/__init__.py @@ -1,5 +1,6 @@ -"""Genearting data splits from the labelset collection.""" +"""Generating data splits from the labelset collection.""" +from obnb.label.split.explicit import ByTermSplit from obnb.label.split.holdout import ( AllHoldout, RandomRatioHoldout, @@ -20,4 +21,5 @@ "RandomRatioPartition", "RatioPartition", "ThresholdPartition", + "ByTermSplit", ] diff --git a/src/obnb/label/split/base.py b/src/obnb/label/split/base.py index b6f8c30..36b0331 100644 --- a/src/obnb/label/split/base.py +++ b/src/obnb/label/split/base.py @@ -36,6 +36,32 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(attrs)})" + def __call__( + self, + ids: List[str], + y: np.ndarray, + ) -> Iterator[Tuple[np.ndarray, ...]]: + """Split the input ids into multiple splits, e.g. a test, train, validation + split. The means by which this splitting occurs should be defined by classes + that inherit from this base class. + + Note: + Inheriting classes should yield the value instead of returning it, + to make it compatible with the sklearn split methods. See the + implementation of the BaseSortedSplit class for an example. + + Args: + ids: List of entity IDs to put in each split. + y: Labels for each entity(?) + + Yields: + Iterator of splits. Each split is a tuple of numpy arrays, where + each array contains the IDs of the entities in the split. + + """ + + raise NotImplementedError + class BaseSortedSplit(BaseSplit): """BaseSortedSplit object for splitting dataset based on sorting.""" diff --git a/src/obnb/label/split/explicit.py b/src/obnb/label/split/explicit.py new file mode 100644 index 0000000..d569975 --- /dev/null +++ b/src/obnb/label/split/explicit.py @@ -0,0 +1,123 @@ +from typing import Iterable, Iterator, List, Tuple + +import numpy +from numpy import ndarray + +from obnb.label.collection import LabelsetCollection +from obnb.label.split.base import BaseSplit + + +class ByTermSplit(BaseSplit): + """Produces splits based on an explicit list of terms. + + Genes which match each term will be placed in the split corresponding to + that term. A split with a single term '*' will act as a catch-all for any + genes that weren't matched by any of the other splits. This would allow you + to, e.g., only retain a specific set of genes in the training set, and place + all others in the test set. + + Note that if the '*' split is not provided, any genes that don't match any + of the other splits will not be present in the returned splits at all. + + """ + + def __init__( + self, + labelset: LabelsetCollection, + split_terms: Iterable[Iterable[str]], + exclusive: bool = False, + ) -> None: + """Initialize ByTermSplit object with reference labels and terms for splits. + + Args: + labelset: LabelsetCollection object containing terms for each + gene ID. + split_terms: a nested collection. The first level of nesting + indicates the splits; the second level within each identifies + terms that should be matched to place a gene in that split. + exclusive: if True, a gene can occur only once across all the + splits; it will belong to the first split in which it occurs. + + """ + self.labelset = labelset + self.split_terms = [set(x) for x in split_terms] + self.exclusive = exclusive + + # verify that there's only one catch-all split + if sum(1 for x in self.split_terms if x == {"*"}) > 1: + raise ValueError("Only one catch-all '*' split is allowed") + + # convert labelset into a dataframe where one can search for + # the terms associated with each gene ID like so: + # self.long_df[self.long_df["Value"] == str(gene_id)]["Name"] + df = self.labelset.to_df() + self.long_df = df.melt( + id_vars=["Name"], + value_vars=df.columns.difference(["Info", "Size", "Name"]), + value_name="Value", + ).dropna(subset=["Value"]) + + # group gene id and aggregate terms into a set, which makes + # it faster to retrieve all the terms for a given gene ID + self.gene_id_to_terms = ( + self.long_df.groupby("Value")["Name"] + .apply(set) + .reset_index() + .rename(columns={"Value": "GeneID", "Name": "Terms"}) + ) + + super().__init__() + + def __call__(self, ids: List[str], y: ndarray) -> Iterator[Tuple[ndarray, ...]]: + """Produce splits based on the terms associated with each gene ID. + + For each gene ID, look up the term it's associated with in the labelset, + and place it in the corresponding split. + + Returns as many splits as there are elements in the split_terms tuple. + + """ + # alias field to shorten the code below + gdf = self.gene_id_to_terms + + # for each split, filter to the gene IDs that have at least one + # term in the split + result = [ + ( + { + gene_id + for gene_id in ids + if gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms + } + if terms != {"*"} + else {"*"} + ) + for terms in self.split_terms + ] + + # if one of the resulting splits ended up as the wildcard, we need to + # fill in that split with any gene that wasn't matched by any of + # the other splits + for idx in range(len(result)): + if result[idx] == {"*"}: + result[idx] = { + gene_id + for gene_id in ids + if not any( + gdf[gdf["GeneID"] == str(gene_id)]["Terms"].values[0] & terms + for terms in self.split_terms + ) + } + + if self.exclusive: + # if exclusive, remove genes in the current split that occurred + # in any previous split + # (we skip the first split since there's nothing with which to + # compare it) + for idx in range(1, len(result)): + result[idx] = result[idx] - set().union(*result[:idx]) + + # yield it in the format returned by other splitters, e.g. a tuple of + # numpy arrays, each of which contain indices into the 'ids' array + # passed into the splitter. + yield tuple([numpy.asarray([ids.index(v) for v in x]) for x in result])