Skip to content

Commit dbaac64

Browse files
authored
Merge pull request #41 from KrishnaswamyLab/dev
Dev
2 parents 1106f50 + b15aa99 commit dbaac64

14 files changed

+5461
-1402
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
language: python
22
python:
3-
- '3.5'
43
- '3.6'
54
- '3.7'
65
- '3.8'

comparison/comparison.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,20 +188,33 @@ def fit_kNN(self, data, **kwargs):
188188
self.graph_knn = gt.Graph(data, n_pca=100, kernel_symm=None, use_pygsp=True,
189189
random_state=self.seed, **kwargs)
190190

191-
def generate_ground_truth_pdf(self, data_phate=None):
192-
'''Takes a set of PHATE coordinates over a set of points and creates an underlying
193-
ground truth pdf over the points as a convex combination of the input phate coords.
194-
'''
195-
np.random.seed(self.seed)
191+
def fit_phate(self, data, **kwargs):
192+
try:
193+
import phate
194+
except ModuleNotFoundError:
195+
raise ModuleNotFoundError('PHATE must be installed. Install via pip `pip install --user phate`')
196196

197-
if data_phate is not None:
198-
self.data_phate = data_phate
197+
self.set_phate(phate.PHATE(n_components=3, **kwargs).fit_transform(data))
198+
return self.data_phate
199199

200+
def set_phate(self, data_phate):
200201
if not data_phate.shape[1] == 3:
201202
raise ValueError('data_phate must have 3 dimensions')
202203
if not np.isclose(data_phate.mean(), 0):
203204
# data_phate must be mean-centered
204205
data_phate = scipy.stats.zscore(data_phate, axis=0)
206+
self.data_phate = data_phate
207+
208+
def generate_ground_truth_pdf(self, data_phate=None):
209+
'''Takes a set of PHATE coordinates over a set of points and creates an underlying
210+
ground truth pdf over the points as a convex combination of the input phate coords.
211+
'''
212+
np.random.seed(self.seed)
213+
214+
if data_phate is not None:
215+
self.set_phate(data_phate)
216+
elif self.data_phate is None:
217+
raise ValueError('data_phate must be set prior to running generate_ground_truth_pdf().')
205218

206219
# Create an array of values that sums to 1
207220
data_simplex = np.sort(np.random.uniform(size=(2)))
1.04 MB
Binary file not shown.

doc/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# You can set these variables from the command line.
55
SPHINXOPTS =
66
SPHINXBUILD = sphinx-build
7-
SPHINXPROJ = PHATE
7+
SPHINXPROJ = MELD
88
SOURCEDIR = source
99
BUILDDIR = build
1010

@@ -17,4 +17,4 @@ help:
1717
# Catch-all target: route all unknown targets to Sphinx using the new
1818
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
1919
%: Makefile
20-
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

doc/source/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ scipy>=1.1.0
33
pandas>=0.25
44
future
55
graphtools>=0.1.8.1
6-
sphinx<=1.8.5
6+
sphinx
77
sphinxcontrib-napoleon
88
autodocsumm

meld/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .meld import MELD
44
from .cluster import VertexFrequencyCluster
55
from .version import __version__
6+
from .benchmark import Benchmarker

meld/benchmark.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (C) 2020 Krishnaswamy Lab, Yale University
2+
3+
import numpy as np
4+
import scipy
5+
import sklearn
6+
import meld
7+
import graphtools as gt
8+
9+
10+
class Benchmarker(object):
11+
"""Creates random signals over a dataset for benchmarking.
12+
13+
Results are used for quantitative comparisons and for parameter searches
14+
on a specific dataset.
15+
16+
Parameters
17+
----------
18+
seed : integer or numpy.RandomState, optional, default: None
19+
Random state. Defaults to the global `numpy` random number generator
20+
21+
Attributes
22+
----------
23+
data_phate : array, shape=[n_samples, 3]
24+
Embedding of the data used to create random signals
25+
pdf : array, shape=[n_samples]
26+
Ground truth probability density function created over the input data.
27+
RES_int : array, shape=[n_samples]
28+
An integer representation of the RES used for k-NN and graph averaging.
29+
RES : array, shape=[n_samples, 2]
30+
Raw Experimental Signal (RES) as described in Burkhardt et al. (2020).
31+
Indicates the sample to which each cell is assigned.
32+
graph : graphtools.base.BaseGraph
33+
The graph built on the input data
34+
graph_kNN : graphtools.graphs.kNNGraph
35+
The graph built on the input data
36+
meld_op : meld.meld.MELD
37+
MELD operator used to derive an EES
38+
EES : array, shape=[n_samples, 2]
39+
Enhanced Experimental Signal (EES) a conditional probability that a cell was
40+
originally sampled from either condition. Should converge to Benchmarker.pdf
41+
42+
"""
43+
def __init__(self, seed=None):
44+
self.seed = seed
45+
self.data_phate = None
46+
self.pdf = None
47+
self.RES_int = None
48+
self.RES = None
49+
self.graph = None
50+
self.graph_kNN = None
51+
self.meld_op = None
52+
self.EES = None
53+
self.estimates = {}
54+
55+
def set_seed(self, seed):
56+
"""Sets random seed.
57+
58+
Parameters
59+
----------
60+
seed : integer or numpy.RandomState
61+
Random state. Defaults to the global `numpy` random number generator
62+
63+
Returns
64+
-------
65+
seed : integer or numpy.RandomState
66+
Newly set random seed.
67+
68+
"""
69+
70+
self.seed = seed
71+
return self.seed
72+
73+
74+
def set_phate(self, data_phate):
75+
"""Short summary.
76+
77+
Parameters
78+
----------
79+
data_phate : array, shape=[n_samples, 3]
80+
PHATE embedding for input data.
81+
82+
Returns
83+
-------
84+
data_phate : array, shape=[n_samples, 3]
85+
Normalized PHATE embedding.
86+
87+
"""
88+
if not data_phate.shape[1] == 3:
89+
raise ValueError('data_phate must have 3 dimensions')
90+
if not np.isclose(data_phate.mean(), 0):
91+
# data_phate must be mean-centered
92+
data_phate = scipy.stats.zscore(data_phate, axis=0)
93+
self.data_phate = data_phate
94+
95+
96+
def fit_graph(self, data, **kwargs):
97+
"""Fits a graphtools.Graph to input data
98+
99+
Parameters
100+
----------
101+
data : array, shape=[n_samples, n_observations]
102+
Input data
103+
**kwargs : dict
104+
Keyword arguments passed to gt.Graph()
105+
106+
Returns
107+
-------
108+
graph : graphtools.Graph
109+
Graph fit to data
110+
111+
"""
112+
self.graph = gt.Graph(data, n_pca=100, use_pygsp=True, random_state=self.seed, **kwargs)
113+
return self.graph
114+
115+
def fit_phate(self, data, **kwargs):
116+
"""Generates a 3D phate embedding of input data
117+
118+
Parameters
119+
----------
120+
data : array, shape=[n_samples, n_observations]
121+
Description of parameter `data`.
122+
**kwargs : dict
123+
Keyword arguments passed to phate.PHATE().
124+
125+
Returns
126+
-------
127+
data_phate : array, shape=[n_samples, 3]
128+
Normalized PHATE embedding for input data.
129+
130+
"""
131+
import phate
132+
133+
self.set_phate(phate.PHATE(n_components=3, **kwargs).fit_transform(data))
134+
return self.data_phate
135+
136+
137+
def generate_ground_truth_pdf(self, data_phate=None):
138+
"""Creates a random density function over input data.
139+
140+
Takes a set of PHATE coordinates over a set of points and creates an underlying
141+
ground truth pdf over the points as a convex combination of the input phate coords.
142+
143+
Parameters
144+
----------
145+
data_phate : array, shape=[n_samples, 3]
146+
PHATE embedding for input data.
147+
148+
Returns
149+
-------
150+
pdf
151+
Ground truth conditional probability of the sample given the data.
152+
153+
"""
154+
155+
np.random.seed(self.seed)
156+
157+
if data_phate is not None:
158+
self.set_phate(data_phate)
159+
elif self.data_phate is None:
160+
raise ValueError('data_phate must be set prior to running generate_ground_truth_pdf().')
161+
162+
# Create an array of values that sums to 1
163+
data_simplex = np.sort(np.random.uniform(size=(2)))
164+
data_simplex = np.hstack([0, data_simplex, 1])
165+
data_simplex = np.diff(data_simplex)
166+
np.random.shuffle(data_simplex)
167+
168+
# Weight each PHATE component by the simplex weights
169+
sort_axis = np.sum(self.data_phate * data_simplex, axis=1)
170+
171+
# Pass the weighted components through a logit
172+
self.pdf = scipy.special.expit(sort_axis)
173+
return self.pdf
174+
175+
def generate_RES(self):
176+
np.random.seed(self.seed)
177+
178+
# Create RES
179+
self.RES_int = np.random.binomial(1, self.pdf)
180+
self.RES = np.array(['ctrl' if res == 0 else 'expt' for res in self.RES_int])
181+
182+
def calculate_EES(self, data=None, **kwargs):
183+
np.random.seed(self.seed)
184+
if not self.graph:
185+
if data is not None:
186+
self.fit_graph(data)
187+
else:
188+
raise NameError("Must pass `data` unless graph has already been fit")
189+
190+
self.meld_op = meld.MELD(**kwargs, verbose=False).fit(self.graph)
191+
self.EES = self.meld_op.transform(self.RES)
192+
self.EES = self.EES['expt'].values # Only keep the expt condition
193+
return self.EES
194+
195+
def calculate_mse(self, estimate):
196+
'''Calculated MSE between the ground truth PDF and an estimate
197+
'''
198+
return sklearn.metrics.mean_squared_error(self.pdf, estimate)

meld/version.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (C) 2020 Krishnaswamy Lab, Yale University
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.3.1"
4+

0 commit comments

Comments
 (0)