-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathutils.py
More file actions
154 lines (124 loc) · 4.4 KB
/
utils.py
File metadata and controls
154 lines (124 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
General utility functions
"""
from typing import Callable, Optional, Sequence, Union
import numpy as np
import pandas as pd
import torch
from torch import Tensor
def torch_maxval(x: Tensor, **kwargs) -> Tensor:
return torch.max(x, **kwargs)[0]
def torch_minval(x: Tensor, **kwargs) -> Tensor:
return torch.min(x, **kwargs)[0]
def torch_log2fc(x: Tensor, y: Tensor) -> Tensor:
return torch.log2(torch.divide(x, y))
def np_log2fc(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return np.log2(np.divide(x, y))
def get_aggfunc(func: Optional[Union[str, Callable]], tensor: bool = False) -> Callable:
"""
Return a function to aggregate values.
Args:
func: A function or the name of a function. Supported names
are "max", "min", "mean", and "sum". If a function is supplied, it
will be returned unchanged.
tensor: If True, it is assumed that the inputs will be torch tensors.
If False, it is assumed that the inputs will be numpy arrays.
Returns:
The desired function.
Raises:
NotImplementedError: If the input is neither a function nor
a supported function name.
"""
if func is None:
return func
elif isinstance(func, Callable):
return func
elif func == "max":
return torch_maxval if tensor else np.max
elif func == "min":
return torch_minval if tensor else np.min
elif func == "mean":
return torch.mean if tensor else np.mean
elif func == "sum":
return torch.sum if tensor else np.sum
else:
raise NotImplementedError
def get_compare_func(
func: Optional[Union[str, Callable]], tensor: bool = False
) -> Callable:
"""
Return a function to compare two values.
Args:
func: A function or the name of a function. Supported names are "subtract", "divide", and "log2FC".
If a function is supplied, it will be returned unchanged. func cannot be None.
tensor: If True, it is assumed that the inputs will be torch tensors.
If False, it is assumed that the inputs will be numpy arrays.
Returns:
The desired function.
Raises:
NotImplementedError: If the input is neither a function nor
a supported function name.
"""
if func is None:
return None
elif isinstance(func, Callable):
return func
elif func == "subtract":
return torch.subtract if tensor else np.subtract
elif func == "divide":
return torch.divide if tensor else np.divide
elif func == "log2FC":
return torch_log2fc if tensor else np_log2fc
else:
raise NotImplementedError
def get_transform_func(
func: Optional[Union[str, Callable]], tensor: bool = False
) -> Callable:
"""
Return a function to transform the input.
Args:
func: A function or the name of a function. Supported names are "log" and "log1p".
If None, the identity function will be returned. If a function is supplied, it
will be returned unchanged.
tensor: If True, it is assumed that the inputs will be torch tensors.
If False, it is assumed that the inputs will be numpy arrays.
Returns:
The desired function.
Raises:
NotImplementedError: If the input is neither a function nor
a supported function name.
"""
if func is None:
return None
elif isinstance(func, Callable):
return func
elif func == "log":
return torch.log if tensor else np.log
elif func == "log1p":
return torch.log1p if tensor else np.log1p
else:
raise NotImplementedError
def make_list(
x: Optional[Union[pd.Series, np.ndarray, Tensor, Sequence, int, float, str]],
) -> list:
"""
Convert various kinds of inputs into a list
Args:
x: An input value or sequence of values.
Returns:
The input values in list format.
"""
if (x is None) or (isinstance(x, list)):
return x
elif (isinstance(x, int)) or (isinstance(x, str)) or (isinstance(x, float)):
return [x]
elif isinstance(x, pd.Series):
return x.tolist()
elif isinstance(x, np.matrix):
return np.array(x).squeeze().tolist()
elif (isinstance(x, np.ndarray)) or (isinstance(x, Tensor)):
return x.squeeze().tolist()
elif isinstance(x, set):
return list(x)
else:
raise NotImplementedError