Skip to content

Commit a2c9509

Browse files
committed
minor code improvements
1 parent 17acb8b commit a2c9509

File tree

3 files changed

+187
-172
lines changed

3 files changed

+187
-172
lines changed

src/tmplot/_distance.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
__all__ = [
2-
'get_topics_dist', 'get_topics_scatter', 'get_top_topic_words']
3-
from typing import Union, List
1+
__all__ = ["get_topics_dist", "get_topics_scatter", "get_top_topic_words"]
2+
from typing import Optional, Union, List
43
from itertools import combinations
5-
from pandas import DataFrame
4+
from pandas import DataFrame, Index
65
import numpy as np
76
from scipy.special import kl_div
87
from scipy.spatial import distance
98
from sklearn.manifold import (
10-
TSNE, Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding)
9+
TSNE,
10+
Isomap,
11+
LocallyLinearEmbedding,
12+
MDS,
13+
SpectralEmbedding,
14+
)
1115
from ._helpers import calc_topics_marg_probs
1216

1317

@@ -28,15 +32,14 @@ def _dist_jsd(a1: np.ndarray, a2: np.ndarray):
2832

2933
def _dist_jef(a1: np.ndarray, a2: np.ndarray):
3034
vals = (a1 - a2) * (np.log(a1) - np.log(a2))
31-
vals[(vals <= 0) | ~np.isfinite(vals)] = 0.
35+
vals[(vals <= 0) | ~np.isfinite(vals)] = 0.0
3236
return vals.sum()
3337

3438

3539
def _dist_hel(a1: np.ndarray, a2: np.ndarray):
3640
a1[(a1 <= 0) | ~np.isfinite(a1)] = 1e-64
3741
a2[(a2 <= 0) | ~np.isfinite(a2)] = 1e-64
38-
hel_val = distance.euclidean(
39-
np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2)
42+
hel_val = distance.euclidean(np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2)
4043
return hel_val
4144

4245

@@ -52,19 +55,18 @@ def _dist_tv(a1: np.ndarray, a2: np.ndarray):
5255
return dist
5356

5457

55-
def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100):
56-
a = np.argsort(a1)[:-top_words-1:-1]
57-
b = np.argsort(a2)[:-top_words-1:-1]
58+
def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100):
59+
a = np.argsort(a1)[: -top_words - 1 : -1]
60+
b = np.argsort(a2)[: -top_words - 1 : -1]
5861
j_num = np.intersect1d(a, b, assume_unique=False).size
5962
j_den = np.union1d(a, b).size
6063
jac_val = 1 - j_num / j_den
6164
return jac_val
6265

6366

6467
def get_topics_dist(
65-
phi: Union[np.ndarray, DataFrame],
66-
method: str = "sklb",
67-
**kwargs) -> np.ndarray:
68+
phi: Union[np.ndarray, DataFrame], method: str = "sklb", **kwargs
69+
) -> np.ndarray:
6870
"""Finding closest topics in models.
6971
7072
Parameters
@@ -110,16 +112,18 @@ def get_topics_dist(
110112
for i, j in topics_pairs:
111113
_dist_func = dist_funcs.get(method, "sklb")
112114
topics_dists[((i, j), (j, i))] = _dist_func(
113-
phi_copy[:, i], phi_copy[:, j], **kwargs)
115+
phi_copy[:, i], phi_copy[:, j], **kwargs
116+
)
114117

115118
return topics_dists
116119

117120

118121
def get_topics_scatter(
119-
topic_dists: np.ndarray,
120-
theta: np.ndarray,
121-
method: str = 'tsne',
122-
method_kws: dict = None) -> DataFrame:
122+
topic_dists: np.ndarray,
123+
theta: np.ndarray,
124+
method: str = "tsne",
125+
method_kws: Optional[dict] = None,
126+
) -> DataFrame:
123127
"""Calculate topics coordinates for a scatter plot.
124128
125129
Parameters
@@ -146,52 +150,52 @@ def get_topics_scatter(
146150
Topics scatter coordinates.
147151
"""
148152
if not method_kws:
149-
method_kws = {'n_components': 2}
153+
method_kws = {"n_components": 2}
150154

151-
if method == 'tsne':
152-
method_kws.setdefault('init', 'pca')
153-
method_kws.setdefault('learning_rate', 'auto')
154-
method_kws.setdefault(
155-
'perplexity', min(50, max(topic_dists.shape[0] // 2, 1)))
155+
if method == "tsne":
156+
method_kws.setdefault("init", "pca")
157+
method_kws.setdefault("learning_rate", "auto")
158+
method_kws.setdefault("perplexity", min(50, max(topic_dists.shape[0] // 2, 1)))
156159
transformer = TSNE(**method_kws)
157160

158-
elif method == 'sem':
159-
method_kws.setdefault('affinity', 'precomputed')
161+
elif method == "sem":
162+
method_kws.setdefault("affinity", "precomputed")
160163
transformer = SpectralEmbedding(**method_kws)
161164

162-
elif method == 'mds':
163-
method_kws.setdefault('dissimilarity', 'precomputed')
164-
method_kws.setdefault('normalized_stress', 'auto')
165+
elif method == "mds":
166+
method_kws.setdefault("dissimilarity", "precomputed")
167+
method_kws.setdefault("normalized_stress", "auto")
165168
transformer = MDS(**method_kws)
166169

167-
elif method == 'lle':
168-
method_kws['method'] = 'standard'
170+
elif method == "lle":
171+
method_kws["method"] = "standard"
169172
transformer = LocallyLinearEmbedding(**method_kws)
170173

171-
elif method == 'ltsa':
172-
method_kws['method'] = 'ltsa'
174+
elif method == "ltsa":
175+
method_kws["method"] = "ltsa"
173176
transformer = LocallyLinearEmbedding(**method_kws)
174177

175-
elif method == 'isomap':
178+
elif method == "isomap":
176179
transformer = Isomap(**method_kws)
177180

178181
coords = transformer.fit_transform(topic_dists)
179182

180-
topics_xy = DataFrame(coords, columns=['x', 'y'])
181-
topics_xy['topic'] = topics_xy.index.astype(int)
182-
topics_xy['size'] = calc_topics_marg_probs(theta)
183-
size_sum = topics_xy['size'].sum()
183+
topics_xy = DataFrame(coords, columns=Index(["x", "y"]))
184+
topics_xy["topic"] = topics_xy.index.astype(int)
185+
topics_xy["size"] = calc_topics_marg_probs(theta)
186+
size_sum = topics_xy["size"].sum()
184187
if size_sum > 0:
185-
topics_xy['size'] *= (100 / topics_xy['size'].sum())
188+
topics_xy["size"] *= 100 / topics_xy["size"].sum()
186189
else:
187-
topics_xy['size'] = np.nan
190+
topics_xy["size"] = np.nan
188191
return topics_xy
189192

190193

191194
def get_top_topic_words(
192-
phi: DataFrame,
193-
words_num: int = 20,
194-
topics_idx: Union[List[int], np.ndarray] = None) -> DataFrame:
195+
phi: DataFrame,
196+
words_num: int = 20,
197+
topics_idx: Optional[Union[List[int], np.ndarray]] = None,
198+
) -> DataFrame:
195199
"""Select top topic words from a fitted model.
196200
197201
Parameters
@@ -209,9 +213,6 @@ def get_top_topic_words(
209213
DataFrame
210214
Words with highest probabilities in all (or selected) topics.
211215
"""
212-
return phi.loc[:, topics_idx or phi.columns]\
213-
.apply(
214-
lambda x: x
215-
.sort_values(ascending=False)
216-
.head(words_num).index, axis=0
216+
return phi.loc[:, topics_idx or phi.columns].apply(
217+
lambda x: x.sort_values(ascending=False).head(words_num).index, axis=0
217218
)

src/tmplot/_stability.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1-
__all__ = ['get_closest_topics', 'get_stable_topics']
1+
__all__ = ["get_closest_topics", "get_stable_topics"]
22
from typing import List, Tuple, Any
33
import numpy as np
44
import tqdm
5-
from ._distance import _dist_klb, _dist_sklb, _dist_jsd, _dist_jef, _dist_hel, \
6-
_dist_bhat, _dist_jac, _dist_tv
5+
from ._distance import (
6+
_dist_klb,
7+
_dist_sklb,
8+
_dist_jsd,
9+
_dist_jef,
10+
_dist_hel,
11+
_dist_bhat,
12+
_dist_jac,
13+
_dist_tv,
14+
)
715
from ._helpers import get_phi
816

917
dist_funcs = {
@@ -19,11 +27,12 @@
1927

2028

2129
def get_closest_topics(
22-
models: List[Any],
23-
ref: int = 0,
24-
method: str = "sklb",
25-
top_words: int = 100,
26-
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
30+
models: List[Any],
31+
ref: int = 0,
32+
method: str = "sklb",
33+
top_words: int = 100,
34+
verbose: bool = True,
35+
) -> Tuple[np.ndarray, np.ndarray]:
2736
"""Finding closest topics in models.
2837
2938
Parameters
@@ -93,7 +102,6 @@ def enum_func(x):
93102

94103
# Iterating over all models
95104
for mid, model in enum_func(models):
96-
97105
# Current model is equal to reference model, skipping
98106
if mid == ref:
99107
continue
@@ -105,7 +113,8 @@ def enum_func(x):
105113
for t_ref in range(topics_num):
106114
for t in range(topics_num):
107115
all_vs_all_dists[t_ref, t] = dist_func(
108-
model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t])
116+
model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t]
117+
)
109118

110119
# Creating two arrays for the closest topics ids and distance values
111120
if method == "jac":
@@ -119,14 +128,15 @@ def enum_func(x):
119128

120129

121130
def get_stable_topics(
122-
closest_topics: np.ndarray,
123-
dist: np.ndarray,
124-
norm: bool = True,
125-
inverse: bool = True,
126-
inverse_factor: float = 1.0,
127-
ref: int = 0,
128-
thres: float = 0.9,
129-
thres_models: int = 2) -> Tuple[np.ndarray, np.ndarray]:
131+
closest_topics: np.ndarray,
132+
dist: np.ndarray,
133+
norm: bool = True,
134+
inverse: bool = True,
135+
inverse_factor: float = 1.0,
136+
ref: int = 0,
137+
thres: float = 0.9,
138+
thres_models: int = 2,
139+
) -> Tuple[np.ndarray, np.ndarray]:
130140
"""Finding stable topics in models.
131141
132142
Parameters
@@ -179,7 +189,5 @@ def get_stable_topics(
179189
dist_arr = np.asarray(dist)
180190
dist_ready = dist_arr / dist_arr.max() if norm else dist_arr.copy()
181191
dist_ready = inverse_factor - dist_ready if inverse else dist_ready
182-
mask = (
183-
np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1)
184-
>= thres_models)
192+
mask = np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1) >= thres_models
185193
return closest_topics[mask], dist_ready[mask]

0 commit comments

Comments
 (0)